Skip to content

Commit 197e4fe

Browse files
committed
spark and cp support
1 parent f7a4f3c commit 197e4fe

File tree

4 files changed

+51
-17
lines changed

4 files changed

+51
-17
lines changed

src/main/java/org/apache/sysds/hops/TernaryOp.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,10 @@ protected DataCharacteristics inferOutputCharacteristics( MemoTable memo )
480480
}
481481

482482

483+
public ExecType findExecTypeTernaryOp(){
484+
return _etype == null ? optFindExecType(OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE) : _etype;
485+
}
486+
483487
@Override
484488
protected ExecType optFindExecType(boolean transitive)
485489
{

src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import org.apache.sysds.api.DMLScript;
2626
import org.apache.sysds.common.Types.DataType;
2727
import org.apache.sysds.common.Types.ExecMode;
28-
import org.apache.sysds.common.Types.ExecType;
2928
import org.apache.sysds.common.Types.FileFormat;
3029
import org.apache.sysds.common.Types.OpOp1;
3130
import org.apache.sysds.common.Types.OpOp2;
@@ -1397,7 +1396,7 @@ public static boolean isSequenceSizeOfA(Hop hop, Hop A)
13971396
{
13981397
boolean ret = false;
13991398

1400-
if((hop instanceof DataGenOp)) {
1399+
if((hop instanceof DataGenOp) ) {
14011400
DataGenOp dgop = (DataGenOp) hop;
14021401
if(dgop.getOp() == OpOpDG.SEQ) {
14031402
Hop from = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_FROM));

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,19 @@
2727
import java.util.List;
2828
import java.util.Set;
2929

30+
import org.apache.sysds.common.Types.AggOp;
31+
import org.apache.sysds.common.Types.DataType;
32+
import org.apache.sysds.common.Types.Direction;
33+
import org.apache.sysds.common.Types.ExecType;
34+
import org.apache.sysds.common.Types.OpOp1;
35+
import org.apache.sysds.common.Types.OpOp2;
36+
import org.apache.sysds.common.Types.OpOp3;
37+
import org.apache.sysds.common.Types.OpOpDG;
38+
import org.apache.sysds.common.Types.OpOpData;
39+
import org.apache.sysds.common.Types.OpOpN;
40+
import org.apache.sysds.common.Types.ParamBuiltinOp;
41+
import org.apache.sysds.common.Types.ReOrgOp;
42+
import org.apache.sysds.common.Types.ValueType;
3043
import org.apache.sysds.hops.AggBinaryOp;
3144
import org.apache.sysds.hops.AggUnaryOp;
3245
import org.apache.sysds.hops.BinaryOp;
@@ -40,20 +53,8 @@
4053
import org.apache.sysds.hops.ReorgOp;
4154
import org.apache.sysds.hops.TernaryOp;
4255
import org.apache.sysds.hops.UnaryOp;
43-
import org.apache.sysds.common.Types.AggOp;
44-
import org.apache.sysds.common.Types.Direction;
45-
import org.apache.sysds.common.Types.OpOp1;
46-
import org.apache.sysds.common.Types.OpOp2;
47-
import org.apache.sysds.common.Types.OpOp3;
48-
import org.apache.sysds.common.Types.OpOpDG;
49-
import org.apache.sysds.common.Types.OpOpData;
50-
import org.apache.sysds.common.Types.OpOpN;
51-
import org.apache.sysds.common.Types.ParamBuiltinOp;
52-
import org.apache.sysds.common.Types.ReOrgOp;
5356
import org.apache.sysds.parser.DataExpression;
5457
import org.apache.sysds.parser.Statement;
55-
import org.apache.sysds.common.Types.DataType;
56-
import org.apache.sysds.common.Types.ValueType;
5758

5859
/**
5960
* Rule: Algebraic Simplifications. Simplifies binary expressions
@@ -2199,10 +2200,10 @@ private static void removeTWriteTReadPairs(ArrayList<Hop> roots) {
21992200

22002201
private static Hop fuseSeqAndTableExpand(Hop hi) {
22012202

2202-
if(hi instanceof TernaryOp) {
2203+
if(hi instanceof TernaryOp ) {
22032204
TernaryOp thop = (TernaryOp) hi;
22042205
thop.getOp();
2205-
if(thop.getOp() == OpOp3.CTABLE) {
2206+
if(thop.getOp() == OpOp3.CTABLE && thop.findExecTypeTernaryOp() == ExecType.CP) {
22062207
Hop input1 = thop.getInput(0);
22072208
Hop input2 = thop.getInput(1);
22082209
// Hop input3 = thop.getInput().size() == 3 ? thop.getInput(2) : null;

src/test/java/org/apache/sysds/test/functions/compress/wordembedding/WordEmbeddingUseCase.java

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import static org.junit.Assert.assertEquals;
2323
import static org.junit.Assert.assertTrue;
24+
import static org.junit.Assert.fail;
2425

2526
import java.io.File;
2627

@@ -80,6 +81,31 @@ public void testWordEmb_moreUniqueWordsThanSentences() {
8081
wordEmb(100, 200, 5, 2, ExecType.CP, "01");
8182
}
8283

84+
@Test
85+
public void testWordEmbSP() {
86+
wordEmb(10, 2, 2, 2, ExecType.SPARK, "01");
87+
}
88+
89+
@Test
90+
public void testWordEmb_mediumSP() {
91+
wordEmb(100, 30, 4, 3, ExecType.SPARK, "01");
92+
}
93+
94+
@Test
95+
public void testWordEmb_bigWordsSP() {
96+
wordEmb(10, 2, 2, 10, ExecType.SPARK, "01");
97+
}
98+
99+
@Test
100+
public void testWordEmb_longSentencesSP() {
101+
wordEmb(100, 30, 5, 2, ExecType.SPARK, "01");
102+
}
103+
104+
@Test
105+
public void testWordEmb_moreUniqueWordsThanSentencesSP() {
106+
wordEmb(100, 200, 5, 2, ExecType.SPARK, "01");
107+
}
108+
83109
public void wordEmb(int rows, int unique, int l, int embeddingSize, ExecType instType, String name) {
84110

85111
OptimizerUtils.ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND = true;
@@ -101,12 +127,16 @@ public void wordEmb(int rows, int unique, int l, int embeddingSize, ExecType ins
101127
MatrixBlock W = TestUtils.generateTestMatrixBlock(unique, embeddingSize, 1.0, -1, 1, 32);
102128
writeBinaryWithMTD("W", W);
103129

104-
runTest(null);
130+
String r = runTest(null).toString();
105131

106132
MatrixBlock R = TestUtils.readBinary(output("R"));
107133

108134
analyzeResult(X, W, R, l);
109135

136+
if( instType == ExecType.CP && heavyHittersContainsString("seq")){
137+
fail("cp should not have seq instruction\n" + r);
138+
}
139+
110140
}
111141
catch(Exception e) {
112142
e.printStackTrace();

0 commit comments

Comments
 (0)