Skip to content

Commit a0e1a17

Browse files
committed
Bugfix in CodeGen
To properly remove references to created operators if broadcasting checks fail
1 parent 4be1bd9 commit a0e1a17

File tree

2 files changed

+118
-42
lines changed

2 files changed

+118
-42
lines changed

src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ private static void buildCostFnRecursively(RewriterStatement costFn, Map<Rewrite
545545
// Returns the set of all active statements after the rewrite
546546
private static Set<RewriterStatement> buildRewrite(RewriterStatement newRoot, StringBuilder sb, RewriterAssertions assertions, Map<RewriterStatement, String> vars, final RuleContext ctx, int indentation) {
547547
Set<RewriterStatement> visited = new HashSet<>();
548-
recursivelyBuildNewHop(sb, newRoot, assertions, vars, ctx, indentation, 1, visited, newRoot.getResultingDataType(ctx).equals("FLOAT"));
548+
recursivelyBuildNewHop(sb, newRoot, assertions, vars, ctx, indentation, 1, visited, newRoot.getResultingDataType(ctx).equals("FLOAT"), new ArrayList<>());
549549

550550
return visited;
551551
}
@@ -561,13 +561,13 @@ private static void removeUnreferencedHops(RewriterStatement oldRoot, Set<Rewrit
561561
}, false);
562562
}
563563

564-
private static int recursivelyBuildNewHop(StringBuilder sb, RewriterStatement cur, RewriterAssertions assertions, Map<RewriterStatement, String> vars, final RuleContext ctx, int indentation, int varCtr, Set<RewriterStatement> visited, boolean enforceRootDataType) {
564+
private static int recursivelyBuildNewHop(StringBuilder sb, RewriterStatement cur, RewriterAssertions assertions, Map<RewriterStatement, String> vars, final RuleContext ctx, int indentation, int varCtr, Set<RewriterStatement> visited, boolean enforceRootDataType, List<String> createdOps) {
565565
visited.add(cur);
566566
if (vars.containsKey(cur))
567567
return varCtr;
568568

569569
for (RewriterStatement child : cur.getOperands())
570-
varCtr = recursivelyBuildNewHop(sb, child, assertions, vars, ctx, indentation, varCtr, visited, false);
570+
varCtr = recursivelyBuildNewHop(sb, child, assertions, vars, ctx, indentation, varCtr, visited, false, createdOps);
571571

572572
if (cur instanceof RewriterDataType) {
573573
if (cur.isLiteral()) {
@@ -610,6 +610,7 @@ private static int recursivelyBuildNewHop(StringBuilder sb, RewriterStatement cu
610610
sb.append("LiteralOp " + name + " = new LiteralOp( " + literalStr + " );\n");
611611
}
612612
vars.put(cur, name);
613+
createdOps.add(name);
613614
}
614615

615616
return varCtr;
@@ -620,17 +621,31 @@ private static int recursivelyBuildNewHop(StringBuilder sb, RewriterStatement cu
620621
if (CodeGenUtils.opRequiresBinaryBroadcastingMatch(cur, ctx)) {
621622
// Then we need to validate that broadcasting still works after rearranging
622623
indent(indentation, sb);
623-
sb.append("if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(" + operandRefs[0] + ", " + operandRefs[1] + ") )\n");
624+
sb.append("if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(" + operandRefs[0] + ", " + operandRefs[1] + ") ) {\n");
625+
for (String createdOp : createdOps) {
626+
// Properly remove the references to the newly constructed ops
627+
indent(indentation+1, sb);
628+
sb.append("HopRewriteUtils.removeAllChildReferences(" + createdOp + ");\n");
629+
}
624630
indent(indentation+1, sb);
625631
sb.append("return hi;\n");
632+
indent(indentation, sb);
633+
sb.append("}\n");
626634
} else {
627635
List<Integer> matchingDims = CodeGenUtils.matchingDimRequirement(cur, ctx);
628636

629637
if (!matchingDims.isEmpty()) {
630638
// Then we need to validate that broadcasting still works after rearranging
631-
sb.append("if ( !RewriterRuntimeUtils.hasMatchingDims(" + matchingDims.stream().map(idx -> operandRefs[idx]).collect(Collectors.joining(", ")) + ") )\n");
639+
sb.append("if ( !RewriterRuntimeUtils.hasMatchingDims(" + matchingDims.stream().map(idx -> operandRefs[idx]).collect(Collectors.joining(", ")) + ") ) {\n");
640+
for (String createdOp : createdOps) {
641+
// Properly remove the references to the newly constructed ops
642+
indent(indentation+1, sb);
643+
sb.append("HopRewriteUtils.removeAllChildReferences(" + createdOp + ");\n");
644+
}
632645
indent(indentation+1, sb);
633646
sb.append("return hi;\n");
647+
indent(indentation, sb);
648+
sb.append("}\n");
634649
}
635650
}
636651

@@ -640,6 +655,7 @@ private static int recursivelyBuildNewHop(StringBuilder sb, RewriterStatement cu
640655
sb.append(opClass + " " + name + " = " + constructor + ";\n");
641656

642657
vars.put(cur, name);
658+
createdOps.add(name);
643659
}
644660

645661
return varCtr;

0 commit comments

Comments
 (0)