Skip to content

Commit 0980456

Browse files
committed
fix
1 parent 828916b commit 0980456

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/enzyme_ad/jax/Passes/LowerEnzymeXLABLAS.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -485,10 +485,10 @@ struct SyrkOpLowering : public OpRewritePattern<enzymexla::SyrkOp> {
485485
rewriter, op.getLoc(), cast<RankedTensorType>(op.getC().getType()),
486486
op.getA(), op.getA(), dotDims, nullptr, nullptr);
487487

488-
auto res = stablehlo::AddOpCreate(
489-
rewriter, op->getLoc(),
490-
stablehlo::MulOpCreate(rewriter, op->getLoc(), op.getAlpha(), AAT),
491-
stablehlo::MulOpCreate(rewriter, op->getLoc(), op.getBeta(), C));
488+
auto aop = stablehlo::MulOpCreate(rewriter, op->getLoc(), op.getAlpha(), AAT);
489+
auto bop = stablehlo::MulOpCreate(rewriter, op->getLoc(), op.getBeta(), C);
490+
491+
auto res = stablehlo::AddOpCreate(rewriter, op->getLoc(), aop, bop);
492492
rewriter.replaceOp(op, res);
493493
return success();
494494
}

0 commit comments

Comments
 (0)