Skip to content

Commit f19bd21

Browse files
authored
Merge pull request #974 from htm-community/replace-groupby
Replace htm/utils/GroupBy with simpler in-lined implementation.
2 parents 2444639 + 1002426 commit f19bd21

File tree

5 files changed

+49
-372
lines changed

5 files changed

+49
-372
lines changed

src/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ set(types_files
178178
)
179179

180180
set(utils_files
181-
htm/utils/GroupBy.hpp
182181
htm/utils/Log.hpp
183182
htm/utils/MovingAverage.cpp
184183
htm/utils/MovingAverage.hpp

src/htm/algorithms/TemporalMemory.cpp

Lines changed: 49 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,7 @@
3838
#include <vector>
3939
#include <set>
4040

41-
4241
#include <htm/algorithms/TemporalMemory.hpp>
43-
44-
#include <htm/utils/GroupBy.hpp>
4542
#include <htm/algorithms/Anomaly.hpp>
4643

4744
using namespace std;
@@ -279,61 +276,69 @@ void TemporalMemory::activateCells(const SDR &activeColumns, const bool learn) {
279276

280277
const vector<CellIdx> prevWinnerCells = std::move(winnerCells_);
281278

282-
//maps segment S to a new segment that is at start of a column where
283-
//S belongs.
284-
//for 3 cells per columns:
285-
//s1_1, s1_2, s1_3, s2_1, s2_2, s2_3, ...
286-
//columnForSegment (for short here CFS)
287-
//CFS(s1_1) = s1_1 = "start of column 1"
288-
//CFS(s1_2) = s1_1
289-
//CFS(s1_3) = s1_1
290-
//CFS(s2_1) = s2_1 = "column 2"
291-
//CFS(s2_2) = s2_1
292-
//...
293-
const auto toColumns = [&](const Segment segment) {
279+
const auto getColumnOfSegment = [&](const Segment segment) {
294280
return connections.cellForSegment(segment) / cellsPerColumn_;
295281
};
296-
const auto identity = [](const ElemSparse a) {return a;}; //TODO use std::identity when c++20
297-
298-
for (auto &&columnData : groupBy( //group by columns, and convert activeSegments & matchingSegments to cols.
299-
sparse, identity,
300-
activeSegments_, toColumns,
301-
matchingSegments_, toColumns)) {
302-
303-
Segment column; //we say "column", but it's the first segment of n-segments/cells that belong to the column
304-
vector<Segment>::const_iterator activeColumnsBegin, activeColumnsEnd,
305-
columnActiveSegmentsBegin, columnActiveSegmentsEnd,
306-
columnMatchingSegmentsBegin, columnMatchingSegmentsEnd;
307-
308-
// for column in activeColumns (the 'sparse' above):
309-
// get its active segments ( >= connectedThr)
310-
// get its matching segs ( >= TODO
311-
std::tie(column,
312-
activeColumnsBegin, activeColumnsEnd,
313-
columnActiveSegmentsBegin, columnActiveSegmentsEnd,
314-
columnMatchingSegmentsBegin, columnMatchingSegmentsEnd
315-
) = columnData;
316-
317-
const bool isActiveColumn = activeColumnsBegin != activeColumnsEnd;
318-
if (isActiveColumn) { //current active column...
282+
283+
// Iterate over these three lists at the same time.
284+
auto activeColumnsBegin = sparse.cbegin();
285+
auto columnActiveSegmentsBegin = activeSegments_.cbegin();
286+
auto columnMatchingSegmentsBegin = matchingSegments_.cbegin();
287+
while(true) {
288+
// Find the next (lowest indexed) column in any of the three lists.
289+
Segment column = numColumns_; // Sentinel value, all column indexes are less than this value.
290+
if (activeColumnsBegin != sparse.cend()) {
291+
column = std::min(column, *activeColumnsBegin);
292+
}
293+
if (columnActiveSegmentsBegin != activeSegments_.cend()) {
294+
column = std::min(column, getColumnOfSegment(*columnActiveSegmentsBegin));
295+
}
296+
if (columnMatchingSegmentsBegin != matchingSegments_.cend()) {
297+
column = std::min(column, getColumnOfSegment(*columnMatchingSegmentsBegin));
298+
}
299+
if (column == numColumns_) {
300+
break;
301+
}
302+
// Find all contiguous stretches of the lists which are part of the selected column.
303+
auto activeColumnsEnd = activeColumnsBegin;
304+
while (activeColumnsEnd != sparse.cend()
305+
&& *activeColumnsEnd == column) {
306+
++activeColumnsEnd;
307+
}
308+
auto columnActiveSegmentsEnd = columnActiveSegmentsBegin;
309+
while (columnActiveSegmentsEnd != activeSegments_.cend() &&
310+
getColumnOfSegment(*columnActiveSegmentsEnd) == column) {
311+
++columnActiveSegmentsEnd;
312+
}
313+
auto columnMatchingSegmentsEnd = columnMatchingSegmentsBegin;
314+
while (columnMatchingSegmentsEnd != matchingSegments_.cend()
315+
&& getColumnOfSegment(*columnMatchingSegmentsEnd) == column) {
316+
++columnMatchingSegmentsEnd;
317+
}
318+
319+
if (activeColumnsBegin != activeColumnsEnd) {
320+
// This column is active.
319321
if (columnActiveSegmentsBegin != columnActiveSegmentsEnd) {
320-
//...was also predicted -> learn :o)
322+
// This column was also predicted.
321323
activatePredictedColumn_(
322324
columnActiveSegmentsBegin, columnActiveSegmentsEnd,
323325
prevActiveCells, prevWinnerCells, learn);
324326
} else {
325-
//...has not been predicted ->
327+
// This column was not predicted.
326328
burstColumn_(column,
327329
columnMatchingSegmentsBegin, columnMatchingSegmentsEnd,
328330
prevActiveCells, prevWinnerCells,
329-
learn);
331+
learn);
330332
}
331-
332-
} else { // predicted but not active column -> unlearn
333+
} else {
334+
// This column was predicted but is not active.
333335
if (learn) {
334336
punishPredictedColumn_(columnMatchingSegmentsBegin, columnMatchingSegmentsEnd, prevActiveCells);
335337
}
336-
} //else: not predicted & not active -> no activity -> does not show up at all
338+
} // else: not predicted & not active -> do nothing to the column.
339+
activeColumnsBegin = activeColumnsEnd;
340+
columnActiveSegmentsBegin = columnActiveSegmentsEnd;
341+
columnMatchingSegmentsBegin = columnMatchingSegmentsEnd;
337342
}
338343
segmentsValid_ = false;
339344
}

src/htm/utils/GroupBy.hpp

Lines changed: 0 additions & 243 deletions
This file was deleted.

src/test/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ set(types_tests
9797
)
9898

9999
set(utils_tests
100-
unit/utils/GroupByTest.cpp
101100
unit/utils/MovingAverageTest.cpp
102101
unit/utils/RandomTest.cpp
103102
unit/utils/VectorHelpersTest.cpp

0 commit comments

Comments
 (0)