Skip to content

Commit c17fcb2

Browse files
Vi VuongvienthanhthaiSandra Shi
authored andcommitted
[SYSTEMDS-3762] Cumulative Row Aggregates and Rewrites
Closes #2279. Co-authored-by: Vien Thanh Thai <[email protected]> Co-authored-by: Sandra Shi <[email protected]>
1 parent 455a738 commit c17fcb2

File tree

18 files changed

+967
-4
lines changed

18 files changed

+967
-4
lines changed

docs/site/dml-language-reference.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,8 @@ quantile () | The p-quantile for a random variable X is the value x such that Pr
702702
quantile () | Returns a column matrix with list of all quantiles requested in P. | Input: (X &lt;(n x 1) matrix&gt;, [W &lt;(n x 1) matrix&gt;),] P &lt;(q x 1) matrix&gt;) <br/> Output: matrix | quantile(X, P) <br/> quantile(X, W, P)
703703
median() | Computes the median in a given column matrix of values | Input: (X &lt;(n x 1) matrix&gt;, [W &lt;(n x 1) matrix&gt;),]) <br/> Output: &lt;scalar&gt; | median(X) <br/> median(X,W)
704704
rowSums() <br/> rowMeans() <br/> rowVars() <br/> rowSds() <br/> rowMaxs() <br/> rowMins() | Row-wise computations -- for each row, compute the sum/mean/variance/stdDev/max/min of cell value | Input: matrix <br/> Output: (n x 1) matrix | rowSums(X) <br/> rowMeans(X) <br/> rowVars(X) <br/> rowSds(X) <br/> rowMaxs(X) <br/> rowMins(X)
705-
cumsum() | Column prefix-sum (For row-prefix sum, use cumsum(t(X)) | Input: matrix <br/> Output: matrix of the same dimensions | A = matrix("1 2 3 4 5 6", rows=3, cols=2) <br/> B = cumsum(A) <br/> The output matrix B = [[1, 2], [4, 6], [9, 12]]
705+
cumsum() | Column prefix-sum | Input: matrix <br/> Output: matrix of the same dimensions | A = matrix("1 2 3 4 5 6", rows=3, cols=2) <br/> B = cumsum(A) <br/> The output matrix B = [[1, 2], [4, 6], [9, 12]]
706+
rowcumsum() | Row prefix-sum | Input: matrix <br/> Output: matrix of the same dimensions | A = matrix("1 2 3 4 5 6", rows=2, cols=3) <br/> B = rowcumsum(A) <br/> The output matrix B = [[1, 3, 6], [4, 9, 15]]
706707
cumprod() | Column prefix-prod (For row-prefix prod, use cumprod(t(X)) | Input: matrix <br/> Output: matrix of the same dimensions | A = matrix("1 2 3 4 5 6", rows=3, cols=2) <br/> B = cumprod(A) <br/> The output matrix B = [[1, 2], [3, 8], [15, 48]]
707708
cummin() | Column prefix-min (For row-prefix min, use cummin(t(X)) | Input: matrix <br/> Output: matrix of the same dimensions | A = matrix("3 4 1 6 5 2", rows=3, cols=2) <br/> B = cummin(A) <br/> The output matrix B = [[3, 4], [1, 4], [1, 2]]
708709
cummax() | Column prefix-max (For row-prefix min, use cummax(t(X)) | Input: matrix <br/> Output: matrix of the same dimensions | A = matrix("3 4 1 6 5 2", rows=3, cols=2) <br/> B = cummax(A) <br/> The output matrix B = [[3, 4], [3, 6], [5, 6]]

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ public enum Builtins {
291291
ROLL("roll", false),
292292
ROUND("round", false),
293293
ROW_COUNT_DISTINCT("rowCountDistinct",false),
294+
ROWCUMSUM("rowcumsum", false),
294295
ROWINDEXMAX("rowIndexMax", false),
295296
ROWINDEXMIN("rowIndexMin", false),
296297
ROWMAX("rowMaxs", false),

src/main/java/org/apache/sysds/common/Opcodes.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public enum Opcodes {
3434
UAKP("uak+", InstructionType.AggregateUnary),
3535
UARKP("uark+", InstructionType.AggregateUnary),
3636
UACKP("uack+", InstructionType.AggregateUnary),
37+
UARCKP("uarck+", InstructionType.AggregateUnary),
3738
UASQKP("uasqk+", InstructionType.AggregateUnary),
3839
UARSQKP("uarsqk+", InstructionType.AggregateUnary),
3940
UACSQKP("uacsqk+", InstructionType.AggregateUnary),
@@ -151,6 +152,7 @@ public enum Opcodes {
151152
CEIL("ceil", InstructionType.Unary),
152153
FLOOR("floor", InstructionType.Unary),
153154
UCUMKP("ucumk+", InstructionType.Unary),
155+
UROWCUMKP("urowcumk+", InstructionType.Unary),
154156
UCUMM("ucum*", InstructionType.Unary),
155157
UCUMKPM("ucumk+*", InstructionType.Unary),
156158
UCUMMIN("ucummin", InstructionType.Unary),
@@ -383,6 +385,7 @@ public enum Opcodes {
383385
UCUMACMIN("ucumacmin", InstructionType.CumsumAggregate),
384386
UCUMACMAX("ucumacmax", InstructionType.CumsumAggregate),
385387
BCUMOFFKP("bcumoffk+", InstructionType.CumsumOffset),
388+
BROWCUMOFFKP("browcumoffk+", InstructionType.CumsumOffset),
386389
BCUMOFFM("bcumoff*", InstructionType.CumsumOffset),
387390
BCUMOFFPM("bcumoff+*", InstructionType.CumsumOffset),
388391
BCUMOFFMIN("bcumoffmin", InstructionType.CumsumOffset),

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ public enum OpOp1 {
547547
CEIL, CHOLESKY, COS, COSH, CUMMAX, CUMMIN, CUMPROD, CUMSUM,
548548
CUMSUMPROD, DET, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR, INVERSE,
549549
IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
550-
MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, STOP, _EVICT,
550+
MEDIAN, PREFETCH, PRINT, ROUND, ROWCUMSUM, SIN, SINH, SIGN, SOFTMAX, SQRT, STOP, _EVICT,
551551
SVD, TAN, TANH, TYPEOF, TRIGREMOTE, SQRT_MATRIX_JAVA,
552552
//fused ML-specific operators for performance
553553
SPROP, //sample proportion: P * (1 - P)
@@ -591,6 +591,7 @@ public String toString() {
591591
case MULT2: return Opcodes.MULT2.toString();
592592
case NOT: return Opcodes.NOT.toString();
593593
case POW2: return Opcodes.POW2.toString();
594+
case ROWCUMSUM: return Opcodes.UROWCUMKP.toString();
594595
case TYPEOF: return Opcodes.TYPEOF.toString();
595596
default: return name().toLowerCase();
596597
}
@@ -610,6 +611,7 @@ public static OpOp1 valueOfByOpcode(String opcode) {
610611
case "ucummin": return CUMMIN;
611612
case "ucum*": return CUMPROD;
612613
case "ucumk+": return CUMSUM;
614+
case "urowcumk+": return ROWCUMSUM;
613615
case "ucumk+*": return CUMSUMPROD;
614616
case "detectSchema": return DETECTSCHEMA;
615617
case "*2": return MULT2;

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
205205
hi = simplifyNotOverComparisons(hop, hi, i); //e.g., !(A>B) -> (A<=B)
206206
hi = simplifyMatrixScalarPMOperation(hop, hi, i); //e.g., a-A-b -> (a-b)-A; a+A-b -> (a-b)+A
207207
//hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
208+
hi = simplifyTransposedCumsum(hop, hi, i); //e.g., t(cumsum(t(X))) -> rowcumsum(X)
208209

209210
//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
210211
if( !descendFirst )
@@ -214,6 +215,28 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
214215
hop.setVisited();
215216
}
216217

218+
private static Hop simplifyTransposedCumsum( Hop parent, Hop hi, int pos )
219+
{
220+
//e.g., t(cumsum(t(X))) -> rowcumsum(X)
221+
if( HopRewriteUtils.isTransposeOperation(hi)
222+
&& hi.getInput(0) instanceof UnaryOp
223+
&& ((UnaryOp)hi.getInput(0)).getOp() == OpOp1.CUMSUM
224+
&& hi.getInput(0).getParent().size() == 1
225+
&& HopRewriteUtils.isTransposeOperation(hi.getInput(0).getInput(0), 1)) //inner transpose with single consumer
226+
{
227+
UnaryOp cumsum=(UnaryOp)hi.getInput(0);
228+
Hop innerMatrix = cumsum.getInput(0).getInput(0);
229+
230+
UnaryOp rowcumsumOp = HopRewriteUtils.createUnary(innerMatrix, OpOp1.ROWCUMSUM);
231+
HopRewriteUtils.replaceChildReference(parent,hi, rowcumsumOp, pos);
232+
233+
hi = rowcumsumOp;
234+
LOG.debug("Applied simplifyTransposedCumsum (line "+hi.getBeginLine()+").");
235+
}
236+
237+
return hi;
238+
}
239+
217240
private Hop simplifyMatrixScalarPMOperation(Hop parent, Hop hi, int pos) {
218241
if (!(hi instanceof BinaryOp))
219242
return hi;

src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,7 @@ else if( getAllExpr().length == 2 ) { //binary
10341034
break;
10351035

10361036
case CUMSUM:
1037+
case ROWCUMSUM:
10371038
case CUMPROD:
10381039
case CUMSUMPROD:
10391040
case CUMMIN:

src/main/java/org/apache/sysds/parser/DMLTranslator.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2616,6 +2616,7 @@ else if ( sop.equalsIgnoreCase(Opcodes.NOTEQUAL.toString()) )
26162616
case CEIL:
26172617
case FLOOR:
26182618
case CUMSUM:
2619+
case ROWCUMSUM:
26192620
case CUMPROD:
26202621
case CUMSUMPROD:
26212622
case CUMMIN:

src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public class Builtin extends ValueFunction
4949

5050
public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN,
5151
MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX,
52-
STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST,
52+
STOP, CEIL, FLOOR, CUMSUM, ROWCUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST,
5353
TYPEOF, APPLY_SCHEMA, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE,
5454
DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE,
5555
MAP, COUNT_DISTINCT, COUNT_DISTINCT_APPROX, UNIQUE}
@@ -95,6 +95,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS,
9595
String2BuiltinCode.put( "ceil" , BuiltinCode.CEIL);
9696
String2BuiltinCode.put( "floor" , BuiltinCode.FLOOR);
9797
String2BuiltinCode.put( "ucumk+" , BuiltinCode.CUMSUM);
98+
String2BuiltinCode.put( "urowcumk+" , BuiltinCode.ROWCUMSUM);
9899
String2BuiltinCode.put( "ucum*" , BuiltinCode.CUMPROD);
99100
String2BuiltinCode.put( "ucumk+*", BuiltinCode.CUMSUMPROD);
100101
String2BuiltinCode.put( "ucummin", BuiltinCode.CUMMIN);

src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,8 @@ public static AggregateUnaryOperator parseBasicCumulativeAggregateUnaryOperator(
557557
Builtin f = (Builtin)uop.fn;
558558
if( f.getBuiltinCode()==BuiltinCode.CUMSUM )
559559
return parseBasicAggregateUnaryOperator(Opcodes.UACKP.toString()) ;
560+
else if( f.getBuiltinCode()==BuiltinCode.ROWCUMSUM )
561+
return parseBasicAggregateUnaryOperator(Opcodes.UARCKP.toString()) ;
560562
else if( f.getBuiltinCode()==BuiltinCode.CUMPROD )
561563
return parseBasicAggregateUnaryOperator(Opcodes.UACM.toString()) ;
562564
else if( f.getBuiltinCode()==BuiltinCode.CUMMIN )

src/main/java/org/apache/sysds/runtime/instructions/spark/CumulativeOffsetSPInstruction.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ private CumulativeOffsetSPInstruction(Operator op, CPOperand in1, CPOperand in2,
5656

5757
if (Opcodes.BCUMOFFKP.toString().equals(opcode))
5858
_uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucumk+"));
59+
else if (Opcodes.BROWCUMOFFKP.toString().equals(opcode))
60+
_uop = new UnaryOperator(Builtin.getBuiltinFnObject("urowcumk+"));
5961
else if (Opcodes.BCUMOFFM.toString().equals(opcode))
6062
_uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucum*"));
6163
else if (Opcodes.BCUMOFFPM.toString().equals(opcode)) {

0 commit comments

Comments
 (0)