@@ -545,7 +545,7 @@ private static void buildCostFnRecursively(RewriterStatement costFn, Map<Rewrite
545545 // Returns the set of all active statements after the rewrite
546546 private static Set <RewriterStatement > buildRewrite (RewriterStatement newRoot , StringBuilder sb , RewriterAssertions assertions , Map <RewriterStatement , String > vars , final RuleContext ctx , int indentation ) {
547547 Set <RewriterStatement > visited = new HashSet <>();
548- recursivelyBuildNewHop (sb , newRoot , assertions , vars , ctx , indentation , 1 , visited , newRoot .getResultingDataType (ctx ).equals ("FLOAT" ));
548+ recursivelyBuildNewHop (sb , newRoot , assertions , vars , ctx , indentation , 1 , visited , newRoot .getResultingDataType (ctx ).equals ("FLOAT" ), new ArrayList <>() );
549549
550550 return visited ;
551551 }
@@ -561,13 +561,13 @@ private static void removeUnreferencedHops(RewriterStatement oldRoot, Set<Rewrit
561561 }, false );
562562 }
563563
564- private static int recursivelyBuildNewHop (StringBuilder sb , RewriterStatement cur , RewriterAssertions assertions , Map <RewriterStatement , String > vars , final RuleContext ctx , int indentation , int varCtr , Set <RewriterStatement > visited , boolean enforceRootDataType ) {
564+ private static int recursivelyBuildNewHop (StringBuilder sb , RewriterStatement cur , RewriterAssertions assertions , Map <RewriterStatement , String > vars , final RuleContext ctx , int indentation , int varCtr , Set <RewriterStatement > visited , boolean enforceRootDataType , List < String > createdOps ) {
565565 visited .add (cur );
566566 if (vars .containsKey (cur ))
567567 return varCtr ;
568568
569569 for (RewriterStatement child : cur .getOperands ())
570- varCtr = recursivelyBuildNewHop (sb , child , assertions , vars , ctx , indentation , varCtr , visited , false );
570+ varCtr = recursivelyBuildNewHop (sb , child , assertions , vars , ctx , indentation , varCtr , visited , false , createdOps );
571571
572572 if (cur instanceof RewriterDataType ) {
573573 if (cur .isLiteral ()) {
@@ -610,6 +610,7 @@ private static int recursivelyBuildNewHop(StringBuilder sb, RewriterStatement cu
610610 sb .append ("LiteralOp " + name + " = new LiteralOp( " + literalStr + " );\n " );
611611 }
612612 vars .put (cur , name );
613+ createdOps .add (name );
613614 }
614615
615616 return varCtr ;
@@ -620,17 +621,31 @@ private static int recursivelyBuildNewHop(StringBuilder sb, RewriterStatement cu
620621 if (CodeGenUtils .opRequiresBinaryBroadcastingMatch (cur , ctx )) {
621622 // Then we need to validate that broadcasting still works after rearranging
622623 indent (indentation , sb );
623- sb .append ("if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(" + operandRefs [0 ] + ", " + operandRefs [1 ] + ") )\n " );
624+ sb .append ("if ( !RewriterRuntimeUtils.validateBinaryBroadcasting(" + operandRefs [0 ] + ", " + operandRefs [1 ] + ") ) {\n " );
625+ for (String createdOp : createdOps ) {
626+ // Properly remove the references to the newly constructed ops
627+ indent (indentation +1 , sb );
628+ sb .append ("HopRewriteUtils.removeAllChildReferences(" + createdOp + ");\n " );
629+ }
624630 indent (indentation +1 , sb );
625631 sb .append ("return hi;\n " );
632+ indent (indentation , sb );
633+ sb .append ("}\n " );
626634 } else {
627635 List <Integer > matchingDims = CodeGenUtils .matchingDimRequirement (cur , ctx );
628636
629637 if (!matchingDims .isEmpty ()) {
630638 // Then we need to validate that broadcasting still works after rearranging
631- sb .append ("if ( !RewriterRuntimeUtils.hasMatchingDims(" + matchingDims .stream ().map (idx -> operandRefs [idx ]).collect (Collectors .joining (", " )) + ") )\n " );
639+ sb .append ("if ( !RewriterRuntimeUtils.hasMatchingDims(" + matchingDims .stream ().map (idx -> operandRefs [idx ]).collect (Collectors .joining (", " )) + ") ) {\n " );
640+ for (String createdOp : createdOps ) {
641+ // Properly remove the references to the newly constructed ops
642+ indent (indentation +1 , sb );
643+ sb .append ("HopRewriteUtils.removeAllChildReferences(" + createdOp + ");\n " );
644+ }
632645 indent (indentation +1 , sb );
633646 sb .append ("return hi;\n " );
647+ indent (indentation , sb );
648+ sb .append ("}\n " );
634649 }
635650 }
636651
@@ -640,6 +655,7 @@ private static int recursivelyBuildNewHop(StringBuilder sb, RewriterStatement cu
640655 sb .append (opClass + " " + name + " = " + constructor + ";\n " );
641656
642657 vars .put (cur , name );
658+ createdOps .add (name );
643659 }
644660
645661 return varCtr ;
0 commit comments