Skip to content

Commit eb2ca62

Browse files
committed
[SYSTEMDS-3815] Fused table sequence
This commit contains a new fused operator for: table(seq(1, nrow(A)), A, w) That removes the need to generate a vector of incrementing integers in the size of A. Previously, we already had support for this operator and called it rexpand. However, that implementation still allocated the seq vector. We see a 1.4x improvement in the rexpand operator, and with the addition of removing the seq allocation, it further improves to 1.72x. The change is not fully integrated into the Federated Instructions and needs additional work. The current workaround tries to compile the previous instruction for federated use cases. Closes #2181
1 parent 9484f11 commit eb2ca62

File tree

23 files changed

+1199
-100
lines changed

23 files changed

+1199
-100
lines changed

src/main/java/org/apache/sysds/hops/TernaryOp.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ else if( isCTableReshapeRewriteApplicable(et, ternaryOp) ) {
307307
}
308308

309309
Ctable ternary = new Ctable(inputLops, ternaryOp,
310-
getDataType(), getValueType(), ignoreZeros, outputEmptyBlocks, et);
310+
getDataType(), getValueType(), ignoreZeros, outputEmptyBlocks, et, OptimizerUtils.getConstrainedNumThreads(getMaxNumThreads()));
311311

312312
ternary.getOutputParameters().setDimensions(getDim1(), getDim2(), getBlocksize(), -1);
313313
setLineNumbers(ternary);
@@ -480,6 +480,10 @@ protected DataCharacteristics inferOutputCharacteristics( MemoTable memo )
480480
}
481481

482482

483+
public ExecType findExecTypeTernaryOp(){
484+
return _etype == null ? optFindExecType(OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE) : _etype;
485+
}
486+
483487
@Override
484488
protected ExecType optFindExecType(boolean transitive)
485489
{
@@ -637,7 +641,7 @@ && getInput().get(1) == that2.getInput().get(1)
637641
return ret;
638642
}
639643

640-
private boolean isSequenceRewriteApplicable( boolean left )
644+
public boolean isSequenceRewriteApplicable( boolean left )
641645
{
642646
boolean ret = false;
643647

@@ -651,7 +655,9 @@ private boolean isSequenceRewriteApplicable( boolean left )
651655
{
652656
Hop input1 = getInput().get(0);
653657
Hop input2 = getInput().get(1);
654-
if( input1.getDataType() == DataType.MATRIX && input2.getDataType() == DataType.MATRIX )
658+
if( (input1.getDataType() == DataType.MATRIX
659+
|| input1.getDataType() == DataType.SCALAR )
660+
&& input2.getDataType() == DataType.MATRIX )
655661
{
656662
//probe rewrite on left input
657663
if( left && input1 instanceof DataGenOp )
@@ -663,6 +669,9 @@ private boolean isSequenceRewriteApplicable( boolean left )
663669
|| dgop.getIncrementValue()==1.0; //set by recompiler
664670
}
665671
}
672+
if( left && input1 instanceof LiteralOp && ((LiteralOp)input1).getStringValue().contains("seq(")){
673+
ret = true;
674+
}
666675
//probe rewrite on right input
667676
if( !left && input2 instanceof DataGenOp )
668677
{

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.apache.sysds.common.Types.AggOp;
3131
import org.apache.sysds.common.Types.DataType;
3232
import org.apache.sysds.common.Types.Direction;
33+
import org.apache.sysds.common.Types.ExecType;
3334
import org.apache.sysds.common.Types.OpOp1;
3435
import org.apache.sysds.common.Types.OpOp2;
3536
import org.apache.sysds.common.Types.OpOp3;
@@ -209,6 +210,8 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
209210
//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
210211
if( !descendFirst )
211212
rule_AlgebraicSimplification(hi, descendFirst);
213+
214+
hi = fuseSeqAndTableExpand(hi);
212215
}
213216

214217
hop.setVisited();
@@ -2913,4 +2916,24 @@ private static Hop simplyfyMMCBindZeroVector(Hop parent, Hop hi, int pos) {
29132916
}
29142917
return hi;
29152918
}
2919+
2920+
2921+
private static Hop fuseSeqAndTableExpand(Hop hi) {
2922+
2923+
if(TernaryOp.ALLOW_CTABLE_SEQUENCE_REWRITES && hi instanceof TernaryOp ) {
2924+
TernaryOp thop = (TernaryOp) hi;
2925+
thop.getOp();
2926+
2927+
if(thop.isSequenceRewriteApplicable(true) && thop.findExecTypeTernaryOp() == ExecType.CP &&
2928+
thop.getInput(1).getForcedExecType() != Types.ExecType.FED) {
2929+
Hop input1 = thop.getInput(0);
2930+
if(input1 instanceof DataGenOp){
2931+
Hop literal = new LiteralOp("seq(1, "+input1.getDim1() +")");
2932+
HopRewriteUtils.replaceChildReference(hi, input1, literal);
2933+
}
2934+
}
2935+
2936+
}
2937+
return hi;
2938+
}
29162939
}

src/main/java/org/apache/sysds/lops/Ctable.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ public class Ctable extends Lop
3636
{
3737
private final boolean _ignoreZeros;
3838
private final boolean _outputEmptyBlocks;
39+
private final int _numThreads;
3940

4041
public enum OperationTypes {
4142
CTABLE_TRANSFORM,
@@ -58,15 +59,16 @@ public boolean hasThirdInput() {
5859
OperationTypes operation;
5960

6061

61-
public Ctable(Lop[] inputLops, OperationTypes op, DataType dt, ValueType vt, ExecType et) {
62-
this(inputLops, op, dt, vt, false, true, et);
62+
public Ctable(Lop[] inputLops, OperationTypes op, DataType dt, ValueType vt, ExecType et, int k) {
63+
this(inputLops, op, dt, vt, false, true, et, k);
6364
}
6465

65-
public Ctable(Lop[] inputLops, OperationTypes op, DataType dt, ValueType vt, boolean ignoreZeros, boolean outputEmptyBlocks, ExecType et) {
66+
public Ctable(Lop[] inputLops, OperationTypes op, DataType dt, ValueType vt, boolean ignoreZeros, boolean outputEmptyBlocks, ExecType et, int k) {
6667
super(Lop.Type.Ctable, dt, vt);
6768
init(inputLops, op, et);
6869
_ignoreZeros = ignoreZeros;
6970
_outputEmptyBlocks = outputEmptyBlocks;
71+
_numThreads = k;
7072
}
7173

7274
private void init(Lop[] inputLops, OperationTypes op, ExecType et) {
@@ -175,6 +177,10 @@ public String getInstructions(String input1, String input2, String input3, Strin
175177
sb.append( OPERAND_DELIMITOR );
176178
sb.append( _outputEmptyBlocks );
177179
}
180+
else {
181+
sb.append( OPERAND_DELIMITOR );
182+
sb.append(_numThreads);
183+
}
178184

179185
return sb.toString();
180186
}

src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1987,8 +1987,8 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV
19871987
case DECOMPRESS:
19881988
if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND){
19891989
checkNumParameters(1);
1990-
checkMatrixParam(getFirstExpr());
1991-
output.setDataType(DataType.MATRIX);
1990+
checkMatrixFrameParam(getFirstExpr());
1991+
output.setDataType(getFirstExpr().getOutput().getDataType());
19921992
output.setDimensions(id.getDim1(), id.getDim2());
19931993
output.setBlocksize (id.getBlocksize());
19941994
output.setValueType(id.getValueType());

src/main/java/org/apache/sysds/resource/cost/CostEstimator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ public double parseSPInst(SPInstruction inst) throws CostEstimationException {
822822
SparkCostUtils.getMatMulChainInstTime(mmchaininst, input1, input2, input3, output, driverMetrics, executorMetrics);
823823
} else if (inst instanceof CtableSPInstruction) {
824824
CtableSPInstruction tableInst = (CtableSPInstruction) inst;
825-
VarStats input1 = getStats(tableInst.input1.getName());
825+
VarStats input1 = getStatsWithDefaultScalar(tableInst.input1.getName());
826826
VarStats input2 = getStatsWithDefaultScalar(tableInst.input2.getName());
827827
VarStats input3 = getStatsWithDefaultScalar(tableInst.input3.getName());
828828
double loadTime = loadRDDStatsAndEstimateTime(input1) +

src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import org.apache.sysds.runtime.compress.lib.CLALibMMChain;
5454
import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult;
5555
import org.apache.sysds.runtime.compress.lib.CLALibMerge;
56+
import org.apache.sysds.runtime.compress.lib.CLALibReshape;
5657
import org.apache.sysds.runtime.compress.lib.CLALibRexpand;
5758
import org.apache.sysds.runtime.compress.lib.CLALibScalar;
5859
import org.apache.sysds.runtime.compress.lib.CLALibSlice;
@@ -1281,6 +1282,11 @@ public MatrixBlock transpose(int k) {
12811282
return getUncompressed().transpose(k);
12821283
}
12831284

1285+
@Override
1286+
public MatrixBlock reshape(int rows,int cols, boolean byRow){
1287+
return CLALibReshape.reshape(this, rows, cols, byRow);
1288+
}
1289+
12841290
@Override
12851291
public String toString() {
12861292
StringBuilder sb = new StringBuilder();

src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -595,10 +595,8 @@ public void preAggregateDense(MatrixBlock m, double[] preAgg, int rl, int ru, in
595595
public void leftMMIdentityPreAggregateDense(MatrixBlock that, MatrixBlock ret, int rl, int ru, int cl, int cu) {
596596
DenseBlock db = that.getDenseBlock();
597597
DenseBlock retDB = ret.getDenseBlock();
598-
if(rl == ru - 1)
599-
leftMMIdentityPreAggregateDenseSingleRow(db.values(rl), db.pos(rl), retDB.values(rl), retDB.pos(rl), cl, cu);
600-
else
601-
throw new NotImplementedException();
598+
for(int i = rl; i < ru; i++)
599+
leftMMIdentityPreAggregateDenseSingleRow(db.values(i), db.pos(i), retDB.values(i), retDB.pos(i), cl, cu);
602600
}
603601

604602
@Override
@@ -632,7 +630,8 @@ public void rightDecompressingMult(MatrixBlock right, MatrixBlock ret, int rl, i
632630
}
633631
}
634632

635-
final void vectMM(double aa, double[] b, double[] c, int endT, int jd, int crl, int cru, int offOut, int k, int vLen) {
633+
final void vectMM(double aa, double[] b, double[] c, int endT, int jd, int crl, int cru, int offOut, int k,
634+
int vLen) {
636635
// vVec = vVec.broadcast(aa);
637636
final int offj = k * jd;
638637
final int end = endT + offj;
@@ -919,16 +918,16 @@ public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret,
919918

920919
@Override
921920
protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
922-
// morph(CompressionType.UNCOMPRESSED, _data.size()).sparseSelection(selection, ret, rl, ru);;
923-
final SparseBlock sb = selection.getSparseBlock();
924-
final DenseBlock retB = ret.getDenseBlock();
925-
for(int r = rl; r < ru; r++) {
926-
if(sb.isEmpty(r))
927-
continue;
928-
final int sPos = sb.pos(r);
929-
final int rowCompressed = sb.indexes(r)[sPos]; // column index with 1
930-
decompressToDenseBlock(retB, rowCompressed, rowCompressed + 1, r - rowCompressed, 0);
931-
}
921+
// morph(CompressionType.UNCOMPRESSED, _data.size()).sparseSelection(selection, ret, rl, ru);;
922+
final SparseBlock sb = selection.getSparseBlock();
923+
final DenseBlock retB = ret.getDenseBlock();
924+
for(int r = rl; r < ru; r++) {
925+
if(sb.isEmpty(r))
926+
continue;
927+
final int sPos = sb.pos(r);
928+
final int rowCompressed = sb.indexes(r)[sPos]; // column index with 1
929+
decompressToDenseBlock(retB, rowCompressed, rowCompressed + 1, r - rowCompressed, 0);
930+
}
932931
}
933932

934933

@@ -946,22 +945,21 @@ private void leftMMIdentityPreAggregateDenseSingleRow(double[] values, int pos,
946945
for(int rc = cl; rc < cu; rc++, pos++) {
947946
final int idx = _data.getIndex(rc);
948947
if(idx != nVal)
949-
values2[_colIndexes.get(idx)] += values[pos];
948+
values2[pos2 + _colIndexes.get(idx)] += values[pos];
950949
}
951950
}
952951
else {
953952
for(int rc = cl; rc < cu; rc++, pos++)
954-
values2[_colIndexes.get(_data.getIndex(rc))] += values[pos];
953+
values2[pos2 + _colIndexes.get(_data.getIndex(rc))] += values[pos];
955954
}
956955
}
957956
}
958957

959-
960958
private void leftMMIdentityPreAggregateDenseSingleRowRangeIndex(double[] values, int pos, double[] values2, int pos2,
961959
int cl, int cu) {
962960
IdentityDictionary a = (IdentityDictionary) _dict;
963961

964-
final int firstCol = _colIndexes.get(0);
962+
final int firstCol = pos2 + _colIndexes.get(0);
965963
pos += cl; // left side matrix position offset.
966964
if(a.withEmpty()) {
967965
final int nVal = _dict.getNumberOfValues(_colIndexes.size()) - 1;

0 commit comments

Comments
 (0)