Skip to content

Commit a46189c

Browse files
committed
[SYSTEMDS-3805] Rewrite and runtime for scalar right indexing
This patch adds a new rewrite, as well as modifies existing rewrites and runtime instructions in order to perform scalar right indexing for operations like as.scalar(X[i,1]) which avoids unnecessary createvar and cast instructions. On a scenario of running the baseline (non-vectorized) exponential smoothing on 10M data points, the patch improved end-to-end performance from from 22.3s to 12.2s (6.7s without statistics time measurements). alpha = 0.05 r = as.scalar(X[1, 1]) for(i in 2:nrow(X)) { r = alpha * as.scalar(X[i, 1]) + (1-alpha) * r } Total elapsed time: 22.348 sec. Total compilation time: 0.516 sec. Total execution time: 21.832 sec. Cache hits (Mem/Li/WB/FS/HDFS): 20000000/0/0/0/0. Cache writes (Li/WB/FS/HDFS): 1/0/0/0. Cache times (ACQr/m, RLS, EXP): 0.777/0.432/1.124/0.000 sec. HOP DAGs recompiled (PRED, SB): 0/0. HOP DAGs recompile time: 0.300 sec. Functions recompiled: 1. Functions recompile time: 0.002 sec. Total JIT compile time: 2.608 sec. Total JVM GC count: 1. Total JVM GC time: 0.018 sec. Heavy hitter instructions: 1 rightIndex 4.894 10000000 2 createvar 3.585 10000001 3 rmvar 2.848 30000000 4 castdts 2.242 10000000 5 * 1.742 19999998 6 + 0.898 9999999 7 mvvar 0.751 10000002 8 rand 0.213 1 9 - 0.016 1 10 print 0.000 1 11 assignvar 0.000 2 Total elapsed time: 12.589 sec. Total compilation time: 0.520 sec. Total execution time: 12.069 sec. Cache hits (Mem/Li/WB/FS/HDFS): 10000000/0/0/0/0. Cache writes (Li/WB/FS/HDFS): 1/0/0/0. Cache times (ACQr/m, RLS, EXP): 0.455/0.000/0.463/0.000 sec. HOP DAGs recompiled (PRED, SB): 0/0. HOP DAGs recompile time: 0.313 sec. Functions recompiled: 1. Functions recompile time: 0.002 sec. Total JIT compile time: 1.923 sec. Total JVM GC count: 1. Total JVM GC time: 0.011 sec. Heavy hitter instructions: 1 rightIndex 3.046 10000000 2 * 1.876 19999998 3 rmvar 1.450 20000000 4 + 0.954 9999999 5 mvvar 0.801 10000002 6 rand 0.213 1 7 - 0.018 1 8 print 0.000 1 9 createvar 0.000 1 10 assignvar 0.000 2
1 parent 3b5e0bc commit a46189c

File tree

11 files changed

+220
-42
lines changed

11 files changed

+220
-42
lines changed

src/main/java/org/apache/sysds/hops/IndexingOp.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ public IndexingOp(String l, DataType dt, ValueType vt, Hop inpMatrix, Hop inpRow
7373
setRowLowerEqualsUpper(passedRowsLEU);
7474
setColLowerEqualsUpper(passedColsLEU);
7575
}
76+
77+
public boolean isScalarOutput() {
78+
return isRowLowerEqualsUpper() && isColLowerEqualsUpper();
79+
}
7680

7781
public boolean isRowLowerEqualsUpper(){
7882
return _rowLowerEqualsUpper;

src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1332,7 +1332,7 @@ public static boolean isConsecutiveIndex(Hop index, Hop index2) {
13321332
}
13331333

13341334
public static boolean isUnnecessaryRightIndexing(Hop hop) {
1335-
if( !(hop instanceof IndexingOp) )
1335+
if( !(hop instanceof IndexingOp) || hop.isScalar() )
13361336
return false;
13371337
//note: in addition to equal sizes, we also check a valid
13381338
//starting row and column ranges of 1 in order to guard against

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ private static Hop removeEmptyRightIndexing(Hop parent, Hop hi, int pos)
241241

242242
private static Hop removeUnnecessaryRightIndexing(Hop parent, Hop hi, int pos)
243243
{
244-
if( HopRewriteUtils.isUnnecessaryRightIndexing(hi) ) {
244+
if( HopRewriteUtils.isUnnecessaryRightIndexing(hi) && !hi.isScalar() ) {
245245
//remove unnecessary right indexing
246246
Hop input = hi.getInput().get(0);
247247
HopRewriteUtils.replaceChildReference(parent, hi, input, pos);

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
174174
hi = simplifyTraceMatrixMult(hop, hi, i); //e.g., trace(X%*%Y)->sum(X*t(Y));
175175
hi = simplifySlicedMatrixMult(hop, hi, i); //e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1];
176176
hi = simplifyListIndexing(hi); //e.g., L[i:i, 1:ncol(L)] -> L[i:i, 1:1]
177+
hi = simplifyScalarIndexing(hop, hi, i); //e.g., as.scalar(X[i,1])->X[i,1] w/ scalar output
177178
hi = simplifyConstantSort(hop, hi, i); //e.g., order(matrix())->matrix/seq;
178179
hi = simplifyOrderedSort(hop, hi, i); //e.g., order(matrix())->seq;
179180
hi = fuseOrderOperationChain(hi); //e.g., order(order(X,2),1) -> order(X,(12))
@@ -1508,6 +1509,27 @@ private static Hop simplifyListIndexing(Hop hi) {
15081509
return hi;
15091510
}
15101511

1512+
private static Hop simplifyScalarIndexing(Hop parent, Hop hi, int pos)
1513+
{
1514+
//as.scalar(X[i,1]) -> X[i,1] w/ scalar output
1515+
if( HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR)
1516+
&& hi.getInput(0).getParent().size() == 1 // only consumer
1517+
&& hi.getInput(0) instanceof IndexingOp
1518+
&& ((IndexingOp)hi.getInput(0)).isScalarOutput()
1519+
&& hi.getInput(0).isMatrix() //no frame support yet
1520+
&& !HopRewriteUtils.isData(parent, OpOpData.TRANSIENTWRITE))
1521+
{
1522+
Hop hi2 = hi.getInput().get(0);
1523+
hi2.setDataType(DataType.SCALAR);
1524+
hi2.setDim1(0); hi2.setDim2(0);
1525+
HopRewriteUtils.replaceChildReference(parent, hi, hi2, pos);
1526+
HopRewriteUtils.cleanupUnreferenced(hi);
1527+
hi = hi2;
1528+
LOG.debug("Applied simplifyScalarIndexing (line "+hi.getBeginLine()+").");
1529+
}
1530+
return hi;
1531+
}
1532+
15111533
private static Hop simplifyConstantSort(Hop parent, Hop hi, int pos)
15121534
{
15131535
//order(matrix(7), indexreturn=FALSE) -> matrix(7)

src/main/java/org/apache/sysds/hops/rewrite/RewriteIndexingVectorization.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,8 @@ private static void vectorizeRightIndexing( Hop hop )
186186
ihops.add(ihop0);
187187
for( Hop c : input.getParent() ){
188188
if( c != ihop0 && c instanceof IndexingOp && c.getInput().get(0) == input
189-
&& ((IndexingOp) c).isRowLowerEqualsUpper()
190-
&& c.getInput().get(1)==ihop0.getInput().get(1) )
189+
&& ((IndexingOp) c).isRowLowerEqualsUpper() && !c.isScalar()
190+
&& c.getInput().get(1)==ihop0.getInput().get(1) )
191191
{
192192
ihops.add( c );
193193
}
@@ -225,7 +225,7 @@ private static void vectorizeRightIndexing( Hop hop )
225225
ihops.add(ihop0);
226226
for( Hop c : input.getParent() ){
227227
if( c != ihop0 && c instanceof IndexingOp && c.getInput().get(0) == input
228-
&& ((IndexingOp) c).isColLowerEqualsUpper()
228+
&& ((IndexingOp) c).isColLowerEqualsUpper() && !c.isScalar()
229229
&& c.getInput().get(3)==ihop0.getInput().get(3) )
230230
{
231231
ihops.add( c );

src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixIndexingCPInstruction.java

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -52,39 +52,46 @@ public void processInstruction(ExecutionContext ec) {
5252
String opcode = getOpcode();
5353
IndexRange ix = getIndexRange(ec);
5454

55-
//get original matrix
5655
MatrixObject mo = ec.getMatrixObject(input1.getName());
56+
boolean inRange = ix.rowStart < mo.getNumRows() && ix.colStart < mo.getNumColumns();
5757

5858
//right indexing
5959
if( opcode.equalsIgnoreCase(RightIndex.OPCODE) )
6060
{
61-
MatrixBlock resultBlock = null;
62-
63-
if( mo.isPartitioned() ) //via data partitioning
64-
resultBlock = mo.readMatrixPartition(ix.add(1));
65-
else if( ix.isScalar() && ix.rowStart < mo.getNumRows() && ix.colStart < mo.getNumColumns() ) {
61+
if( output.isScalar() && inRange ) { //SCALAR out
6662
MatrixBlock matBlock = mo.acquireReadAndRelease();
67-
resultBlock = new MatrixBlock(
68-
matBlock.get((int)ix.rowStart, (int)ix.colStart));
63+
ec.setScalarOutput(output.getName(),
64+
new DoubleObject(matBlock.get((int)ix.rowStart, (int)ix.colStart)));
6965
}
70-
else //via slicing the in-memory matrix
71-
{
72-
//execute right indexing operation (with shallow row copies for range
73-
//of entire sparse rows, which is safe due to copy on update)
74-
MatrixBlock matBlock = mo.acquireRead();
75-
resultBlock = matBlock.slice((int)ix.rowStart, (int)ix.rowEnd,
76-
(int)ix.colStart, (int)ix.colEnd, false, new MatrixBlock());
66+
else { //MATRIX out
67+
MatrixBlock resultBlock = null;
7768

78-
//unpin rhs input
79-
ec.releaseMatrixInput(input1.getName());
69+
if( mo.isPartitioned() ) //via data partitioning
70+
resultBlock = mo.readMatrixPartition(ix.add(1));
71+
else if( ix.isScalar() && inRange ) {
72+
MatrixBlock matBlock = mo.acquireReadAndRelease();
73+
resultBlock = new MatrixBlock(
74+
matBlock.get((int)ix.rowStart, (int)ix.colStart));
75+
}
76+
else //via slicing the in-memory matrix
77+
{
78+
//execute right indexing operation (with shallow row copies for range
79+
//of entire sparse rows, which is safe due to copy on update)
80+
MatrixBlock matBlock = mo.acquireRead();
81+
resultBlock = matBlock.slice((int)ix.rowStart, (int)ix.rowEnd,
82+
(int)ix.colStart, (int)ix.colEnd, false, new MatrixBlock());
83+
84+
//unpin rhs input
85+
ec.releaseMatrixInput(input1.getName());
86+
87+
//ensure correct sparse/dense output representation
88+
if( checkGuardedRepresentationChange(matBlock, resultBlock) )
89+
resultBlock.examSparsity();
90+
}
8091

81-
//ensure correct sparse/dense output representation
82-
if( checkGuardedRepresentationChange(matBlock, resultBlock) )
83-
resultBlock.examSparsity();
92+
//unpin output
93+
ec.setMatrixOutput(output.getName(), resultBlock);
8494
}
85-
86-
//unpin output
87-
ec.setMatrixOutput(output.getName(), resultBlock);
8895
}
8996
//left indexing
9097
else if ( opcode.equalsIgnoreCase(LeftIndex.OPCODE))

src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,11 @@ private void processCastAsScalarVariableInstruction(ExecutionContext ec){
897897
ec.setVariable(output.getName(), list.slice(0));
898898
break;
899899
}
900+
case SCALAR: {
901+
//for robustness in case rewrites added unnecessary as.scalars
902+
ec.setScalarOutput(output.getName(), ec.getScalarInput(getInput1()));
903+
break;
904+
}
900905
default:
901906
throw new DMLRuntimeException("Unsupported data type "
902907
+ "in as.scalar(): "+getInput1().getDataType().name());

src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
3636
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
3737
import org.apache.sysds.runtime.instructions.cp.CPOperand;
38+
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
3839
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
3940
import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
4041
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
@@ -47,6 +48,7 @@
4748
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
4849
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
4950
import org.apache.sysds.runtime.meta.DataCharacteristics;
51+
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
5052
import org.apache.sysds.runtime.util.IndexRange;
5153
import org.apache.sysds.runtime.util.UtilFunctions;
5254
import scala.Function1;
@@ -103,26 +105,35 @@ public void processInstruction(ExecutionContext ec) {
103105
if( opcode.equalsIgnoreCase(RightIndex.OPCODE) )
104106
{
105107
//update and check output dimensions
106-
DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName());
108+
DataCharacteristics mcOut = output.isScalar() ?
109+
new MatrixCharacteristics(1,1) :
110+
ec.getDataCharacteristics(output.getName());
107111
mcOut.set(ru-rl+1, cu-cl+1, mcIn.getBlocksize(), mcIn.getBlocksize());
108112
mcOut.setNonZerosBound(Math.min(mcOut.getLength(), mcIn.getNonZerosBound()));
109113
checkValidOutputDimensions(mcOut);
110114

111115
//execute right indexing operation (partitioning-preserving if possible)
112116
JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = sec.getBinaryMatrixBlockRDDHandleForVariable( input1.getName() );
113-
114-
if( isSingleBlockLookup(mcIn, ixrange) ) {
115-
sec.setMatrixOutput(output.getName(), singleBlockIndexing(in1, mcIn, mcOut, ixrange));
116-
}
117-
else if( isMultiBlockLookup(in1, mcIn, mcOut, ixrange) ) {
118-
sec.setMatrixOutput(output.getName(), multiBlockIndexing(in1, mcIn, mcOut, ixrange));
117+
118+
if( output.isScalar() ) { //SCALAR output
119+
MatrixBlock ret = singleBlockIndexing(in1, mcIn, mcOut, ixrange);
120+
sec.setScalarOutput(output.getName(), new DoubleObject(ret.get(0, 0)));
119121
}
120-
else { //rdd output for general case
121-
JavaPairRDD<MatrixIndexes,MatrixBlock> out = generalCaseRightIndexing(in1, mcIn, mcOut, ixrange, _aggType);
122+
else { //MATRIX output
122123

123-
//put output RDD handle into symbol table
124-
sec.setRDDHandleForVariable(output.getName(), out);
125-
sec.addLineageRDD(output.getName(), input1.getName());
124+
if( isSingleBlockLookup(mcIn, ixrange) ) {
125+
sec.setMatrixOutput(output.getName(), singleBlockIndexing(in1, mcIn, mcOut, ixrange));
126+
}
127+
else if( isMultiBlockLookup(in1, mcIn, mcOut, ixrange) ) {
128+
sec.setMatrixOutput(output.getName(), multiBlockIndexing(in1, mcIn, mcOut, ixrange));
129+
}
130+
else { //rdd output for general case
131+
JavaPairRDD<MatrixIndexes,MatrixBlock> out = generalCaseRightIndexing(in1, mcIn, mcOut, ixrange, _aggType);
132+
133+
//put output RDD handle into symbol table
134+
sec.setRDDHandleForVariable(output.getName(), out);
135+
sec.addLineageRDD(output.getName(), input1.getName());
136+
}
126137
}
127138
}
128139
//left indexing
@@ -178,12 +189,13 @@ else if ( opcode.equalsIgnoreCase(LeftIndex.OPCODE) || opcode.equalsIgnoreCase("
178189
sec.addLineageRDD(output.getName(), input2.getName());
179190
}
180191
else
181-
throw new DMLRuntimeException("Invalid opcode (" + opcode +") encountered in MatrixIndexingSPInstruction.");
192+
throw new DMLRuntimeException("Invalid opcode (" + opcode +") encountered in MatrixIndexingSPInstruction.");
182193
}
183194

184195

185196
public static MatrixBlock inmemoryIndexing(JavaPairRDD<MatrixIndexes,MatrixBlock> in1,
186-
DataCharacteristics mcIn, DataCharacteristics mcOut, IndexRange ixrange) {
197+
DataCharacteristics mcIn, DataCharacteristics mcOut, IndexRange ixrange)
198+
{
187199
if( isSingleBlockLookup(mcIn, ixrange) ) {
188200
return singleBlockIndexing(in1, mcIn, mcOut, ixrange);
189201
}

src/test/java/org/apache/sysds/test/functions/rewrite/RewriteLoopVectorization.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.HashMap;
2323

2424
import org.junit.Assert;
25+
import org.junit.Ignore;
2526
import org.junit.Test;
2627
import org.apache.sysds.hops.OptimizerUtils;
2728
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
@@ -57,6 +58,7 @@ public void testLoopVectorizationSumNoRewrite() {
5758
}
5859

5960
@Test
61+
@Ignore //FIXME: extend loop vectorization rewrite
6062
public void testLoopVectorizationSumRewrite() {
6163
testRewriteLoopVectorizationSum( TEST_NAME1, true );
6264
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.rewrite;
21+
22+
23+
import org.junit.Assert;
24+
import org.junit.Test;
25+
26+
import org.apache.sysds.common.Types.ExecMode;
27+
import org.apache.sysds.common.Types.ExecType;
28+
import org.apache.sysds.hops.OptimizerUtils;
29+
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
30+
import org.apache.sysds.test.AutomatedTestBase;
31+
import org.apache.sysds.test.TestConfiguration;
32+
import org.apache.sysds.utils.Statistics;
33+
34+
public class RewriteScalarRightIndexingTest extends AutomatedTestBase
35+
{
36+
private final static String TEST_DIR = "functions/rewrite/";
37+
private final static String TEST_NAME = "RewriteScalarRightIndexing";
38+
39+
private final static String TEST_CLASS_DIR = TEST_DIR + RewriteScalarRightIndexingTest.class.getSimpleName() + "/";
40+
41+
private final static int rows = 122;
42+
43+
@Override
44+
public void setUp() {
45+
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"A"}));
46+
}
47+
48+
@Test
49+
public void testScalarRightIndexingCP() {
50+
runScalarRightIndexing(true, ExecType.CP);
51+
}
52+
53+
@Test
54+
public void testScalarRightIndexingNoRewriteCP() {
55+
runScalarRightIndexing(false, ExecType.CP);
56+
}
57+
58+
@Test
59+
public void testScalarRightIndexingSpark() {
60+
runScalarRightIndexing(true, ExecType.SPARK);
61+
}
62+
63+
@Test
64+
public void testScalarRightIndexingNoRewriteSpark() {
65+
runScalarRightIndexing(false, ExecType.SPARK);
66+
}
67+
68+
private void runScalarRightIndexing(boolean rewrite, ExecType instType) {
69+
ExecMode platformOld = setExecMode(instType);
70+
boolean flagOld = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
71+
try {
72+
TestConfiguration config = getTestConfiguration(TEST_NAME);
73+
loadTestConfiguration(config);
74+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite;
75+
76+
String HOME = SCRIPT_DIR + TEST_DIR;
77+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
78+
programArgs = new String[]{"-explain", "-stats", "-args",
79+
Long.toString(rows), output("A")};
80+
runTest(true, false, null, -1);
81+
82+
Double ret = readDMLScalarFromOutputDir("A").get(new CellIndex(1,1));
83+
Assert.assertEquals(Double.valueOf(103.0383), ret, 1e-4);
84+
if(rewrite) //w/o rewrite 122 casts
85+
Assert.assertTrue(Statistics.getCPHeavyHitterCount("castdts")<=1);
86+
}
87+
finally {
88+
resetExecMode(platformOld);
89+
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = flagOld;
90+
}
91+
}
92+
}

0 commit comments

Comments
 (0)