Skip to content

Commit 2e68ad3

Browse files
Frxmsmboehm7
authored andcommitted
[SYSTEMDS-3860] Extended codegen row template by var aggregates
Closes #2244.
1 parent 8ecd7c5 commit 2e68ad3

File tree

11 files changed

+184
-10
lines changed

11 files changed

+184
-10
lines changed

src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeUnary.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public class CNodeUnary extends CNode
3333
public enum UnaryType {
3434
LOOKUP_R, LOOKUP_C, LOOKUP_RC, LOOKUP0, //codegen specific
3535
ROW_SUMS, ROW_SUMSQS, ROW_COUNTNNZS, //codegen specific
36-
ROW_MEANS, ROW_MINS, ROW_MAXS,
36+
ROW_MEANS, ROW_MINS, ROW_MAXS, ROW_VARS,
3737
VECT_EXP, VECT_POW2, VECT_MULT2, VECT_SQRT, VECT_LOG,
3838
VECT_ABS, VECT_ROUND, VECT_CEIL, VECT_FLOOR, VECT_SIGN,
3939
VECT_SIN, VECT_COS, VECT_TAN, VECT_ASIN, VECT_ACOS, VECT_ATAN,
@@ -139,6 +139,7 @@ public String toString() {
139139
case ROW_MINS: return "u(Rmin)";
140140
case ROW_MAXS: return "u(Rmax)";
141141
case ROW_MEANS: return "u(Rmean)";
142+
case ROW_VARS: return "u(Rvar)";
142143
case ROW_COUNTNNZS: return "u(Rnnz)";
143144
case VECT_EXP:
144145
case VECT_POW2:
@@ -210,6 +211,7 @@ public void setOutputDims() {
210211
case ROW_MINS:
211212
case ROW_MAXS:
212213
case ROW_MEANS:
214+
case ROW_VARS:
213215
case ROW_COUNTNNZS:
214216
case EXP:
215217
case LOOKUP_R:

src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ public String getTemplate(UnaryType type, boolean sparse) {
3232
case ROW_MINS:
3333
case ROW_MAXS:
3434
case ROW_MEANS:
35+
case ROW_VARS:
3536
case ROW_COUNTNNZS: {
3637
String vectName = StringUtils.capitalize(type.name().substring(4, type.name().length()-1).toLowerCase());
3738
return sparse ? " double %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, len);\n":
3839
" double %TMP% = LibSpoofPrimitives.vect"+vectName+"(%IN1%, %POS1%, %LEN%);\n";
3940
}
40-
4141
case VECT_EXP:
4242
case VECT_POW2:
4343
case VECT_MULT2:

src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767

6868
public class TemplateRow extends TemplateBase
6969
{
70-
private static final AggOp[] SUPPORTED_ROW_AGG = new AggOp[]{AggOp.SUM, AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.PROD};
70+
private static final AggOp[] SUPPORTED_ROW_AGG = new AggOp[]{AggOp.SUM, AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.PROD, AggOp.VAR};
7171
private static final OpOp1[] SUPPORTED_VECT_UNARY = new OpOp1[]{
7272
OpOp1.EXP, OpOp1.SQRT, OpOp1.LOG, OpOp1.ABS, OpOp1.ROUND, OpOp1.CEIL, OpOp1.FLOOR, OpOp1.SIGN,
7373
OpOp1.SIN, OpOp1.COS, OpOp1.TAN, OpOp1.ASIN, OpOp1.ACOS, OpOp1.ATAN, OpOp1.SINH, OpOp1.COSH, OpOp1.TANH,

src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2151,7 +2151,19 @@ public static double[] vectConv2dmmWrite(double[] a, double[] b, int ai, int bi,
21512151
new DenseBlockFP64(new int[]{K, PQ}, c), PQ, CRS, 0, K, 0, PQ);
21522152
return c;
21532153
}
2154-
2154+
2155+
public static double vectVar(double[] a, int ai, int len) {
2156+
double meanVal = Math.pow(vectMean(a, ai, len), 2);
2157+
double[] aSqr = vectPow2Write(a, ai, len);
2158+
return (vectSum(aSqr, 0, len)-len*meanVal)/(len-1);
2159+
}
2160+
2161+
public static double vectVar(double[] avals, int[] aix, int ai, int alen, int len) {
2162+
double meanVal = Math.pow(vectMean(avals, aix, ai, alen, len), 2);
2163+
double[] avalsSqr = vectPow2Write(avals, aix, ai, alen, len);
2164+
return (vectSum(avalsSqr, 0, len)-len*meanVal)/(len-1);
2165+
}
2166+
21552167
//complex builtin functions that are not directly generated
21562168
//(included here in order to reduce the number of imports)
21572169

src/test/java/org/apache/sysds/test/component/misc/DMLScriptTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.apache.log4j.Level;
2525
import org.apache.log4j.Logger;
2626
import org.apache.log4j.spi.LoggingEvent;
27-
import org.apache.sysds.api.DMLOptions;
2827
import org.apache.sysds.api.DMLScript;
2928
import org.apache.sysds.parser.LanguageException;
3029
import org.apache.sysds.test.LoggingUtils;

src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinMDTest.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,6 @@ public void testMDCP() {
9090
}
9191

9292
@Test
93-
//@Ignore
94-
// https://issues.apache.org/jira/browse/SYSTEMDS-3716
9593
public void testMDSP() {
9694
double[][] D = {
9795
{7567, 231, 1231, 1232, 122, 321},

src/test/java/org/apache/sysds/test/functions/codegen/RowAggTmplTest.java

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ public class RowAggTmplTest extends AutomatedTestBase
8787
private static final String TEST_NAME44 = TEST_NAME+"44"; //maxpool(X - mean(X)) + 7;
8888
private static final String TEST_NAME45 = TEST_NAME+"45"; //vector allocation;
8989
private static final String TEST_NAME46 = TEST_NAME+"46"; //conv2d(X - mean(X), F1) + conv2d(X - mean(X), F2);
90-
90+
private static final String TEST_NAME47 = TEST_NAME+"47"; //sum(X + rowVars(X))
91+
private static final String TEST_NAME48 = TEST_NAME+"48"; //sum(rowVars(X))
92+
9193
private static final String TEST_DIR = "functions/codegen/";
9294
private static final String TEST_CLASS_DIR = TEST_DIR + RowAggTmplTest.class.getSimpleName() + "/";
9395
private final static String TEST_CONF = "SystemDS-config-codegen.xml";
@@ -98,7 +100,7 @@ public class RowAggTmplTest extends AutomatedTestBase
98100
@Override
99101
public void setUp() {
100102
TestUtils.clearAssertionInformation();
101-
for(int i=1; i<=46; i++)
103+
for(int i=1; i<=48; i++)
102104
addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) );
103105
}
104106

@@ -795,6 +797,36 @@ public void testCodegenRowAgg46SP() {
795797
testCodegenIntegration( TEST_NAME46, false, ExecType.SPARK );
796798
}
797799

800+
@Test
801+
public void testCodegenRowAggRewrite47CP() {
802+
testCodegenIntegration( TEST_NAME47, true, ExecType.CP );
803+
}
804+
805+
@Test
806+
public void testCodegenRowAgg47CP() {
807+
testCodegenIntegration( TEST_NAME47, false, ExecType.CP );
808+
}
809+
810+
@Test
811+
public void testCodegenRowAgg47SP() {
812+
testCodegenIntegration( TEST_NAME47, false, ExecType.SPARK );
813+
}
814+
815+
@Test
816+
public void testCodegenRowAggRewrite48CP() {
817+
testCodegenIntegration( TEST_NAME48, true, ExecType.CP );
818+
}
819+
820+
@Test
821+
public void testCodegenRowAgg48CP() {
822+
testCodegenIntegration( TEST_NAME48, false, ExecType.CP );
823+
}
824+
825+
@Test
826+
public void testCodegenRowAgg48SP() {
827+
testCodegenIntegration( TEST_NAME48, false, ExecType.SPARK );
828+
}
829+
798830
private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType )
799831
{
800832
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
@@ -807,7 +839,7 @@ private void testCodegenIntegration( String testname, boolean rewrites, ExecType
807839

808840
String HOME = SCRIPT_DIR + TEST_DIR;
809841
fullDMLScriptName = HOME + testname + ".dml";
810-
programArgs = new String[]{"-stats", "-args", output("S") };
842+
programArgs = new String[]{"-explain", "codegen", "-stats", "-args", output("S") };
811843

812844
fullRScriptName = HOME + testname + ".R";
813845
rCmd = getRCmd(inputDir(), expectedDir());
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
args<-commandArgs(TRUE)
23+
options(digits=22)
24+
library("Matrix")
25+
library("matrixStats")
26+
27+
# rowVars <- function(X) {
28+
# apply(X, 1, function(x) sum((x - mean(x))^2) / length(x))
29+
# }
30+
31+
X = matrix(seq(7, 50*10+6), 50, 10, byrow=TRUE);
32+
z = seq(1,50)
33+
34+
R = as.matrix(sum(X + rowVars(X)));
35+
36+
writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep=""));
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
X = matrix(seq(7, 50*10+6), 50, 10);
23+
z = seq(1,50)
24+
25+
while(FALSE){}
26+
27+
R = as.matrix(sum(X + rowVars(X)));
28+
29+
write(R, $1)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
args<-commandArgs(TRUE)
22+
options(digits=22)
23+
library("Matrix")
24+
library("matrixStats")
25+
26+
# rowVars <- function(X) {
27+
# apply(X, 1, function(x) sum((x - mean(x))^2) / length(x))
28+
# }
29+
30+
Z = matrix(seq(1,10), 1, 10)
31+
Y = matrix(0, 10, 10)
32+
X = rbind(Y, Z, Y)
33+
34+
R = as.matrix(sum(rowVars(X)));
35+
36+
writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep=""));

0 commit comments

Comments
 (0)