Skip to content

Commit f8a5e96

Browse files
Vi VuongVi Vuong
authored andcommitted
rowcumsum implementation and tests
1 parent 8874584 commit f8a5e96

File tree

13 files changed

+615
-3
lines changed

13 files changed

+615
-3
lines changed

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ public enum Builtins {
291291
ROLL("roll", false),
292292
ROUND("round", false),
293293
ROW_COUNT_DISTINCT("rowCountDistinct",false),
294+
ROWCUMSUM("rowcumsum", false),
294295
ROWINDEXMAX("rowIndexMax", false),
295296
ROWINDEXMIN("rowIndexMin", false),
296297
ROWMAX("rowMaxs", false),

src/main/java/org/apache/sysds/common/Opcodes.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ public enum Opcodes {
3636
UAKP("uak+", InstructionType.AggregateUnary),
3737
UARKP("uark+", InstructionType.AggregateUnary),
3838
UACKP("uack+", InstructionType.AggregateUnary),
39+
UARCKP("uarck+", InstructionType.AggregateUnary),
3940
UASQKP("uasqk+", InstructionType.AggregateUnary),
4041
UARSQKP("uarsqk+", InstructionType.AggregateUnary),
4142
UACSQKP("uacsqk+", InstructionType.AggregateUnary),
@@ -151,6 +152,7 @@ public enum Opcodes {
151152
CEIL("ceil", InstructionType.Unary),
152153
FLOOR("floor", InstructionType.Unary),
153154
UCUMKP("ucumk+", InstructionType.Unary),
155+
UROWCUMKP("urowcumk+", InstructionType.Unary),
154156
UCUMM("ucum*", InstructionType.Unary),
155157
UCUMKPM("ucumk+*", InstructionType.Unary),
156158
UCUMMIN("ucummin", InstructionType.Unary),
@@ -383,6 +385,7 @@ public enum Opcodes {
383385
UCUMACMIN("ucumacmin", InstructionType.CumsumAggregate),
384386
UCUMACMAX("ucumacmax", InstructionType.CumsumAggregate),
385387
BCUMOFFKP("bcumoffk+", InstructionType.CumsumOffset),
388+
BROWCUMOFFKP("browcumoffk+", InstructionType.CumsumOffset),
386389
BCUMOFFM("bcumoff*", InstructionType.CumsumOffset),
387390
BCUMOFFPM("bcumoff+*", InstructionType.CumsumOffset),
388391
BCUMOFFMIN("bcumoffmin", InstructionType.CumsumOffset),

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ public enum OpOp1 {
545545
CEIL, CHOLESKY, COS, COSH, CUMMAX, CUMMIN, CUMPROD, CUMSUM,
546546
CUMSUMPROD, DET, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR, INVERSE,
547547
IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
548-
MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, STOP, _EVICT,
548+
MEDIAN, PREFETCH, PRINT, ROUND, ROWCUMSUM, SIN, SINH, SIGN, SOFTMAX, SQRT, STOP, _EVICT,
549549
SVD, TAN, TANH, TYPEOF, TRIGREMOTE, SQRT_MATRIX_JAVA,
550550
//fused ML-specific operators for performance
551551
SPROP, //sample proportion: P * (1 - P)
@@ -589,6 +589,7 @@ public String toString() {
589589
case MULT2: return Opcodes.MULT2.toString();
590590
case NOT: return Opcodes.NOT.toString();
591591
case POW2: return Opcodes.POW2.toString();
592+
case ROWCUMSUM: return Opcodes.UROWCUMKP.toString();
592593
case TYPEOF: return Opcodes.TYPEOF.toString();
593594
default: return name().toLowerCase();
594595
}
@@ -608,6 +609,7 @@ public static OpOp1 valueOfByOpcode(String opcode) {
608609
case "ucummin": return CUMMIN;
609610
case "ucum*": return CUMPROD;
610611
case "ucumk+": return CUMSUM;
612+
case "urowcumk+": return ROWCUMSUM;
611613
case "ucumk+*": return CUMSUMPROD;
612614
case "detectSchema": return DETECTSCHEMA;
613615
case "*2": return MULT2;

src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,7 @@ else if( getAllExpr().length == 2 ) { //binary
10341034
break;
10351035

10361036
case CUMSUM:
1037+
case ROWCUMSUM:
10371038
case CUMPROD:
10381039
case CUMSUMPROD:
10391040
case CUMMIN:

src/main/java/org/apache/sysds/parser/DMLTranslator.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2616,6 +2616,7 @@ else if ( sop.equalsIgnoreCase(Opcodes.NOTEQUAL.toString()) )
26162616
case CEIL:
26172617
case FLOOR:
26182618
case CUMSUM:
2619+
case ROWCUMSUM:
26192620
case CUMPROD:
26202621
case CUMSUMPROD:
26212622
case CUMMIN:

src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public class Builtin extends ValueFunction
4949

5050
public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, ATAN, LOG, LOG_NZ, MIN,
5151
MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX,
52-
STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST,
52+
STOP, CEIL, FLOOR, CUMSUM, ROWCUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST,
5353
TYPEOF, APPLY_SCHEMA, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE,
5454
DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE,
5555
MAP, COUNT_DISTINCT, COUNT_DISTINCT_APPROX, UNIQUE}
@@ -95,6 +95,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS,
9595
String2BuiltinCode.put( "ceil" , BuiltinCode.CEIL);
9696
String2BuiltinCode.put( "floor" , BuiltinCode.FLOOR);
9797
String2BuiltinCode.put( "ucumk+" , BuiltinCode.CUMSUM);
98+
String2BuiltinCode.put( "urowcumk+" , BuiltinCode.ROWCUMSUM);
9899
String2BuiltinCode.put( "ucum*" , BuiltinCode.CUMPROD);
99100
String2BuiltinCode.put( "ucumk+*", BuiltinCode.CUMSUMPROD);
100101
String2BuiltinCode.put( "ucummin", BuiltinCode.CUMMIN);

src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,8 @@ public static AggregateUnaryOperator parseBasicCumulativeAggregateUnaryOperator(
552552
Builtin f = (Builtin)uop.fn;
553553
if( f.getBuiltinCode()==BuiltinCode.CUMSUM )
554554
return parseBasicAggregateUnaryOperator(Opcodes.UACKP.toString()) ;
555+
else if( f.getBuiltinCode()==BuiltinCode.ROWCUMSUM )
556+
return parseBasicAggregateUnaryOperator(Opcodes.UARCKP.toString()) ;
555557
else if( f.getBuiltinCode()==BuiltinCode.CUMPROD )
556558
return parseBasicAggregateUnaryOperator(Opcodes.UACM.toString()) ;
557559
else if( f.getBuiltinCode()==BuiltinCode.CUMMIN )

src/main/java/org/apache/sysds/runtime/instructions/spark/CumulativeOffsetSPInstruction.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ private CumulativeOffsetSPInstruction(Operator op, CPOperand in1, CPOperand in2,
5656

5757
if (Opcodes.BCUMOFFKP.toString().equals(opcode))
5858
_uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucumk+"));
59+
else if (Opcodes.BROWCUMOFFKP.toString().equals(opcode))
60+
_uop = new UnaryOperator(Builtin.getBuiltinFnObject("urowcumk+"));
5961
else if (Opcodes.BCUMOFFM.toString().equals(opcode))
6062
_uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucum*"));
6163
else if (Opcodes.BCUMOFFPM.toString().equals(opcode)) {

src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryMatrixSPInstruction.java

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,22 @@
2121

2222
import org.apache.spark.api.java.JavaPairRDD;
2323
import org.apache.spark.api.java.function.Function;
24+
import org.apache.spark.api.java.function.PairFunction;
2425
import org.apache.sysds.common.Types.DataType;
2526
import org.apache.sysds.common.Types.ValueType;
2627
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
2728
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
29+
import org.apache.sysds.runtime.functionobjects.KahanPlus;
2830
import org.apache.sysds.runtime.instructions.InstructionUtils;
2931
import org.apache.sysds.runtime.instructions.cp.CPOperand;
32+
import org.apache.sysds.runtime.instructions.cp.KahanObject;
3033
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
3134
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
3235
import org.apache.sysds.runtime.matrix.operators.Operator;
3336
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
37+
import scala.Serializable;
38+
import scala.Tuple2;
39+
import java.util.*;
3440

3541
public class UnaryMatrixSPInstruction extends UnarySPInstruction {
3642

@@ -61,6 +67,210 @@ public void processInstruction(ExecutionContext ec) {
6167
updateUnaryOutputDataCharacteristics(sec);
6268
sec.setRDDHandleForVariable(output.getName(), out);
6369
sec.addLineageRDD(output.getName(), input1.getName());
70+
71+
if ( "urowcumk+".equals(getOpcode()) ) {
72+
73+
JavaPairRDD< MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock> > localRowcumsum = in.mapToPair( new LocalRowCumsumFunction() );
74+
75+
76+
// Collect end-values of every block of every row for offset calc by grouping by global row index
77+
JavaPairRDD< Long, Iterable<Tuple3<Long, Long, double[]>> > rowEndValues = localRowcumsum
78+
.mapToPair( tuple2 -> {
79+
80+
// get index of block
81+
MatrixIndexes indexes = tuple2._1;
82+
// get cum matrix block
83+
MatrixBlock localRowcumsumBlock = tuple2._2._2;
84+
85+
// get row and column block index
86+
long rowBlockIndex = indexes.getRowIndex();
87+
long colBlockIndex = indexes.getColumnIndex();
88+
89+
// Save end value of every row of every block (if block is empty save 0)
90+
double[] endValues = new double[ localRowcumsumBlock.getNumRows() ];
91+
92+
for ( int i = 0; i < localRowcumsumBlock.getNumRows(); i ++ ) {
93+
if (localRowcumsumBlock.getNumColumns() > 0) {
94+
endValues[i] = localRowcumsumBlock.get(i, localRowcumsumBlock.getNumColumns() - 1);
95+
} else {
96+
endValues[i] = 0.0 ;
97+
}
98+
}
99+
100+
return new Tuple2<>(rowBlockIndex, new Tuple3<>(rowBlockIndex, colBlockIndex, endValues));
101+
}
102+
103+
).groupByKey();
104+
105+
106+
107+
108+
// compute offset for every block
109+
List< Tuple2 <Tuple2<Long, Long>, double[]> > offsetList = rowEndValues
110+
.flatMapToPair(tuple2 -> {
111+
112+
Long rowBlockIndex = tuple2._1;
113+
114+
List< Tuple3<Long, Long, double[]> > colValues = new ArrayList<>();
115+
for ( Tuple3<Long, Long, double[]> cv : tuple2._2 ) {
116+
colValues.add(cv);
117+
}
118+
119+
// sort blocks from one row by column index
120+
colValues.sort(Comparator.comparing(Tuple3::_2));
121+
122+
// get number of rows of a block by counting amount of end (row) values of said block
123+
int numberOfRows = 0;
124+
if ( !colValues.isEmpty() ) {
125+
Tuple3<Long, Long, double[]> firstTuple = colValues.get(0);
126+
double[] lastValuesArray = firstTuple._3();
127+
numberOfRows = lastValuesArray.length;
128+
}
129+
130+
131+
List<Tuple2<Tuple2<Long, Long>, double[]>> blockOffsets = new ArrayList<>();
132+
133+
double[] cumulativeOffsets = new double[numberOfRows];
134+
135+
for (Tuple3<Long, Long, double[]> colValue : colValues) {
136+
137+
Long colBlockIndex = colValue._2();
138+
double[] endValues = colValue._3();
139+
140+
// copy current offsets
141+
double[] currentOffsets = cumulativeOffsets.clone();
142+
143+
// and save block indexes with its offsets
144+
blockOffsets.add( new Tuple2<>(new Tuple2<>(rowBlockIndex, colBlockIndex), currentOffsets) );
145+
146+
for ( int i = 0; i < numberOfRows && i < endValues.length; i++ ) {
147+
cumulativeOffsets[i] += endValues[i];
148+
}
149+
150+
}
151+
return blockOffsets.iterator();
152+
}
153+
).collect();
154+
155+
156+
// convert list to map for easier access to offsets
157+
Map< Tuple2<Long, Long>, double[] > offsetMap = new HashMap<>();
158+
for (Tuple2<Tuple2<Long, Long>, double[]> offset : offsetList) {
159+
offsetMap.put(offset._1, offset._2);
160+
}
161+
162+
163+
out = localRowcumsum.mapToPair( new FinalRowCumsumFunction(offsetMap)) ;
164+
165+
updateUnaryOutputDataCharacteristics(sec);
166+
sec.setRDDHandleForVariable(output.getName(), out);
167+
sec.addLineageRDD(output.getName(), input1.getName());
168+
}
169+
}
170+
171+
172+
173+
private static class LocalRowCumsumFunction implements PairFunction< Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock> > {
174+
175+
@Override
176+
public Tuple2< MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock> > call(Tuple2<MatrixIndexes, MatrixBlock> tuple2) {
177+
178+
179+
MatrixBlock inputBlock = tuple2._2;
180+
MatrixBlock cumsumBlock = new MatrixBlock( inputBlock.getNumRows(), inputBlock.getNumColumns(), false );
181+
182+
183+
for ( int i = 0; i < inputBlock.getNumRows(); i++ ) {
184+
185+
KahanObject kbuff = new KahanObject(0, 0);
186+
KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
187+
188+
for ( int j = 0; j < inputBlock.getNumColumns(); j++ ) {
189+
190+
double val = inputBlock.get(i, j);
191+
kplus.execute2(kbuff, val);
192+
cumsumBlock.set(i, j, kbuff._sum);
193+
}
194+
}
195+
// original index, original matrix and local cumsum block
196+
return new Tuple2<>( tuple2._1, new Tuple2<>(inputBlock, cumsumBlock) );
197+
}
198+
}
199+
200+
201+
202+
203+
private static class FinalRowCumsumFunction implements PairFunction<Tuple2< MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock> >, MatrixIndexes, MatrixBlock> {
204+
205+
206+
// map block indexes to the row offsets
207+
private final Map< Tuple2<Long, Long>, double[] > offsetMap;
208+
209+
public FinalRowCumsumFunction(Map<Tuple2<Long, Long>, double[]> offsetMap) {
210+
this.offsetMap = offsetMap;
211+
}
212+
213+
214+
@Override
215+
public Tuple2<MatrixIndexes, MatrixBlock> call( Tuple2< MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock> > tuple ) {
216+
217+
MatrixIndexes indexes = tuple._1;
218+
MatrixBlock inputBlock = tuple._2._1;
219+
MatrixBlock localRowCumsumBlock = tuple._2._2;
220+
221+
// key to get the row offset for this block
222+
Tuple2<Long, Long> blockKey = new Tuple2<>( indexes.getRowIndex(), indexes.getColumnIndex()) ;
223+
double[] offsets = offsetMap.get(blockKey);
224+
225+
MatrixBlock cumsumBlock = new MatrixBlock( inputBlock.getNumRows(), inputBlock.getNumColumns(), false );
226+
227+
228+
for ( int i = 0; i < inputBlock.getNumRows(); i++ ) {
229+
230+
double rowOffset = 0.0;
231+
if ( offsets != null && i < offsets.length ) {
232+
rowOffset = offsets[i];
233+
}
234+
235+
for ( int j = 0; j < inputBlock.getNumColumns(); j++ ) {
236+
double cumsumValue = localRowCumsumBlock.get(i, j);
237+
cumsumBlock.set(i, j, cumsumValue + rowOffset);
238+
}
239+
}
240+
241+
// block index and final cumsum block
242+
return new Tuple2<>(indexes, cumsumBlock);
243+
}
244+
}
245+
246+
247+
248+
// helper class
249+
private static class Tuple3<Type1, Type2, Type3> implements Serializable {
250+
251+
private static final long serialVersionUID = 123;
252+
private final Type1 _1;
253+
private final Type2 _2;
254+
private final Type3 _3;
255+
256+
257+
public Tuple3( Type1 _1, Type2 _2, Type3 _3 ) {
258+
this._1 = _1;
259+
this._2 = _2;
260+
this._3 = _3;
261+
}
262+
263+
public Type1 _1() {
264+
return _1;
265+
}
266+
267+
public Type2 _2() {
268+
return _2;
269+
}
270+
271+
public Type3 _3() {
272+
return _3;
273+
}
64274
}
65275

66276
private static class RDDMatrixBuiltinUnaryOp implements Function<MatrixBlock,MatrixBlock>

0 commit comments

Comments
 (0)