Skip to content

Commit f7a4f3c

Browse files
committed
rewrite introduced ... maybe not working in spark and federated
1 parent a546751 commit f7a4f3c

File tree

5 files changed

+32
-5
lines changed

5 files changed

+32
-5
lines changed

src/main/java/org/apache/sysds/hops/TernaryOp.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,9 @@ private boolean isSequenceRewriteApplicable( boolean left )
651651
{
652652
Hop input1 = getInput().get(0);
653653
Hop input2 = getInput().get(1);
654-
if( input1.getDataType() == DataType.MATRIX && input2.getDataType() == DataType.MATRIX )
654+
if( (input1.getDataType() == DataType.MATRIX
655+
|| input1.getDataType() == DataType.SCALAR )
656+
&& input2.getDataType() == DataType.MATRIX )
655657
{
656658
//probe rewrite on left input
657659
if( left && input1 instanceof DataGenOp )
@@ -663,6 +665,9 @@ private boolean isSequenceRewriteApplicable( boolean left )
663665
|| dgop.getIncrementValue()==1.0; //set by recompiler
664666
}
665667
}
668+
if( left && input1 instanceof LiteralOp){
669+
ret = true;
670+
}
666671
//probe rewrite on right input
667672
if( !left && input2 instanceof DataGenOp )
668673
{

src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1397,7 +1397,7 @@ public static boolean isSequenceSizeOfA(Hop hop, Hop A)
13971397
{
13981398
boolean ret = false;
13991399

1400-
if((hop instanceof DataGenOp) && hop.getExecType() == ExecType.CP && A.getExecType() != ExecType.CP ) {
1400+
if((hop instanceof DataGenOp)) {
14011401
DataGenOp dgop = (DataGenOp) hop;
14021402
if(dgop.getOp() == OpOpDG.SEQ) {
14031403
Hop from = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_FROM));

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
198198
//hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
199199

200200
hi = fixNonScalarPrint(hop, hi, i); //e.g., print(m) -> print(toString(m))
201+
hi = fuseSeqAndTableExpand(hi);
201202

202203
//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
203204
if( !descendFirst )
@@ -2195,4 +2196,24 @@ private static void removeTWriteTReadPairs(ArrayList<Hop> roots) {
21952196
}
21962197
}
21972198
}
2199+
2200+
private static Hop fuseSeqAndTableExpand(Hop hi) {
2201+
2202+
if(hi instanceof TernaryOp) {
2203+
TernaryOp thop = (TernaryOp) hi;
2204+
thop.getOp();
2205+
if(thop.getOp() == OpOp3.CTABLE) {
2206+
Hop input1 = thop.getInput(0);
2207+
Hop input2 = thop.getInput(1);
2208+
// Hop input3 = thop.getInput().size() == 3 ? thop.getInput(2) : null;
2209+
2210+
if(HopRewriteUtils.isSequenceSizeOfA(input1, input2)) {
2211+
Hop literal = new LiteralOp(input1.getDim1());
2212+
HopRewriteUtils.replaceChildReference(hi, input1, literal);
2213+
}
2214+
}
2215+
2216+
}
2217+
return hi;
2218+
}
21982219
}

src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@ private CtableFEDInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOper
6666
}
6767

6868
public static CtableFEDInstruction parseInstruction(CtableCPInstruction inst, ExecutionContext ec) {
69-
if((inst.getOpcode().equalsIgnoreCase("ctable") || inst.getOpcode().equalsIgnoreCase("ctableexpand")) &&
70-
(ec.getCacheableData(inst.input1).isFederated(FType.ROW) ||
69+
if((inst.getOpcode().equalsIgnoreCase("ctable")
70+
|| inst.getOpcode().equalsIgnoreCase("ctableexpand")) &&
71+
(inst.input1.isMatrix() && ec.getCacheableData(inst.input1).isFederated(FType.ROW) ||
7172
(inst.input2.isMatrix() && ec.getCacheableData(inst.input2).isFederated(FType.ROW)) ||
7273
(inst.input3.isMatrix() && ec.getCacheableData(inst.input3).isFederated(FType.ROW))))
7374
return CtableFEDInstruction.parseInstruction(inst);

src/test/java/org/apache/sysds/test/functions/compress/wordembedding/WordEmbeddingUseCase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ public void wordEmb(int rows, int unique, int l, int embeddingSize, ExecType ins
102102
writeBinaryWithMTD("W", W);
103103

104104
runTest(null);
105-
105+
106106
MatrixBlock R = TestUtils.readBinary(output("R"));
107107

108108
analyzeResult(X, W, R, l);

0 commit comments

Comments
 (0)