Skip to content

Commit 3b5e0bc

Browse files
committed
[SYSTEMDS-3805] Fix scalar right indexing (only for valid indices)
In order to ensure consistent error handling, we now only use the scalar right indexing if the index-range is within the matrix dims.
1 parent 716d5ce commit 3b5e0bc

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixIndexingCPInstruction.java

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ protected MatrixIndexingCPInstruction(CPOperand lhsInput, CPOperand rhsInput, CP
5050
@Override
5151
public void processInstruction(ExecutionContext ec) {
5252
String opcode = getOpcode();
53-
IndexRange ixrange = getIndexRange(ec);
53+
IndexRange ix = getIndexRange(ec);
5454

5555
//get original matrix
5656
MatrixObject mo = ec.getMatrixObject(input1.getName());
@@ -61,19 +61,19 @@ public void processInstruction(ExecutionContext ec) {
6161
MatrixBlock resultBlock = null;
6262

6363
if( mo.isPartitioned() ) //via data partitioning
64-
resultBlock = mo.readMatrixPartition(ixrange.add(1));
65-
else if( ixrange.isScalar() ){
64+
resultBlock = mo.readMatrixPartition(ix.add(1));
65+
else if( ix.isScalar() && ix.rowStart < mo.getNumRows() && ix.colStart < mo.getNumColumns() ) {
6666
MatrixBlock matBlock = mo.acquireReadAndRelease();
6767
resultBlock = new MatrixBlock(
68-
matBlock.get((int)ixrange.rowStart, (int)ixrange.colStart));
68+
matBlock.get((int)ix.rowStart, (int)ix.colStart));
6969
}
7070
else //via slicing the in-memory matrix
7171
{
7272
//execute right indexing operation (with shallow row copies for range
7373
//of entire sparse rows, which is safe due to copy on update)
7474
MatrixBlock matBlock = mo.acquireRead();
75-
resultBlock = matBlock.slice((int)ixrange.rowStart, (int)ixrange.rowEnd,
76-
(int)ixrange.colStart, (int)ixrange.colEnd, false, new MatrixBlock());
75+
resultBlock = matBlock.slice((int)ix.rowStart, (int)ix.rowEnd,
76+
(int)ix.colStart, (int)ix.colEnd, false, new MatrixBlock());
7777

7878
//unpin rhs input
7979
ec.releaseMatrixInput(input1.getName());
@@ -101,15 +101,15 @@ else if ( opcode.equalsIgnoreCase(LeftIndex.OPCODE))
101101

102102
if(input2.getDataType() == DataType.MATRIX) { //MATRIX<-MATRIX
103103
MatrixBlock rhsMatBlock = ec.getMatrixInput(input2.getName());
104-
resultBlock = matBlock.leftIndexingOperations(rhsMatBlock, ixrange, new MatrixBlock(), updateType);
104+
resultBlock = matBlock.leftIndexingOperations(rhsMatBlock, ix, new MatrixBlock(), updateType);
105105
ec.releaseMatrixInput(input2.getName());
106106
}
107107
else { //MATRIX<-SCALAR
108-
if(!ixrange.isScalar())
109-
throw new DMLRuntimeException("Invalid index range of scalar leftindexing: "+ixrange.toString()+"." );
108+
if(!ix.isScalar())
109+
throw new DMLRuntimeException("Invalid index range of scalar leftindexing: "+ix.toString()+"." );
110110
ScalarObject scalar = ec.getScalarInput(input2.getName(), ValueType.FP64, input2.isLiteral());
111111
resultBlock = matBlock.leftIndexingOperations(scalar,
112-
(int)ixrange.rowStart, (int)ixrange.colStart, new MatrixBlock(), updateType);
112+
(int)ix.rowStart, (int)ix.colStart, new MatrixBlock(), updateType);
113113
}
114114

115115
//unpin lhs input

0 commit comments

Comments
 (0)