Skip to content

Commit 0ec4dcd

Browse files
Frxmsmboehm7
authored andcommitted
[SYSTEMDS-3844] Extended codegen templates w/ row/colProd Agg
Closes #2234.
1 parent 147519e commit 0ec4dcd

File tree

12 files changed

+573
-7
lines changed

12 files changed

+573
-7
lines changed

src/main/java/org/apache/sysds/hops/codegen/template/TemplateCell.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
public class TemplateCell extends TemplateBase
6767
{
6868
private static final AggOp[] SUPPORTED_AGG =
69-
new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX};
69+
new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.PROD};
7070

7171
public TemplateCell() {
7272
super(TemplateType.CELL);

src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767

6868
public class TemplateRow extends TemplateBase
6969
{
70-
private static final AggOp[] SUPPORTED_ROW_AGG = new AggOp[]{AggOp.SUM, AggOp.MIN, AggOp.MAX, AggOp.MEAN};
70+
private static final AggOp[] SUPPORTED_ROW_AGG = new AggOp[]{AggOp.SUM, AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.PROD};
7171
private static final OpOp1[] SUPPORTED_VECT_UNARY = new OpOp1[]{
7272
OpOp1.EXP, OpOp1.SQRT, OpOp1.LOG, OpOp1.ABS, OpOp1.ROUND, OpOp1.CEIL, OpOp1.FLOOR, OpOp1.SIGN,
7373
OpOp1.SIN, OpOp1.COS, OpOp1.TAN, OpOp1.ASIN, OpOp1.ACOS, OpOp1.ATAN, OpOp1.SINH, OpOp1.COSH, OpOp1.TANH,

src/main/java/org/apache/sysds/runtime/codegen/SpoofCellwise.java

Lines changed: 218 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ public enum AggOp {
8484
SUM_SQ,
8585
MIN,
8686
MAX,
87+
PROD
8788
}
8889

8990
protected final CellType _type;
@@ -332,12 +333,16 @@ private long executeDense(DenseBlock a, SideInput[] b, double[] scalars,
332333
else if( _type == CellType.ROW_AGG ) {
333334
if( _aggOp == AggOp.SUM || _aggOp == AggOp.SUM_SQ )
334335
return executeDenseRowAggSum(a, lb, scalars, c, m, n, sparseSafe, rl, ru, rix);
336+
else if(_aggOp == AggOp.PROD)
337+
return executeDenseRowProd(a, lb, scalars, c, m, n, sparseSafe, rl, ru, rix);
335338
else
336339
return executeDenseRowAggMxx(a, lb, scalars, c, m, n, sparseSafe, rl, ru, rix);
337340
}
338341
else if( _type == CellType.COL_AGG ) {
339342
if( _aggOp == AggOp.SUM || _aggOp == AggOp.SUM_SQ )
340343
return executeDenseColAggSum(a, lb, scalars, c, m, n, sparseSafe, rl, ru, rix);
344+
else if(_aggOp == AggOp.PROD)
345+
return executeDenseColProd(a, lb, scalars, c, m, n, sparseSafe, rl, ru, rix);
341346
else
342347
return executeDenseColAggMxx(a, lb, scalars, c, m, n, sparseSafe, rl, ru, rix);
343348
}
@@ -372,12 +377,16 @@ private long executeSparse(SparseBlock sblock, SideInput[] b, double[] scalars,
372377
else if( _type == CellType.ROW_AGG ) {
373378
if( _aggOp == AggOp.SUM || _aggOp == AggOp.SUM_SQ )
374379
return executeSparseRowAggSum(sblock, lb, scalars, out, m, n, sparseSafe, rl, ru, rix);
380+
else if( _aggOp == AggOp.PROD)
381+
return executeSparseRowProd(sblock, lb, scalars, out, m, n, sparseSafe, rl, ru, rix);
375382
else
376383
return executeSparseRowAggMxx(sblock, lb, scalars, out, m, n, sparseSafe, rl, ru, rix);
377384
}
378385
else if( _type == CellType.COL_AGG ) {
379386
if( _aggOp == AggOp.SUM || _aggOp == AggOp.SUM_SQ )
380387
return executeSparseColAggSum(sblock, lb, scalars, out, m, n, sparseSafe, rl, ru, rix);
388+
else if( _aggOp == AggOp.PROD)
389+
return executeSparseColProd(sblock, lb, scalars, out, m, n, sparseSafe, rl, ru, rix);
381390
else
382391
return executeSparseColAggMxx(sblock, lb, scalars, out, m, n, sparseSafe, rl, ru, rix);
383392
}
@@ -930,7 +939,215 @@ private double executeSparseAggMxx(SparseBlock sblock, SideInput[] b, double[] s
930939
}
931940
return ret;
932941
}
933-
942+
943+
private long executeDenseRowProd(DenseBlock a, SideInput[] b, double[] scalars,
944+
DenseBlock c, int m, int n, boolean sparseSafe, int rl, int ru, long rix)
945+
{
946+
// single block output
947+
double[] lc = c.valuesAt(0);
948+
long lnnz = 0;
949+
if(a == null && !sparseSafe) {
950+
for(int i = rl; i < ru; i++) {
951+
for(int j = 0; j < n; j++) {
952+
if(j == 0) {
953+
lc[i] = genexec(0, b, scalars, m, n, rix+i, i, j);
954+
} else if(lc[i] != 0) {
955+
lc[i] *= genexec(0, b, scalars, m, n, rix+i, i, j);
956+
} else {
957+
break;
958+
}
959+
}
960+
lnnz += (lc[i]!=0) ? 1 : 0;
961+
}
962+
}
963+
else if( a != null ) {
964+
for(int i = rl; i < ru; i++) {
965+
double[] avals = a.values(i);
966+
int aix = a.pos(i);
967+
for(int j = 0; j < n; j++) {
968+
double aval = avals[aix + j];
969+
if(aval != 0 || !sparseSafe) {
970+
if(j == 0) {
971+
lc[i] = genexec(aval, b, scalars, m, n, rix+i, i, j);
972+
} else if(lc[i] != 0) {
973+
lc[i] *= genexec(aval, b, scalars, m, n, rix+i, i, j);
974+
} else {
975+
break;
976+
}
977+
} else {
978+
break;
979+
}
980+
}
981+
lnnz += (lc[i] != 0) ? 1 : 0;
982+
}
983+
}
984+
return lnnz;
985+
}
986+
987+
private long executeDenseColProd(DenseBlock a, SideInput[] b, double[] scalars,
988+
DenseBlock c, int m, int n, boolean sparseSafe, int rl, int ru, long rix)
989+
{
990+
double[] lc = c.valuesAt(0);
991+
//track the cols that have a zero
992+
boolean[] zeroFlag = new boolean[n];
993+
if(a == null && !sparseSafe) {
994+
for(int i = rl; i < ru; i++) {
995+
for(int j = 0; j < n; j++) {
996+
if(!zeroFlag[j]) {
997+
if(i == 0) {
998+
lc[j] = genexec(0, b, scalars, m, n, rix+i, i, j);
999+
} else if(lc[j] != 0) {
1000+
lc[j] *= genexec(0, b, scalars, m, n, rix+i, i, j);
1001+
} else {
1002+
zeroFlag[j] = true;
1003+
}
1004+
}
1005+
}
1006+
}
1007+
}
1008+
else if(a != null) {
1009+
for(int i = rl; i < ru; i++) {
1010+
double[] avals = a.values(i);
1011+
int aix = a.pos(i);
1012+
for(int j = 0; j < n; j++) {
1013+
if(!zeroFlag[j]) {
1014+
double aval = avals[aix + j];
1015+
if(aval != 0 || !sparseSafe) {
1016+
if(i == 0) {
1017+
lc[j] = genexec(aval, b, scalars, m, n, rix + i, i, j);
1018+
} else if(lc[j] != 0) {
1019+
lc[j] *= genexec(aval, b, scalars, m, n, rix + i, i, j);
1020+
} else {
1021+
zeroFlag[j] = true;
1022+
}
1023+
}
1024+
} else {
1025+
zeroFlag[j] = true;
1026+
}
1027+
}
1028+
}
1029+
}
1030+
return -1;
1031+
}
1032+
1033+
private long executeSparseRowProd(SparseBlock sblock, SideInput[] b, double[] scalars,
1034+
MatrixBlock out, int m, int n, boolean sparseSafe, int rl, int ru, long rix)
1035+
{
1036+
double[] c = out.getDenseBlockValues();
1037+
long lnnz = 0;
1038+
for(int i = rl; i < ru; i++) {
1039+
int lastj = -1;
1040+
if(sblock != null && !sblock.isEmpty(i)) {
1041+
int apos = sblock.pos(i);
1042+
int alen = sblock.size(i);
1043+
int[] aix = sblock.indexes(i);
1044+
double[] avals = sblock.values(i);
1045+
for(int k = apos; k < apos+alen; k++) {
1046+
if(!sparseSafe) {
1047+
for(int j=lastj+1; j<aix[k]; j++) {
1048+
if(j == 0) {
1049+
c[i] = genexec(0, b, scalars, m, n, rix+i, i, j);
1050+
} else if(c[i] != 0){
1051+
c[i] *= genexec(0, b, scalars, m, n, rix+i, i, j);
1052+
} else {
1053+
break;
1054+
}
1055+
}
1056+
}
1057+
if(aix.length == n || !sparseSafe) {
1058+
if(aix[k] == 0) {
1059+
lastj = aix[k];
1060+
c[i] = genexec(avals[k], b, scalars, m, n, rix+i, i, k);
1061+
} else if(c[i] != 0){
1062+
lastj = aix[k];
1063+
c[i] *= genexec(avals[k], b, scalars, m, n, rix+i, i, k);
1064+
} else {
1065+
break;
1066+
}
1067+
} else {
1068+
break;
1069+
}
1070+
}
1071+
}
1072+
//process remaining zeros
1073+
if(!sparseSafe)
1074+
for(int j=lastj+1; j<n; j++) {
1075+
if(j == 0) {
1076+
c[i] = genexec(0, b, scalars, m, n, rix+i, i, j);
1077+
} else if(c[i] != 0){
1078+
c[i] *= genexec(0, b, scalars, m, n, rix+i, i, j);
1079+
} else {
1080+
break;
1081+
}
1082+
}
1083+
lnnz += (c[i] != 0) ? 1 : 0;
1084+
}
1085+
return lnnz;
1086+
}
1087+
1088+
private long executeSparseColProd(SparseBlock sblock, SideInput[] b, double[] scalars,
1089+
MatrixBlock out, int m, int n, boolean sparseSafe, int rl, int ru, long rix)
1090+
{
1091+
double[] c = out.getDenseBlockValues();
1092+
boolean[] zeroFlag = new boolean[n];
1093+
1094+
for(int i=rl; i<ru; i++) {
1095+
int lastj = -1;
1096+
//handle non-empty rows
1097+
if(sblock != null && !sblock.isEmpty(i)) {
1098+
int apos = sblock.pos(i);
1099+
int alen = sblock.size(i);
1100+
int[] aix = sblock.indexes(i);
1101+
double[] avals = sblock.values(i);
1102+
long nnzCount = sblock.size(rl, ru);
1103+
//process every column, to not miss any 0's
1104+
for(int k=apos; k<apos+alen; k++) {
1105+
//process zeros before current non-zero
1106+
if( !sparseSafe )
1107+
for(int j=lastj+1; j<aix[k]; j++) {
1108+
if(!zeroFlag[j]) {
1109+
if(i == 0) {
1110+
c[j] = genexec(0, b, scalars, m, n, rix+i, i, j);
1111+
} else if(c[j] != 0){
1112+
c[j] *= genexec(0, b, scalars, m, n, rix+i, i, j);
1113+
} else {
1114+
zeroFlag[j] = true;
1115+
}
1116+
}
1117+
}
1118+
//process current non-zero
1119+
if((nnzCount == m || !sparseSafe) && !zeroFlag[aix[k]]) {
1120+
if(i == 0) {
1121+
lastj = aix[k];
1122+
c[aix[k]] = genexec(avals[k], b, scalars, m, n, rix+i, i, lastj);
1123+
} else if(c[aix[k]] != 0){
1124+
lastj = aix[k];
1125+
c[aix[k]] *= genexec(avals[k], b, scalars, m, n, rix+i, i, lastj);
1126+
} else {
1127+
zeroFlag[aix[k]] = true;
1128+
}
1129+
} else {
1130+
zeroFlag[aix[k]] = true;
1131+
}
1132+
}
1133+
}
1134+
//process empty rows or remaining zeros
1135+
if(!sparseSafe)
1136+
for(int j=lastj+1; j<n; j++) {
1137+
if(!zeroFlag[j]) {
1138+
if(i == 0) {
1139+
c[j] = genexec(0, b, scalars, m, n, rix+i, i, j);
1140+
} else if(c[j] != 0){
1141+
c[j] *= genexec(0, b, scalars, m, n, rix+i, i, j);
1142+
} else {
1143+
zeroFlag[j] = true;
1144+
}
1145+
}
1146+
}
1147+
}
1148+
return -1;
1149+
}
1150+
9341151
//local execution where grix==rix
9351152
protected final double genexec( double a, SideInput[] b,
9361153
double[] scalars, int m, int n, int rix, int cix) {

src/test/java/org/apache/sysds/test/functions/codegen/CellwiseTmplTest.java

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ public class CellwiseTmplTest extends AutomatedTestBase
6767
private static final String TEST_NAME25 = TEST_NAME+25; //bias_add
6868
private static final String TEST_NAME26 = TEST_NAME+26; //bias_mult
6969
private static final String TEST_NAME27 = TEST_NAME+27; //outer < +7 negative
70+
private static final String TEST_NAME28 = TEST_NAME+28; //colProds(X^2 + 1)
71+
private static final String TEST_NAME29 = TEST_NAME+29; //colProds(2*log(X))
72+
private static final String TEST_NAME30 = TEST_NAME+30; //rowProds(X^2 + 1)
73+
private static final String TEST_NAME31 = TEST_NAME+31; //rowProds(2*log(X))
7074

7175
private static final String TEST_DIR = "functions/codegen/";
7276
private static final String TEST_CLASS_DIR = TEST_DIR + CellwiseTmplTest.class.getSimpleName() + "/";
@@ -79,7 +83,7 @@ public class CellwiseTmplTest extends AutomatedTestBase
7983
@Override
8084
public void setUp() {
8185
TestUtils.clearAssertionInformation();
82-
for( int i=1; i<=27; i++ ) {
86+
for( int i=1; i<=31; i++ ) {
8387
addTestConfiguration( TEST_NAME+i, new TestConfiguration(
8488
TEST_CLASS_DIR, TEST_NAME+i, new String[] {String.valueOf(i)}) );
8589
}
@@ -444,7 +448,7 @@ public void testCodegenCellwise26() {
444448
public void testCodegenCellwiseRewrite26_sp() {
445449
testCodegenIntegration( TEST_NAME26, true, ExecType.SPARK );
446450
}
447-
451+
448452
@Test
449453
public void testCodegenCellwiseRewrite27() {
450454
testCodegenIntegration( TEST_NAME27, true, ExecType.CP );
@@ -455,10 +459,71 @@ public void testCodegenCellwise27() {
455459
testCodegenIntegration( TEST_NAME27, false, ExecType.CP );
456460
}
457461

462+
@Test
458463
public void testCodegenCellwiseRewrite27_sp() {
459464
testCodegenIntegration( TEST_NAME27, true, ExecType.SPARK );
460465
}
461466

467+
@Test
468+
public void testCodegenCellwiseRewrite28() {
469+
testCodegenIntegration( TEST_NAME28, true, ExecType.CP );
470+
}
471+
472+
@Test
473+
public void testCodegenCellwise28() {
474+
testCodegenIntegration( TEST_NAME28, false, ExecType.CP );
475+
}
476+
477+
@Test
478+
public void testCodegenCellwiseRewrite28_sp() {
479+
testCodegenIntegration( TEST_NAME28, true, ExecType.SPARK );
480+
}
481+
482+
@Test
483+
public void testCodegenCellwiseRewrite29() {
484+
testCodegenIntegration( TEST_NAME29, true, ExecType.CP );
485+
}
486+
487+
@Test
488+
public void testCodegenCellwise29() {
489+
testCodegenIntegration( TEST_NAME29, false, ExecType.CP );
490+
}
491+
492+
@Test
493+
public void testCodegenCellwiseRewrite29_sp() {
494+
testCodegenIntegration( TEST_NAME29, true, ExecType.SPARK );
495+
}
496+
497+
@Test
498+
public void testCodegenCellwiseRewrite30() {
499+
testCodegenIntegration( TEST_NAME30, true, ExecType.CP );
500+
}
501+
502+
@Test
503+
public void testCodegenCellwise30() {
504+
testCodegenIntegration( TEST_NAME30, false, ExecType.CP );
505+
}
506+
507+
@Test
508+
public void testCodegenCellwiseRewrite30_sp() {
509+
testCodegenIntegration( TEST_NAME30, true, ExecType.SPARK );
510+
}
511+
512+
@Test
513+
public void testCodegenCellwiseRewrite31() {
514+
testCodegenIntegration( TEST_NAME31, true, ExecType.CP );
515+
}
516+
517+
@Test
518+
public void testCodegenCellwise31() {
519+
testCodegenIntegration( TEST_NAME31, false, ExecType.CP );
520+
}
521+
522+
@Test
523+
public void testCodegenCellwiseRewrite31_sp() {
524+
testCodegenIntegration( TEST_NAME31, true, ExecType.SPARK );
525+
}
526+
462527
private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType )
463528
{
464529
boolean oldRewrites = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
@@ -467,15 +532,15 @@ private void testCodegenIntegration( String testname, boolean rewrites, ExecType
467532

468533
if( testname.equals(TEST_NAME9) )
469534
TEST_CONF = TEST_CONF6;
470-
535+
471536
try
472537
{
473538
TestConfiguration config = getTestConfiguration(testname);
474539
loadTestConfiguration(config);
475540

476541
String HOME = SCRIPT_DIR + TEST_DIR;
477542
fullDMLScriptName = HOME + testname + ".dml";
478-
programArgs = new String[]{"-stats", "-args", output("S") };
543+
programArgs = new String[]{"-explain", "codegen", "-stats", "-args", output("S") };
479544

480545
fullRScriptName = HOME + testname + ".R";
481546
rCmd = getRCmd(inputDir(), expectedDir());

0 commit comments

Comments
 (0)