Skip to content

Commit e705f89

Browse files
committed
[MINOR] Code cleanups in rewrites and tests
1 parent 2dcd822 commit e705f89

File tree

3 files changed

+38
-60
lines changed

3 files changed

+38
-60
lines changed

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ private static Hop removeUnnecessaryRightIndexing(Hop parent, Hop hi, int pos)
243243
{
244244
if( HopRewriteUtils.isUnnecessaryRightIndexing(hi) && !hi.isScalar() ) {
245245
//remove unnecessary right indexing
246-
Hop input = hi.getInput().get(0);
246+
Hop input = hi.getInput(0);
247247
HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
248248
HopRewriteUtils.cleanupUnreferenced(hi);
249249
hi = input;
@@ -258,8 +258,8 @@ private static Hop removeEmptyLeftIndexing(Hop parent, Hop hi, int pos)
258258
{
259259
if( hi instanceof LeftIndexingOp && hi.getDataType() == DataType.MATRIX ) //left indexing op
260260
{
261-
Hop input1 = hi.getInput().get(0); //lhs matrix
262-
Hop input2 = hi.getInput().get(1); //rhs matrix
261+
Hop input1 = hi.getInput(0); //lhs matrix
262+
Hop input2 = hi.getInput(1); //rhs matrix
263263

264264
if( input1.getNnz()==0 //nnz original known and empty
265265
&& input2.getNnz()==0 ) //nnz input known and empty
@@ -271,7 +271,7 @@ private static Hop removeEmptyLeftIndexing(Hop parent, Hop hi, int pos)
271271
hi = hnew;
272272

273273
LOG.debug("Applied removeEmptyLeftIndexing");
274-
}
274+
}
275275
}
276276

277277
return hi;
@@ -281,19 +281,19 @@ private static Hop removeUnnecessaryLeftIndexing(Hop parent, Hop hi, int pos)
281281
{
282282
if( hi instanceof LeftIndexingOp ) //left indexing op
283283
{
284-
Hop input = hi.getInput().get(1); //rhs matrix/frame
284+
Hop input = hi.getInput(1); //rhs matrix/frame
285285

286286
if( HopRewriteUtils.isEqualSize(hi, input) ) //equal dims
287287
{
288288
//equal dims of left indexing input and output -> no need for indexing
289289

290-
//remove unnecessary right indexing
290+
//remove unnecessary right indexing
291291
HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
292292
HopRewriteUtils.cleanupUnreferenced(hi);
293293
hi = input;
294294

295295
LOG.debug("Applied removeUnnecessaryLeftIndexing");
296-
}
296+
}
297297
}
298298

299299
return hi;
@@ -306,15 +306,15 @@ private static Hop fuseLeftIndexingChainToAppend(Hop parent, Hop hi, int pos)
306306
//pattern1: X[,1]=A; X[,2]=B -> X=cbind(A,B); matrix / frame
307307
if( hi instanceof LeftIndexingOp //first lix
308308
&& HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp)hi)
309-
&& hi.getInput().get(0) instanceof LeftIndexingOp //second lix
309+
&& hi.getInput(0) instanceof LeftIndexingOp //second lix
310310
&& HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp)hi.getInput().get(0))
311-
&& hi.getInput().get(0).getParent().size()==1 //first lix is single consumer
312-
&& hi.getInput().get(0).getInput().get(0).getDim2() == 2 ) //two column matrix
311+
&& hi.getInput(0).getParent().size()==1 //first lix is single consumer
312+
&& hi.getInput(0).getInput(0).getDim2() == 2 ) //two column matrix
313313
{
314-
Hop input2 = hi.getInput().get(1); //rhs matrix
315-
Hop pred2 = hi.getInput().get(4); //cl=cu
316-
Hop input1 = hi.getInput().get(0).getInput().get(1); //lhs matrix
317-
Hop pred1 = hi.getInput().get(0).getInput().get(4); //cl=cu
314+
Hop input2 = hi.getInput(1); //rhs matrix
315+
Hop pred2 = hi.getInput(4); //cl=cu
316+
Hop input1 = hi.getInput(0).getInput(1); //lhs matrix
317+
Hop pred1 = hi.getInput(0).getInput(4); //cl=cu
318318

319319
if( pred1 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred1)==1
320320
&& pred2 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred2)==2
@@ -332,15 +332,15 @@ private static Hop fuseLeftIndexingChainToAppend(Hop parent, Hop hi, int pos)
332332
//pattern1: X[1,]=A; X[2,]=B -> X=rbind(A,B)
333333
if( !applied && hi instanceof LeftIndexingOp //first lix
334334
&& HopRewriteUtils.isFullRowIndexing((LeftIndexingOp)hi)
335-
&& hi.getInput().get(0) instanceof LeftIndexingOp //second lix
335+
&& hi.getInput(0) instanceof LeftIndexingOp //second lix
336336
&& HopRewriteUtils.isFullRowIndexing((LeftIndexingOp)hi.getInput().get(0))
337-
&& hi.getInput().get(0).getParent().size()==1 //first lix is single consumer
338-
&& hi.getInput().get(0).getInput().get(0).getDim1() == 2 ) //two column matrix
337+
&& hi.getInput(0).getParent().size()==1 //first lix is single consumer
338+
&& hi.getInput(0).getInput(0).getDim1() == 2 ) //two column matrix
339339
{
340-
Hop input2 = hi.getInput().get(1); //rhs matrix
341-
Hop pred2 = hi.getInput().get(2); //rl=ru
342-
Hop input1 = hi.getInput().get(0).getInput().get(1); //lhs matrix
343-
Hop pred1 = hi.getInput().get(0).getInput().get(2); //rl=ru
340+
Hop input2 = hi.getInput(1); //rhs matrix
341+
Hop pred2 = hi.getInput(2); //rl=ru
342+
Hop input1 = hi.getInput(0).getInput(1); //lhs matrix
343+
Hop pred1 = hi.getInput(0).getInput(2); //rl=ru
344344

345345
if( pred1 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred1)==1
346346
&& pred2 instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred2)==2
@@ -364,19 +364,19 @@ private static Hop removeUnnecessaryCumulativeOp(Hop parent, Hop hi, int pos)
364364
{
365365
if( hi instanceof UnaryOp && ((UnaryOp)hi).isCumulativeUnaryOperation() )
366366
{
367-
Hop input = hi.getInput().get(0); //input matrix
367+
Hop input = hi.getInput(0); //input matrix
368368

369369
if( HopRewriteUtils.isDimsKnown(input) //dims input known
370370
&& input.getDim1()==1 ) //1 row
371371
{
372372
OpOp1 op = ((UnaryOp)hi).getOp();
373373

374-
//remove unnecessary unary cumsum operator
374+
//remove unnecessary unary cumsum operator
375375
HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
376376
hi = input;
377377

378378
LOG.debug("Applied removeUnnecessaryCumulativeOp: "+op);
379-
}
379+
}
380380
}
381381

382382
return hi;
@@ -413,27 +413,27 @@ private static Hop removeUnnecessaryOuterProduct(Hop parent, Hop hi, int pos)
413413
if( hi instanceof BinaryOp ) //binary cell operation
414414
{
415415
OpOp2 bop = ((BinaryOp)hi).getOp();
416-
Hop left = hi.getInput().get(0);
417-
Hop right = hi.getInput().get(1);
416+
Hop left = hi.getInput(0);
417+
Hop right = hi.getInput(1);
418418

419419
//check for matrix-vector column replication: (A + b %*% ones) -> (A + b)
420420
if( HopRewriteUtils.isMatrixMultiply(right) //matrix mult with datagen
421421
&& HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(1), 1)
422-
&& right.getInput().get(0).getDim2() == 1 ) //column vector for mv binary
422+
&& right.getInput(0).getDim2() == 1 ) //column vector for mv binary
423423
{
424424
//remove unnecessary outer product
425-
HopRewriteUtils.replaceChildReference(hi, right, right.getInput().get(0), 1 );
425+
HopRewriteUtils.replaceChildReference(hi, right, right.getInput(0), 1 );
426426
HopRewriteUtils.cleanupUnreferenced(right);
427427

428428
LOG.debug("Applied removeUnnecessaryOuterProduct1 (line "+right.getBeginLine()+")");
429429
}
430430
//check for matrix-vector row replication: (A + ones %*% b) -> (A + b)
431431
else if( HopRewriteUtils.isMatrixMultiply(right) //matrix mult with datagen
432-
&& HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(0), 1)
433-
&& right.getInput().get(1).getDim1() == 1 ) //row vector for mv binary
432+
&& HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput(0), 1)
433+
&& right.getInput(1).getDim1() == 1 ) //row vector for mv binary
434434
{
435435
//remove unnecessary outer product
436-
HopRewriteUtils.replaceChildReference(hi, right, right.getInput().get(1), 1 );
436+
HopRewriteUtils.replaceChildReference(hi, right, right.getInput(1), 1 );
437437
HopRewriteUtils.cleanupUnreferenced(right);
438438

439439
LOG.debug("Applied removeUnnecessaryOuterProduct2 (line "+right.getBeginLine()+")");
@@ -442,11 +442,11 @@ else if( HopRewriteUtils.isMatrixMultiply(right) //matrix mult with datagen
442442
else if(HopRewriteUtils.isValidOuterBinaryOp(bop)
443443
&& HopRewriteUtils.isMatrixMultiply(left)
444444
&& HopRewriteUtils.isDataGenOpWithConstantValue(left.getInput().get(1), 1)
445-
&& (left.getInput().get(0).getDim2() == 1 //outer product
446-
|| left.getInput().get(1).getDim1() == 1)
445+
&& (left.getInput(0).getDim2() == 1 //outer product
446+
|| left.getInput(1).getDim1() == 1)
447447
&& left.getDim1() != 1 && right.getDim1() == 1 ) //outer vector binary
448448
{
449-
Hop hnew = HopRewriteUtils.createBinary(left.getInput().get(0), right, bop, true);
449+
Hop hnew = HopRewriteUtils.createBinary(left.getInput(0), right, bop, true);
450450
HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
451451
HopRewriteUtils.cleanupUnreferenced(hi);
452452

src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationAllTest.java

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
import org.junit.Assert;
2525
import org.junit.Test;
26-
import org.apache.sysds.api.DMLScript;
2726
import org.apache.sysds.common.Types.ExecMode;
2827
import org.apache.sysds.hops.OptimizerUtils;
2928
import org.apache.sysds.common.Types.ExecType;
@@ -74,16 +73,7 @@ public void testMatrixMultChainOptRewritesSP() {
7473

7574
private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
7675
{
77-
ExecMode platformOld = rtplatform;
78-
switch( et ){
79-
case SPARK: rtplatform = ExecMode.SPARK; break;
80-
default: rtplatform = ExecMode.HYBRID; break;
81-
}
82-
83-
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
84-
if( rtplatform == ExecMode.SPARK || rtplatform == ExecMode.HYBRID )
85-
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
86-
76+
ExecMode platformOld = setExecMode(et);
8777
boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
8878
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
8979

@@ -126,8 +116,7 @@ private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, Exe
126116
}
127117
finally {
128118
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
129-
rtplatform = platformOld;
130-
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
119+
resetExecMode(platformOld);
131120
}
132121
}
133122
}

src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationTest.java

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
import org.junit.Assert;
2525
import org.junit.Test;
26-
import org.apache.sysds.api.DMLScript;
2726
import org.apache.sysds.common.Types.ExecMode;
2827
import org.apache.sysds.hops.OptimizerUtils;
2928
import org.apache.sysds.common.Types.ExecType;
@@ -73,16 +72,7 @@ public void testMatrixMultChainOptRewritesSP() {
7372

7473
private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
7574
{
76-
ExecMode platformOld = rtplatform;
77-
switch( et ){
78-
case SPARK: rtplatform = ExecMode.SPARK; break;
79-
default: rtplatform = ExecMode.HYBRID; break;
80-
}
81-
82-
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
83-
if( rtplatform == ExecMode.SPARK || rtplatform == ExecMode.HYBRID )
84-
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
85-
75+
ExecMode platformOld = setExecMode(et);
8676
boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
8777
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
8878

@@ -119,8 +109,7 @@ private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, Exe
119109
}
120110
finally {
121111
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
122-
rtplatform = platformOld;
123-
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
112+
resetExecMode(platformOld);
124113
}
125114
}
126115
}

0 commit comments

Comments
 (0)