5555#include "llvm/ADT/MapVector.h"
5656#include <cstddef>
5757#include <iterator>
58+ #include <mlir/IR/Value.h>
5859#include <numeric>
5960#define DEBUG_TYPE "enzymehloopt"
6061
@@ -25445,8 +25446,9 @@ struct DotGeneralToSyrk
2544525446 cast<ElementsAttr>(makeAttr(alphaType, 0))),
2544625447 enzymexla::LapackUploAttr::get(op.getContext(),
2544725448 enzymexla::LapackUplo::F),
25448- enzymexla::LapackTransposeAttr::get(op.getContext(), lapackTranspose),
25449- rewriter.getUnitAttr());
25449+ enzymexla::LapackUploAttr::get(op.getContext(),
25450+ enzymexla::LapackUplo::F),
25451+ enzymexla::LapackTransposeAttr::get(op.getContext(), lapackTranspose));
2545025452 rewriter.replaceOp(op, syrkOp.getResult());
2545125453 return success();
2545225454 }
@@ -25460,31 +25462,26 @@ struct TransposeSyrkToSyrk
2546025462 LogicalResult matchAndRewriteImpl(enzymexla::SyrkOp op,
2546125463 PatternRewriter &rewriter) const {
2546225464 auto input = op.getA();
25463- if (cast<RankedTensorType>(input.getType()).getRank() != 2)
25465+ if (cast<RankedTensorType>(input.getType()).getRank() != 2) {
2546425466 return failure(); // support only rank 2 matrices for now
25467+ }
2546525468
2546625469 auto transposeOp = input.getDefiningOp<stablehlo::TransposeOp>();
25467- if (!transposeOp)
25470+ if (!transposeOp) {
2546825471 return failure();
25472+ }
2546925473
2547025474 auto perm = transposeOp.getPermutation();
25471- if (perm.size() != 2 || perm[0] != 1 || perm[1] != 0)
25475+ if (perm.size() != 2 || perm[0] != 1 || perm[1] != 0) {
2547225476 return failure();
25473-
25474- enzymexla::LapackTranspose lapackTranspose;
25475- switch (op.getTranspose()) {
25476- case enzymexla::LapackTranspose::none:
25477- lapackTranspose = enzymexla::LapackTranspose::transpose;
25478- break;
25479- default:
25480- lapackTranspose = enzymexla::LapackTranspose::none;
2548125477 }
2548225478
2548325479 rewriter.replaceOpWithNewOp<enzymexla::SyrkOp>(
2548425480 op, op.getResult().getType(), transposeOp.getOperand(), op.getC(),
25485- op.getAlpha(), op.getBeta(), op.getUploAttr(),
25486- enzymexla::LapackTransposeAttr::get(op.getContext(), lapackTranspose),
25487- op.getFillAttr());
25481+ op.getAlpha(), op.getBeta(), op.getUploAttr(), op.getOutputUploAttr(),
25482+ enzymexla::LapackTransposeAttr::get(
25483+ op.getContext(),
25484+ enzyme::transposeLapackTranspose(op.getTranspose(), false)));
2548825485 return success();
2548925486 }
2549025487};
@@ -25527,7 +25524,8 @@ struct FuseMulIntoSyrk
2552725524
2552825525 rewriter.replaceOpWithNewOp<enzymexla::SyrkOp>(
2552925526 op, syrkOp.getType(), syrkOp.getA(), syrkOp.getC(), newAlpha, newBeta,
25530- syrkOp.getUploAttr(), syrkOp.getTransposeAttr(), syrkOp.getFillAttr());
25527+ syrkOp.getUploAttr(), syrkOp.getOutputUploAttr(),
25528+ syrkOp.getTransposeAttr());
2553125529 return success();
2553225530 }
2553325531};
@@ -25579,7 +25577,159 @@ struct FuseAddIntoSyrk
2557925577
2558025578 rewriter.replaceOpWithNewOp<enzymexla::SyrkOp>(
2558125579 op, syrkOp.getType(), syrkOp.getA(), newC, syrkOp.getAlpha(), newBeta,
25582- syrkOp.getUploAttr(), syrkOp.getTransposeAttr(), syrkOp.getFillAttr());
25580+ syrkOp.getUploAttr(), syrkOp.getOutputUploAttr(),
25581+ syrkOp.getTransposeAttr());
25582+ return success();
25583+ }
25584+ };
25585+
25586+ struct SyrkSimplifyOutputUplo final
25587+ : public CheckedOpRewritePattern<enzymexla::SyrkOp,
25588+ SyrkSimplifyOutputUplo> {
25589+ using CheckedOpRewritePattern<
25590+ enzymexla::SyrkOp, SyrkSimplifyOutputUplo>::CheckedOpRewritePattern;
25591+
25592+ LogicalResult matchAndRewriteImpl(enzymexla::SyrkOp op,
25593+ PatternRewriter &rewriter) const {
25594+ // we can also try to align the uplos for other cases but that would be
25595+ // less common
25596+ if (op.getOutputUplo() != enzymexla::LapackUplo::F) {
25597+ return rewriter.notifyMatchFailure(op, "output_uplo is not F");
25598+ }
25599+
25600+ // Track all users of syrk. If these end at syrk ops (via supported
25601+ // elementwise operations) we can potentially set a common uplo.
25602+ SmallVector<enzymexla::SyrkOp> childSyrkOps;
25603+
25604+ // Worklist approach: track all children that need to be processed
25605+ SmallVector<Value> worklist;
25606+ llvm::SmallPtrSet<Value, 8> visited;
25607+
25608+ // Start with the result of the current syrk op
25609+ worklist.push_back(op.getResult());
25610+ visited.insert(op.getResult());
25611+
25612+ while (!worklist.empty()) {
25613+ Value current = worklist.pop_back_val();
25614+
25615+ // Check all users of the current value
25616+ for (Operation *user : current.getUsers()) {
25617+ // If user is a syrk op, add it to our list of child syrk ops
25618+ if (auto childSyrk = dyn_cast<enzymexla::SyrkOp>(user)) {
25619+ if (childSyrk.getC() != current) {
25620+ return failure();
25621+ }
25622+ childSyrkOps.push_back(childSyrk);
25623+ continue;
25624+ }
25625+
25626+ // Check if the operation is a supported elementwise operation
25627+ // that preserves symmetry (elementwise ops preserve the symmetric
25628+ // structure)
25629+ if (!stablehlo::hasTraitElementwise(user)) {
25630+ // Found a non-elementwise user that is not a syrk op
25631+ // This is not supported, so we fail
25632+ return rewriter.notifyMatchFailure(
25633+ op, "found non-elementwise, non-syrk user");
25634+ }
25635+
25636+ // For elementwise operations, add their results to the worklist
25637+ for (Value result : user->getResults()) {
25638+ if (!visited.contains(result)) {
25639+ visited.insert(result);
25640+ worklist.push_back(result);
25641+ }
25642+ }
25643+ }
25644+ }
25645+
25646+ // If no child syrk ops were found, nothing to optimize
25647+ if (childSyrkOps.empty()) {
25648+ return rewriter.notifyMatchFailure(op, "no child syrk ops found");
25649+ }
25650+
25651+ // Check the uplos for all child syrk ops
25652+ // Collect all the uplos we see
25653+ bool hasUpper = false, hasLower = false;
25654+ int countOutputUpper = 0, countOutputLower = 0;
25655+
25656+ for (enzymexla::SyrkOp childSyrk : childSyrkOps) {
25657+ enzymexla::LapackUplo childUplo = childSyrk.getUplo();
25658+ switch (childUplo) {
25659+ case enzymexla::LapackUplo::U:
25660+ hasUpper = true;
25661+ break;
25662+ case enzymexla::LapackUplo::L:
25663+ hasLower = true;
25664+ break;
25665+ case enzymexla::LapackUplo::F:
25666+ break;
25667+ }
25668+
25669+ enzymexla::LapackUplo childOutputUplo = childSyrk.getOutputUplo();
25670+ switch (childOutputUplo) {
25671+ case enzymexla::LapackUplo::U:
25672+ countOutputUpper++;
25673+ break;
25674+ case enzymexla::LapackUplo::L:
25675+ countOutputLower++;
25676+ break;
25677+ case enzymexla::LapackUplo::F:
25678+ break;
25679+ }
25680+ }
25681+
25682+ // Check for conflict: if we have both U and L, we cannot satisfy both
25683+ if (hasUpper && hasLower) {
25684+ return rewriter.notifyMatchFailure(
25685+ op, "conflicting uplos among child syrk ops (both U and L)");
25686+ }
25687+
25688+ // Determine the common uplo
25689+ enzymexla::LapackUplo newOutputUplo;
25690+ if (hasUpper) {
25691+ // At least one child requires U, set output to U
25692+ newOutputUplo = enzymexla::LapackUplo::U;
25693+ } else if (hasLower) {
25694+ // At least one child requires L, set output to L
25695+ newOutputUplo = enzymexla::LapackUplo::L;
25696+ } else {
25697+ // All children have uplo F
25698+ // Check the output_uplo of each child to find the uplo that minimizes
25699+ // copying. According to LowerEnzymeXLABlas::resolveUplo:
25700+ // - When uplo is F, needsCopy = (output_uplo == F)
25701+ // - If output_uplo is U or L, no copy is needed when input matches
25702+
25703+ // Choose the uplo that matches the most children's output_uplo
25704+ // This minimizes the number of copies needed
25705+ if (countOutputUpper > countOutputLower) {
25706+ newOutputUplo = enzymexla::LapackUplo::U;
25707+ } else if (countOutputLower > countOutputUpper) {
25708+ newOutputUplo = enzymexla::LapackUplo::L;
25709+ } else {
25710+ // Tied or all have output_uplo F
25711+ // Default to U as per standardizeUplo (see LowerEnzymeXLABlas)
25712+ newOutputUplo = enzymexla::LapackUplo::U;
25713+ }
25714+ }
25715+
25716+ // If the new output uplo is the same as what we already have, no change
25717+ if (newOutputUplo == op.getOutputUplo()) {
25718+ return rewriter.notifyMatchFailure(op, "output_uplo already optimal");
25719+ }
25720+
25721+ auto newUploAttr =
25722+ enzymexla::LapackUploAttr::get(rewriter.getContext(), newOutputUplo);
25723+
25724+ // Create a new syrk op with the updated output_uplo
25725+ rewriter.modifyOpInPlace(op, [&]() { op.setOutputUploAttr(newUploAttr); });
25726+
25727+ // Also update the child syrk ops that have uplo F to use the new uplo
25728+ // This ensures they don't need to copy
25729+ for (enzymexla::SyrkOp &childSyrk : childSyrkOps) {
25730+ rewriter.modifyOpInPlace(childSyrk,
25731+ [&]() { childSyrk.setUploAttr(newUploAttr); });
25732+ }
2558325733 return success();
2558425734 }
2558525735};
@@ -27894,6 +28044,7 @@ struct EnzymeHLOOptPass
2789428044 DotGeneralRemoveBatchDimensions,
2789528045 DUSDynamicSliceSimplify,
2789628046 WhileDUSDSSimplify,
28047+ SyrkSimplifyOutputUplo,
2789728048 WhileDUSDUSSimplify,
2789828049 WhileDUS,
2789928050 DeleteDimsReduce,
0 commit comments