2020package org .apache .sysds .runtime .instructions .ooc ;
2121
2222import java .util .HashMap ;
23- import java .util .Map ;
2423import java .util .concurrent .ExecutorService ;
2524
2625import org .apache .sysds .common .Opcodes ;
@@ -82,6 +81,9 @@ public void processInstruction( ExecutionContext ec ) {
8281 partitionedVector .put (key , vectorSlice );
8382 }
8483
84+ // number of colBlocks for early block output
85+ long nBlocks = min .getDataCharacteristics ().getNumColBlocks ();
86+
8587 LocalTaskQueue <IndexedMatrixValue > qIn = min .getStreamHandle ();
8688 LocalTaskQueue <IndexedMatrixValue > qOut = new LocalTaskQueue <>();
8789 BinaryOperator plus = InstructionUtils .parseBinaryOperator (Opcodes .PLUS .toString ());
@@ -94,6 +96,7 @@ public void processInstruction( ExecutionContext ec ) {
9496 IndexedMatrixValue tmp = null ;
9597 try {
9698 HashMap <Long , MatrixBlock > partialResults = new HashMap <>();
99+ HashMap <Long , Integer > cnt = new HashMap <>();
97100 while ((tmp = qIn .dequeueTask ()) != LocalTaskQueue .NO_MORE_TASKS ) {
98101 MatrixBlock matrixBlock = (MatrixBlock ) tmp .getValue ();
99102 long rowIndex = tmp .getIndexes ().getRowIndex ();
@@ -109,19 +112,26 @@ public void processInstruction( ExecutionContext ec ) {
109112 qOut .enqueueTask (new IndexedMatrixValue (tmp .getIndexes (), partialResult ));
110113 }
111114 else {
115+ // aggregation
112116 MatrixBlock currAgg = partialResults .get (rowIndex );
113- if (currAgg == null )
117+ if (currAgg == null ) {
114118 partialResults .put (rowIndex , partialResult );
115- else
119+ cnt .put (rowIndex , 1 );
120+ }
121+ else {
116122 currAgg .binaryOperationsInPlace (plus , partialResult );
117- }
118- }
119-
120- // emit aggregated blocks
121- if ( min .getNumColumns () > min .getBlocksize () ) {
122- for (Map .Entry <Long , MatrixBlock > entry : partialResults .entrySet ()) {
123- MatrixIndexes outIndexes = new MatrixIndexes (entry .getKey (), 1L );
124- qOut .enqueueTask (new IndexedMatrixValue (outIndexes , entry .getValue ()));
123+ int newCnt = cnt .get (rowIndex ) + 1 ;
124+ cnt .replace (rowIndex , newCnt );
125+
126+ if (newCnt == nBlocks ){
127+ // early block output: emit aggregated block
128+ MatrixIndexes idx = new MatrixIndexes (rowIndex , 1L );
129+ MatrixBlock result = partialResults .get (rowIndex );
130+ qOut .enqueueTask (new IndexedMatrixValue (idx , result ));
131+ partialResults .remove (rowIndex );
132+ cnt .remove (rowIndex );
133+ }
134+ }
125135 }
126136 }
127137 }
0 commit comments