Skip to content

Commit 63e6efe

Browse files
committed
[SYSTEMDS-3805] Fix rewrite issues and python test versions
- fix remaining edge cases of scalar right indexing in codegen and various setting with inconsistent paths - Python document from 3.7 to 3.8 because not avaiable on ubuntu 24
1 parent a46189c commit 63e6efe

File tree

8 files changed

+25
-46
lines changed

8 files changed

+25
-46
lines changed

.github/workflows/documentation.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ jobs:
5555
distribution: ${{ matrix.javadist }}
5656
java-version: ${{ matrix.java }}
5757
cache: 'maven'
58-
58+
5959
- name: Make Documentation SystemDS Java
6060
run: mvn -ntp -P distribution package
6161

@@ -69,7 +69,7 @@ jobs:
6969
- name: Setup Python
7070
uses: actions/setup-python@v5
7171
with:
72-
python-version: 3.7
72+
python-version: 3.8
7373
architecture: 'x64'
7474

7575
- name: Cache Pip Dependencies

src/main/java/org/apache/sysds/hops/codegen/template/TemplateCell.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ public boolean open(Hop hop) {
8585
return hop.dimsKnown() && isValidOperation(hop)
8686
&& !(hop.getDim1()==1 && hop.getDim2()==1)
8787
|| (hop instanceof IndexingOp && hop.getInput().get(0).getDim2() >= 0
88-
&& (((IndexingOp)hop).isColLowerEqualsUpper() || hop.getDim2()==1))
88+
&& (((IndexingOp)hop).isColLowerEqualsUpper() || hop.getDim2()==1)
89+
&& !((IndexingOp)hop).isScalarOutput())
8990
|| (HopRewriteUtils.isDataGenOpWithLiteralInputs(hop, OpOpDG.SEQ)
9091
&& HopRewriteUtils.hasOnlyUnaryBinaryParents(hop, true))
9192
|| (HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX, OpOpN.PLUS) && hop.isMatrix())

src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ && isFuseSkinnyMatrixMult(hop.getParent().get(0)))
111111
&& HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG))
112112
|| (hop instanceof IndexingOp && hop.getInput().get(0).getDim1() > 1
113113
&& hop.getInput().get(0).getDim2() >= 0
114+
&& !((IndexingOp)hop).isScalarOutput()
114115
&& HopRewriteUtils.isColumnRangeIndexing((IndexingOp)hop))
115116
|| (HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD, OpOpDnn.BIASMULT)
116117
&& hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown()

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ private static Hop simplifyRowwiseAggregate( Hop parent, Hop hi, int pos ) {
698698
HopRewriteUtils.cleanupUnreferenced(hi);
699699
hi = input;
700700

701-
LOG.debug("Applied simplifyRowwiseAggregate1");
701+
LOG.debug("Applied simplifyRowwiseAggregate1 (line "+hi.getBeginLine()+")");
702702
}
703703
}
704704
else if( input.getDim1() == 1 )
@@ -1371,7 +1371,7 @@ private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos)
13711371

13721372
if( HopRewriteUtils.isEqualSize(left, right) //dims(A) == dims(B)
13731373
&& left.getDataType() == DataType.MATRIX
1374-
&& right.getDataType() == DataType.MATRIX )
1374+
&& right.getDataType() == DataType.MATRIX )
13751375
{
13761376
OpOp2 applyOp = ( bop.getOp() == OpOp2.PLUS //pattern a: sum(A+B)->sum(A)+sum(B)
13771377
|| bop.getOp() == OpOp2.MINUS ) //pattern b: sum(A-B)->sum(A)-sum(B)
@@ -1380,7 +1380,7 @@ private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos)
13801380
if( applyOp != null ) {
13811381
//create new subdag sum(A) bop sum(B)
13821382
AggUnaryOp sum1 = HopRewriteUtils.createSum(left);
1383-
AggUnaryOp sum2 = HopRewriteUtils.createSum(right);
1383+
AggUnaryOp sum2 = HopRewriteUtils.createSum(right);
13841384
BinaryOp newBin = HopRewriteUtils.createBinary(sum1, sum2, applyOp);
13851385

13861386
//rewire new subdag
@@ -1389,8 +1389,8 @@ private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos)
13891389

13901390
hi = newBin;
13911391

1392-
LOG.debug("Applied pushdownSumOnAdditiveBinary.");
1393-
}
1392+
LOG.debug("Applied pushdownSumOnAdditiveBinary (line "+hi.getBeginLine()+").");
1393+
}
13941394
}
13951395
}
13961396

@@ -2292,7 +2292,7 @@ private static Hop simplifyDotProductSum(Hop parent, Hop hi, int pos) {
22922292
//sum(v^2)/sum(v1*v2) --> as.scalar(t(v)%*%v) in order to exploit tsmm vector dotproduct
22932293
//w/o materialization of intermediates
22942294
if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getOp()==AggOp.SUM //sum
2295-
&& ((AggUnaryOp)hi).getDirection()==Direction.RowCol //full aggregate
2295+
&& ((AggUnaryOp)hi).getDirection()==Direction.RowCol //full aggregate
22962296
&& hi.getInput().get(0).getDim2() == 1 ) //vector (for correctness)
22972297
{
22982298
Hop baLeft = null;
@@ -2337,12 +2337,12 @@ else if( HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1) //no other consumer than s
23372337
UnaryOp cast = HopRewriteUtils.createUnary(mmult, OpOp1.CAST_AS_SCALAR);
23382338

23392339
//rehang new subdag under parent node
2340-
HopRewriteUtils.replaceChildReference(parent, hi, cast, pos);
2340+
HopRewriteUtils.replaceChildReference(parent, hi, cast, pos);
23412341
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
23422342

23432343
hi = cast;
23442344

2345-
LOG.debug("Applied simplifyDotProductSum.");
2345+
LOG.debug("Applied simplifyDotProductSum (line "+hi.getBeginLine()+").");
23462346
}
23472347
}
23482348

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1184,7 +1184,7 @@ private static Hop pushdownSumBinaryMult(Hop parent, Hop hi, int pos ) {
11841184

11851185
HopRewriteUtils.replaceChildReference(parent, hi, bop, pos);
11861186

1187-
LOG.debug("Applied pushdownSumBinaryMult.");
1187+
LOG.debug("Applied pushdownSumBinaryMult (line "+hi.getBeginLine()+").");
11881188
return bop;
11891189
}
11901190
return hi;
@@ -1514,6 +1514,7 @@ private static Hop simplifyScalarIndexing(Hop parent, Hop hi, int pos)
15141514
//as.scalar(X[i,1]) -> X[i,1] w/ scalar output
15151515
if( HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR)
15161516
&& hi.getInput(0).getParent().size() == 1 // only consumer
1517+
&& hi.getParent().size() == 1 //avoid temp inconsistency
15171518
&& hi.getInput(0) instanceof IndexingOp
15181519
&& ((IndexingOp)hi.getInput(0)).isScalarOutput()
15191520
&& hi.getInput(0).isMatrix() //no frame support yet

src/test/java/org/apache/sysds/test/functions/codegenalg/partone/AlgorithmLinregCG.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ private void runLinregCGTest( String testname, boolean rewrites, boolean sparse,
318318
loadTestConfiguration(config);
319319

320320
fullDMLScriptName = getScript();
321-
programArgs = new String[]{ "-stats", "-nvargs", "X="+input("X"), "Y="+input("y"),
321+
programArgs = new String[]{ "-explain","-stats", "-nvargs", "X="+input("X"), "Y="+input("y"),
322322
"icpt="+String.valueOf(intercept), "tol="+String.valueOf(epsilon),
323323
"maxi="+String.valueOf(maxiter), "reg=0.001", "B="+output("w")};
324324

src/test/java/org/apache/sysds/test/functions/unary/matrix/EigenFactorizeTest.java

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
package org.apache.sysds.test.functions.unary.matrix;
2121

2222
import org.junit.Test;
23-
import org.apache.sysds.api.DMLScript;
2423
import org.apache.sysds.common.Types.ExecMode;
2524
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
2625
import org.apache.sysds.test.AutomatedTestBase;
@@ -74,13 +73,8 @@ public void testLargeEigenFactorizeDenseHybrid() {
7473
}
7574

7675
private void runTestEigenFactorize( int rows, ExecMode rt)
77-
{
78-
ExecMode rtold = rtplatform;
79-
rtplatform = rt;
80-
81-
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
82-
if( rtplatform == ExecMode.SPARK )
83-
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
76+
{
77+
ExecMode rtold = setExecMode(rt);
8478

8579
try
8680
{
@@ -100,16 +94,14 @@ private void runTestEigenFactorize( int rows, ExecMode rt)
10094
for(int i=0; i < numEigenValuesToEvaluate; i++) {
10195
D[i][0] = 0.0;
10296
}
103-
writeExpectedMatrix("D", D);
97+
writeExpectedMatrix("D", D);
10498

10599
boolean exceptionExpected = false;
106100
runTest(true, exceptionExpected, null, -1);
107101
compareResults(1e-8);
108102
}
109103
finally {
110-
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
111-
rtplatform = rtold;
104+
resetExecMode(rtold);
112105
}
113106
}
114-
115-
}
107+
}

src/test/scripts/functions/unary/matrix/eigen.dml

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
# to you under the Apache License, Version 2.0 (the
88
# "License"); you may not use this file except in compliance
99
# with the License. You may obtain a copy of the License at
10-
#
10+
#
1111
# http://www.apache.org/licenses/LICENSE-2.0
12-
#
12+
#
1313
# Unless required by applicable law or agreed to in writing,
1414
# software distributed under the License is distributed on an
1515
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -29,32 +29,16 @@ A = t(A) %*% A; # make the input matrix symmetric
2929

3030
[eval, evec] = eigen(A);
3131

32-
/*
33-
B = evec %*% diag(eval) %*% t(evec);
34-
diff = sum(A - B);
35-
D = matrix(1,1,1);
36-
D = diff*D;
37-
*/
38-
3932
numEval = $2;
4033
D = matrix(1, numEval, 1);
4134
for ( i in 1:numEval ) {
4235
Av = A %*% evec[,i];
36+
while(FALSE){} #fix incorrect rewrite sequence
4337
rhs = as.scalar(eval[i,1]) * evec[,i];
38+
while(FALSE){} #fix incorrect rewrite sequence
4439
diff = sum(Av-rhs);
4540
D[i,1] = diff;
4641
}
4742

48-
/*
49-
# TODO: dummy if() must be removed
50-
v = evec[,1];
51-
Av = A %*% v;
52-
rhs = as.scalar(eval[1,1]) * evec[,1];
53-
diff = sum(Av-rhs);
54-
55-
D = matrix(1,1,1);
56-
D = diff*D;
57-
*/
58-
5943
write(D, $3);
6044

0 commit comments

Comments
 (0)