2020package org .apache .sysds .runtime .instructions .ooc ;
2121
2222import java .util .HashMap ;
23- import java .util .Map ;
2423import java .util .concurrent .ExecutorService ;
2524
2625import org .apache .sysds .common .Opcodes ;
@@ -82,6 +81,9 @@ public void processInstruction( ExecutionContext ec ) {
8281 partitionedVector .put (key , vectorSlice );
8382 }
8483
84+ // number of colBlocks for early block output
85+ long nBlocks = min .getDataCharacteristics ().getNumColBlocks ();
86+
8587 LocalTaskQueue <IndexedMatrixValue > qIn = min .getStreamHandle ();
8688 LocalTaskQueue <IndexedMatrixValue > qOut = new LocalTaskQueue <>();
8789 BinaryOperator plus = InstructionUtils .parseBinaryOperator (Opcodes .PLUS .toString ());
@@ -94,6 +96,7 @@ public void processInstruction( ExecutionContext ec ) {
9496 IndexedMatrixValue tmp = null ;
9597 try {
9698 HashMap <Long , MatrixBlock > partialResults = new HashMap <>();
99+ HashMap <Long , Integer > cnt = new HashMap <>();
97100 while ((tmp = qIn .dequeueTask ()) != LocalTaskQueue .NO_MORE_TASKS ) {
98101 MatrixBlock matrixBlock = (MatrixBlock ) tmp .getValue ();
99102 long rowIndex = tmp .getIndexes ().getRowIndex ();
@@ -109,19 +112,29 @@ public void processInstruction( ExecutionContext ec ) {
109112 qOut .enqueueTask (new IndexedMatrixValue (tmp .getIndexes (), partialResult ));
110113 }
111114 else {
115+ // aggregation
112116 MatrixBlock currAgg = partialResults .get (rowIndex );
113- if (currAgg == null )
117+ if (currAgg == null ) {
114118 partialResults .put (rowIndex , partialResult );
115- else
119+ cnt .put (rowIndex , 1 );
120+ }
121+ else {
116122 currAgg .binaryOperationsInPlace (plus , partialResult );
117- }
118- }
119-
120- // emit aggregated blocks
121- if ( min .getNumColumns () > min .getBlocksize () ) {
122- for (Map .Entry <Long , MatrixBlock > entry : partialResults .entrySet ()) {
123- MatrixIndexes outIndexes = new MatrixIndexes (entry .getKey (), 1L );
124- qOut .enqueueTask (new IndexedMatrixValue (outIndexes , entry .getValue ()));
123+ int newCnt = cnt .get (rowIndex ) + 1 ;
124+
125+ if (newCnt == nBlocks ){
126+ // early block output: emit aggregated block
127+ MatrixIndexes idx = new MatrixIndexes (rowIndex , 1L );
128+ MatrixBlock result = partialResults .get (rowIndex );
129+ qOut .enqueueTask (new IndexedMatrixValue (idx , result ));
130+ partialResults .remove (rowIndex );
131+ cnt .remove (rowIndex );
132+ }
133+ else {
134+ // maintain aggregation counts if not output-ready yet
135+ cnt .replace (rowIndex , newCnt );
136+ }
137+ }
125138 }
126139 }
127140 }
0 commit comments