|
38 | 38 | #include <vector> |
39 | 39 | #include <set> |
40 | 40 |
|
41 | | - |
42 | 41 | #include <htm/algorithms/TemporalMemory.hpp> |
43 | | - |
44 | | -#include <htm/utils/GroupBy.hpp> |
45 | 42 | #include <htm/algorithms/Anomaly.hpp> |
46 | 43 |
|
47 | 44 | using namespace std; |
@@ -279,61 +276,69 @@ void TemporalMemory::activateCells(const SDR &activeColumns, const bool learn) { |
279 | 276 |
|
280 | 277 | const vector<CellIdx> prevWinnerCells = std::move(winnerCells_); |
281 | 278 |
|
282 | | - //maps segment S to a new segment that is at start of a column where |
283 | | - //S belongs. |
284 | | - //for 3 cells per columns: |
285 | | - //s1_1, s1_2, s1_3, s2_1, s2_2, s2_3, ... |
286 | | - //columnForSegment (for short here CFS) |
287 | | - //CFS(s1_1) = s1_1 = "start of column 1" |
288 | | - //CFS(s1_2) = s1_1 |
289 | | - //CFS(s1_3) = s1_1 |
290 | | - //CFS(s2_1) = s2_1 = "column 2" |
291 | | - //CFS(s2_2) = s2_1 |
292 | | - //... |
293 | | - const auto toColumns = [&](const Segment segment) { |
| 279 | + const auto getColumnOfSegment = [&](const Segment segment) { |
294 | 280 | return connections.cellForSegment(segment) / cellsPerColumn_; |
295 | 281 | }; |
296 | | - const auto identity = [](const ElemSparse a) {return a;}; //TODO use std::identity when c++20 |
297 | | - |
298 | | - for (auto &&columnData : groupBy( //group by columns, and convert activeSegments & matchingSegments to cols. |
299 | | - sparse, identity, |
300 | | - activeSegments_, toColumns, |
301 | | - matchingSegments_, toColumns)) { |
302 | | - |
303 | | - Segment column; //we say "column", but it's the first segment of n-segments/cells that belong to the column |
304 | | - vector<Segment>::const_iterator activeColumnsBegin, activeColumnsEnd, |
305 | | - columnActiveSegmentsBegin, columnActiveSegmentsEnd, |
306 | | - columnMatchingSegmentsBegin, columnMatchingSegmentsEnd; |
307 | | - |
308 | | - // for column in activeColumns (the 'sparse' above): |
309 | | - // get its active segments ( >= connectedThr) |
310 | | - // get its matching segs ( >= TODO |
311 | | - std::tie(column, |
312 | | - activeColumnsBegin, activeColumnsEnd, |
313 | | - columnActiveSegmentsBegin, columnActiveSegmentsEnd, |
314 | | - columnMatchingSegmentsBegin, columnMatchingSegmentsEnd |
315 | | - ) = columnData; |
316 | | - |
317 | | - const bool isActiveColumn = activeColumnsBegin != activeColumnsEnd; |
318 | | - if (isActiveColumn) { //current active column... |
| 282 | + |
| 283 | + // Iterate over these three lists at the same time. |
| 284 | + auto activeColumnsBegin = sparse.cbegin(); |
| 285 | + auto columnActiveSegmentsBegin = activeSegments_.cbegin(); |
| 286 | + auto columnMatchingSegmentsBegin = matchingSegments_.cbegin(); |
| 287 | + while(true) { |
| 288 | + // Find the next (lowest indexed) column in any of the three lists. |
| 289 | + Segment column = numColumns_; // Sentinel value, all column indexes are less than this value. |
| 290 | + if (activeColumnsBegin != sparse.cend()) { |
| 291 | + column = std::min(column, *activeColumnsBegin); |
| 292 | + } |
| 293 | + if (columnActiveSegmentsBegin != activeSegments_.cend()) { |
| 294 | + column = std::min(column, getColumnOfSegment(*columnActiveSegmentsBegin)); |
| 295 | + } |
| 296 | + if (columnMatchingSegmentsBegin != matchingSegments_.cend()) { |
| 297 | + column = std::min(column, getColumnOfSegment(*columnMatchingSegmentsBegin)); |
| 298 | + } |
| 299 | + if (column == numColumns_) { |
| 300 | + break; |
| 301 | + } |
| 302 | + // Find all contiguous stretches of the lists which are part of the selected column. |
| 303 | + auto activeColumnsEnd = activeColumnsBegin; |
| 304 | + while (activeColumnsEnd != sparse.cend() |
| 305 | + && *activeColumnsEnd == column) { |
| 306 | + ++activeColumnsEnd; |
| 307 | + } |
| 308 | + auto columnActiveSegmentsEnd = columnActiveSegmentsBegin; |
| 309 | + while (columnActiveSegmentsEnd != activeSegments_.cend() && |
| 310 | + getColumnOfSegment(*columnActiveSegmentsEnd) == column) { |
| 311 | + ++columnActiveSegmentsEnd; |
| 312 | + } |
| 313 | + auto columnMatchingSegmentsEnd = columnMatchingSegmentsBegin; |
| 314 | + while (columnMatchingSegmentsEnd != matchingSegments_.cend() |
| 315 | + && getColumnOfSegment(*columnMatchingSegmentsEnd) == column) { |
| 316 | + ++columnMatchingSegmentsEnd; |
| 317 | + } |
| 318 | + |
| 319 | + if (activeColumnsBegin != activeColumnsEnd) { |
| 320 | + // This column is active. |
319 | 321 | if (columnActiveSegmentsBegin != columnActiveSegmentsEnd) { |
320 | | - //...was also predicted -> learn :o) |
| 322 | + // This column was also predicted. |
321 | 323 | activatePredictedColumn_( |
322 | 324 | columnActiveSegmentsBegin, columnActiveSegmentsEnd, |
323 | 325 | prevActiveCells, prevWinnerCells, learn); |
324 | 326 | } else { |
325 | | - //...has not been predicted -> |
| 327 | + // This column was not predicted. |
326 | 328 | burstColumn_(column, |
327 | 329 | columnMatchingSegmentsBegin, columnMatchingSegmentsEnd, |
328 | 330 | prevActiveCells, prevWinnerCells, |
329 | | - learn); |
| 331 | + learn); |
330 | 332 | } |
331 | | - |
332 | | - } else { // predicted but not active column -> unlearn |
| 333 | + } else { |
| 334 | + // This column was predicted but is not active. |
333 | 335 | if (learn) { |
334 | 336 | punishPredictedColumn_(columnMatchingSegmentsBegin, columnMatchingSegmentsEnd, prevActiveCells); |
335 | 337 | } |
336 | | - } //else: not predicted & not active -> no activity -> does not show up at all |
| 338 | + } // else: not predicted & not active -> do nothing to the column. |
| 339 | + activeColumnsBegin = activeColumnsEnd; |
| 340 | + columnActiveSegmentsBegin = columnActiveSegmentsEnd; |
| 341 | + columnMatchingSegmentsBegin = columnMatchingSegmentsEnd; |
337 | 342 | } |
338 | 343 | segmentsValid_ = false; |
339 | 344 | } |
|
0 commit comments