2222import java .util .HashMap ;
2323
2424import org .apache .sysds .common .Opcodes ;
25- import org .apache .sysds .conf .ConfigurationManager ;
2625import org .apache .sysds .runtime .DMLRuntimeException ;
2726import org .apache .sysds .runtime .controlprogram .caching .MatrixObject ;
2827import org .apache .sysds .runtime .controlprogram .context .ExecutionContext ;
@@ -51,7 +50,7 @@ public static MatrixVectorBinaryOOCInstruction parseInstruction(String str) {
5150 InstructionUtils .checkNumFields (parts , 4 );
5251 String opcode = parts [0 ];
5352 CPOperand in1 = new CPOperand (parts [1 ]); // the larget matrix (streamed)
54- CPOperand in2 = new CPOperand (parts [2 ]); // the small vector (in-memory )
53+ CPOperand in2 = new CPOperand (parts [2 ]); // vector operand (may be OOC )
5554 CPOperand out = new CPOperand (parts [3 ]);
5655
5756 AggregateOperator agg = new AggregateOperator (0 , Plus .getPlusFnObject ());
@@ -62,48 +61,54 @@ public static MatrixVectorBinaryOOCInstruction parseInstruction(String str) {
6261
6362 @ Override
6463 public void processInstruction ( ExecutionContext ec ) {
65- // 1. Identify the inputs
66- MatrixObject min = ec .getMatrixObject (input1 ); // big matrix
67- MatrixBlock vin = ec .getMatrixObject (input2 )
68- .acquireReadAndRelease (); // in-memory vector
69-
70- // 2. Pre-partition the in-memory vector into a hashmap
71- HashMap <Long , MatrixBlock > partitionedVector = new HashMap <>();
72- int blksize = vin .getDataCharacteristics ().getBlocksize ();
73- if (blksize < 0 )
74- blksize = ConfigurationManager .getBlocksize ();
75- for (int i =0 ; i <vin .getNumRows (); i +=blksize ) {
76- long key = (long ) (i /blksize ) + 1 ; // the key starts at 1
77- int end_row = Math .min (i + blksize , vin .getNumRows ());
78- MatrixBlock vectorSlice = vin .slice (i , end_row - 1 );
79- partitionedVector .put (key , vectorSlice );
80- }
64+ // Fetch both inputs without assuming which one fits in memory
65+ MatrixObject min = ec .getMatrixObject (input1 );
66+ MatrixObject vin = ec .getMatrixObject (input2 );
8167
8268 // number of colBlocks for early block output
8369 long emitThreshold = min .getDataCharacteristics ().getNumColBlocks ();
8470 OOCMatrixBlockTracker aggTracker = new OOCMatrixBlockTracker (emitThreshold );
8571
86- OOCStream <IndexedMatrixValue > qIn = min .getStreamHandle ();
72+ OOCStream <IndexedMatrixValue > qIn1 = min .getStreamHandle ();
73+ OOCStream <IndexedMatrixValue > qIn2 = vin .getStreamHandle (); // Stream handles for matrix and vector (both may be OOC)
8774 OOCStream <IndexedMatrixValue > qOut = createWritableStream ();
8875 BinaryOperator plus = InstructionUtils .parseBinaryOperator (Opcodes .PLUS .toString ());
8976 ec .getMatrixObject (output ).setStreamHandle (qOut );
9077
9178 submitOOCTask (() -> {
92- IndexedMatrixValue tmp = null ;
79+
9380 try {
94- while ((tmp = qIn .dequeue ()) != LocalTaskQueue .NO_MORE_TASKS ) {
81+ // Cache vector blocks indexed by their block row id
82+ // This removes the assumption that the vector is fully in-memory
83+ HashMap <Long , MatrixBlock > vectorCache = new HashMap <>();
84+
85+ // Consume the entire vector stream and cache it block-wise
86+ IndexedMatrixValue vecVal ;
87+ while ((vecVal = qIn2 .dequeue ()) != LocalTaskQueue .NO_MORE_TASKS ) {
88+ vectorCache .put (
89+ vecVal .getIndexes ().getRowIndex (),
90+ (MatrixBlock ) vecVal .getValue ());
91+ }
92+
93+ // Stream through matrix blocks and match them with vector blocks
94+ IndexedMatrixValue tmp = null ;
95+ while ((tmp = qIn1 .dequeue ()) != LocalTaskQueue .NO_MORE_TASKS ) {
9596 MatrixBlock matrixBlock = (MatrixBlock ) tmp .getValue ();
9697 long rowIndex = tmp .getIndexes ().getRowIndex ();
9798 long colIndex = tmp .getIndexes ().getColumnIndex ();
98- MatrixBlock vectorSlice = partitionedVector .get (colIndex );
99+ MatrixBlock vectorSlice = vectorCache .get (colIndex );
100+
101+ // Fail fast if the corresponding vector block is missing
102+ if (vectorSlice == null )
103+ throw new DMLRuntimeException ("Missing vector block for column block " + colIndex );
99104
100105 // Now, call the operation with the correct, specific operator.
101106 MatrixBlock partialResult = matrixBlock .aggregateBinaryOperations (
102107 matrixBlock , vectorSlice , new MatrixBlock (), (AggregateBinaryOperator ) _optr );
103108
104109 // for single column block, no aggregation neeeded
105110 if (emitThreshold == 1 ) {
106- qOut .enqueue (new IndexedMatrixValue (tmp . getIndexes ( ), partialResult ));
111+ qOut .enqueue (new IndexedMatrixValue (new MatrixIndexes ( rowIndex , 1 ), partialResult ));
107112 }
108113 else {
109114 // aggregation
@@ -129,6 +134,6 @@ public void processInstruction( ExecutionContext ec ) {
129134 finally {
130135 qOut .closeInput ();
131136 }
132- }, qIn , qOut );
137+ }, qIn1 , qIn2 , qOut );
133138 }
134139}
0 commit comments