Skip to content

Commit d9ebcf0

Browse files
committed
rewrite ... not allowed if federated in
1 parent 67e43ee commit d9ebcf0

File tree

3 files changed

+39
-35
lines changed

3 files changed

+39
-35
lines changed

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.apache.sysds.common.Types.AggOp;
3131
import org.apache.sysds.common.Types.DataType;
3232
import org.apache.sysds.common.Types.Direction;
33+
import org.apache.sysds.common.Types.ExecType;
3334
import org.apache.sysds.common.Types.OpOp1;
3435
import org.apache.sysds.common.Types.OpOp2;
3536
import org.apache.sysds.common.Types.OpOp3;
@@ -209,6 +210,8 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
209210
//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
210211
if( !descendFirst )
211212
rule_AlgebraicSimplification(hi, descendFirst);
213+
214+
hi = fuseSeqAndTableExpand(hi);
212215
}
213216

214217
hop.setVisited();
@@ -2913,4 +2916,24 @@ private static Hop simplyfyMMCBindZeroVector(Hop parent, Hop hi, int pos) {
29132916
}
29142917
return hi;
29152918
}
2919+
2920+
2921+
private static Hop fuseSeqAndTableExpand(Hop hi) {
2922+
2923+
if(TernaryOp.ALLOW_CTABLE_SEQUENCE_REWRITES && hi instanceof TernaryOp ) {
2924+
TernaryOp thop = (TernaryOp) hi;
2925+
thop.getOp();
2926+
2927+
if(thop.isSequenceRewriteApplicable(true) && thop.findExecTypeTernaryOp() == ExecType.CP &&
2928+
thop.getInput(1).getForcedExecType() != Types.ExecType.FED) {
2929+
Hop input1 = thop.getInput(0);
2930+
if(input1 instanceof DataGenOp){
2931+
Hop literal = new LiteralOp("seq(1, "+input1.getDim1() +")");
2932+
HopRewriteUtils.replaceChildReference(hi, input1, literal);
2933+
}
2934+
}
2935+
2936+
}
2937+
return hi;
2938+
}
29162939
}

src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
import org.apache.sysds.hops.UnaryOp;
4343
import org.apache.sysds.common.Types.AggOp;
4444
import org.apache.sysds.common.Types.Direction;
45-
import org.apache.sysds.common.Types.ExecType;
4645
import org.apache.sysds.common.Types.OpOp1;
4746
import org.apache.sysds.common.Types.OpOp2;
4847
import org.apache.sysds.common.Types.OpOp3;
@@ -199,7 +198,6 @@ private void rule_AlgebraicSimplification(Hop hop, boolean descendFirst)
199198
//hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
200199

201200
hi = fixNonScalarPrint(hop, hi, i); //e.g., print(m) -> print(toString(m))
202-
hi = fuseSeqAndTableExpand(hi);
203201

204202
//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
205203
if( !descendFirst )
@@ -2197,22 +2195,4 @@ private static void removeTWriteTReadPairs(ArrayList<Hop> roots) {
21972195
}
21982196
}
21992197
}
2200-
2201-
private static Hop fuseSeqAndTableExpand(Hop hi) {
2202-
2203-
if(TernaryOp.ALLOW_CTABLE_SEQUENCE_REWRITES && hi instanceof TernaryOp ) {
2204-
TernaryOp thop = (TernaryOp) hi;
2205-
thop.getOp();
2206-
2207-
if(thop.isSequenceRewriteApplicable(true) && thop.findExecTypeTernaryOp() == ExecType.CP) {
2208-
Hop input1 = thop.getInput(0);
2209-
if(input1 instanceof DataGenOp){
2210-
Hop literal = new LiteralOp("seq(1, "+input1.getDim1() +")");
2211-
HopRewriteUtils.replaceChildReference(hi, input1, literal);
2212-
}
2213-
}
2214-
2215-
}
2216-
return hi;
2217-
}
22182198
}

src/test/java/org/apache/sysds/test/component/resource/RecompilationTest.java

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ private void runTestMM(String fileX, String fileY, long driverMemory, int number
235235

236236
// original compilation used for comparison
237237
Program expectedProgram = ResourceCompiler.compile(HOME+"mm_test.dml", nvargs);
238-
Program recompiledProgram = runTest(precompiledProgram, expectedProgram, driverMemory, numberExecutors, executorMemory);
238+
Program recompiledProgram = runTest(precompiledProgram, expectedProgram, driverMemory, numberExecutors, executorMemory, new StringBuilder());
239239

240240
Optional<Instruction> mmInstruction = ((BasicProgramBlock) recompiledProgram.getProgramBlocks().get(0)).getInstructions().stream()
241241
.filter(inst -> (Objects.equals(expectedSparkExecType, inst instanceof SPInstruction) && Objects.equals(inst.getOpcode(), expectedOpcode)))
@@ -257,7 +257,7 @@ private void runTestTSMM(String fileX, long driverMemory, int numberExecutors, l
257257
}
258258
// original compilation used for comparison
259259
Program expectedProgram = ResourceCompiler.compile(HOME+"mm_transpose_test.dml", nvargs);
260-
Program recompiledProgram = runTest(precompiledProgram, expectedProgram, driverMemory, numberExecutors, executorMemory);
260+
Program recompiledProgram = runTest(precompiledProgram, expectedProgram, driverMemory, numberExecutors, executorMemory, new StringBuilder());
261261
Optional<Instruction> mmInstruction = ((BasicProgramBlock) recompiledProgram.getProgramBlocks().get(0)).getInstructions().stream()
262262
.filter(inst -> (Objects.equals(expectedSparkExecType, inst instanceof SPInstruction) && Objects.equals(inst.getOpcode(), expectedOpcode)))
263263
.findFirst();
@@ -273,22 +273,23 @@ private void runTestAlgorithm(String dmlScript, long driverMemory, int numberExe
273273
Map<String, String> nvargs) throws IOException {
274274
// pre-compiled program using default values to be used as source for the recompilation
275275
Program precompiledProgram = generateInitialProgram(HOME+dmlScript, nvargs);
276-
System.out.println("precompiled");
277-
System.out.println(Explain.explain(precompiledProgram));
276+
StringBuilder sb = new StringBuilder();
277+
sb.append("\n\nprecompiled\n");
278+
sb.append(Explain.explain(precompiledProgram));
278279
if (numberExecutors > 0) {
279280
ResourceCompiler.setSparkClusterResourceConfigs(driverMemory, driverThreads, numberExecutors, executorMemory, executorThreads);
280281
} else {
281282
ResourceCompiler.setSingleNodeResourceConfigs(driverMemory, driverThreads);
282283
}
283284
// original compilation used for comparison
284285
Program expectedProgram = ResourceCompiler.compile(HOME+dmlScript, nvargs);
285-
System.out.println("expected");
286-
System.out.println(Explain.explain(expectedProgram));
287-
runTest(precompiledProgram, expectedProgram, driverMemory, numberExecutors, executorMemory);
286+
sb.append("\n\nexpected\n");
287+
sb.append(Explain.explain(expectedProgram));
288+
runTest(precompiledProgram, expectedProgram, driverMemory, numberExecutors, executorMemory, sb);
288289
}
289290

290-
private Program runTest(Program precompiledProgram, Program expectedProgram, long driverMemory, int numberExecutors, long executorMemory) {
291-
if (DEBUG_MODE) System.out.println(Explain.explain(expectedProgram));
291+
private Program runTest(Program precompiledProgram, Program expectedProgram, long driverMemory, int numberExecutors, long executorMemory, StringBuilder sb) {
292+
if (DEBUG_MODE) sb.append(Explain.explain(expectedProgram));
292293
Program recompiledProgram;
293294
if (numberExecutors == 0) {
294295
ResourceCompiler.setSingleNodeResourceConfigs(driverMemory, driverThreads);
@@ -303,19 +304,19 @@ private Program runTest(Program precompiledProgram, Program expectedProgram, lon
303304
);
304305
recompiledProgram = ResourceCompiler.doFullRecompilation(precompiledProgram);
305306
}
306-
System.out.println("recompiled");
307-
System.out.println(Explain.explain(recompiledProgram));
307+
sb.append("\n\nrecompiled\n");
308+
sb.append(Explain.explain(recompiledProgram));
308309

309-
if (DEBUG_MODE) System.out.println(Explain.explain(recompiledProgram));
310-
assertEqualPrograms(expectedProgram, recompiledProgram);
310+
if (DEBUG_MODE) sb.append(Explain.explain(recompiledProgram));
311+
assertEqualPrograms(expectedProgram, recompiledProgram, sb);
311312
return recompiledProgram;
312313
}
313314

314-
private void assertEqualPrograms(Program expected, Program actual) {
315+
private void assertEqualPrograms(Program expected, Program actual, StringBuilder sb) {
315316
// strip empty blocks basic program blocks
316317
String expectedProgramExplained = stripGeneralAndReplaceRandoms(Explain.explain(expected));
317318
String actualProgramExplained = stripGeneralAndReplaceRandoms(Explain.explain(actual));
318-
Assert.assertEquals(expectedProgramExplained, actualProgramExplained);
319+
Assert.assertEquals(sb.toString(), expectedProgramExplained, actualProgramExplained);
319320
}
320321

321322
private String stripGeneralAndReplaceRandoms(String explainedProgram) {

0 commit comments

Comments
 (0)