Skip to content

Commit 2680f71

Browse files
authored
feat: syrk simplifications and fixes (#1781)
* feat: syrk simplifications and fixes * feat: efficient copy
1 parent 426a717 commit 2680f71

File tree

13 files changed

+906
-195
lines changed

13 files changed

+906
-195
lines changed

src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -431,11 +431,19 @@ def SyrkOp: EnzymeXLA_Op<"blas.syrk", [Pure, SameOperandsAndResultElementType]>
431431
let summary = "Multiplication involving a symmetric matrix";
432432

433433
let description = [{
434-
C := alpha*A*A^T + beta*C, or C := alpha*A^T*A + beta*C, where alpha and beta are
435-
scalars. C must be a n x n symmetric matrix.
436-
437-
If `fill` is present, then both the upper and lower triangles of the matrix are filled.
438-
Otherwise the values in the non-uplo part of the matrix are undefined.
434+
C_out := alpha*A*A^T + beta*C, or C_out := alpha*A^T*A + beta*C, where alpha and beta
435+
are scalars. C must be a n x n symmetric matrix.
436+
437+
`output_uplo` determines which part of `C_out` is populated. Accessing the values in
438+
the non-`output_uplo` part of the matrix is undefined behavior.
439+
440+
LAPACK/BLAS routines typically require a single `uplo` attribute and it is implicitly
441+
assumed that the output `uplo` corresponds to the input `uplo`. This means the burden
442+
lies on the user to manually copy data if they need to access the other half of the
443+
matrix. By specifying the `output_uplo` we can perform transformations that analyze the
444+
entire dataflow, and avoid computing/copying half of the tensor all together. Generally,
445+
it is recommended to set this attribute to `enzymexla::LapackUplo::F`, and our passes
446+
will automatically refine this to minimize data copies.
439447
}];
440448

441449
let arguments = (ins
@@ -444,8 +452,8 @@ def SyrkOp: EnzymeXLA_Op<"blas.syrk", [Pure, SameOperandsAndResultElementType]>
444452
TensorFloat:$alpha,
445453
TensorFloat:$beta,
446454
EnzymeXLA_LapackUploAttr:$uplo,
447-
DefaultValuedAttr<EnzymeXLA_LapackTransposeAttr, "::mlir::enzymexla::LapackTranspose::none">:$transpose,
448-
OptionalAttr<UnitAttr>:$fill
455+
EnzymeXLA_LapackUploAttr:$output_uplo,
456+
DefaultValuedAttr<EnzymeXLA_LapackTransposeAttr, "::mlir::enzymexla::LapackTranspose::none">:$transpose
449457
);
450458

451459
let results = (outs
@@ -455,6 +463,8 @@ def SyrkOp: EnzymeXLA_Op<"blas.syrk", [Pure, SameOperandsAndResultElementType]>
455463
let assemblyFormat = [{
456464
$A `,` $C `,` $alpha `,` $beta attr-dict `:` functional-type(operands, results)
457465
}];
466+
467+
let hasVerifier = 1;
458468
}
459469

460470
def TrmmOp: EnzymeXLA_Op<"blas.trmm", [Pure, SameOperandsAndResultElementType]> {

src/enzyme_ad/jax/Dialect/Ops.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,20 @@ LogicalResult enzymexla::MemcpyOp::verify() {
11041104
return success();
11051105
}
11061106

1107+
LogicalResult enzymexla::SyrkOp::verify() {
1108+
auto CType = cast<RankedTensorType>(getC().getType());
1109+
bool isComplex = false;
1110+
if (auto complex_type = dyn_cast<ComplexType>(CType.getElementType())) {
1111+
isComplex = true;
1112+
}
1113+
1114+
if (isComplex && getTranspose() == enzymexla::LapackTranspose::adjoint) {
1115+
return emitOpError("Complex matrix not supported for complex transpose");
1116+
}
1117+
1118+
return success();
1119+
}
1120+
11071121
namespace {
11081122

11091123
/// Erases a common case of copy ops where a destination value is used only by

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 169 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
#include "llvm/ADT/MapVector.h"
5656
#include <cstddef>
5757
#include <iterator>
58+
#include <mlir/IR/Value.h>
5859
#include <numeric>
5960
#define DEBUG_TYPE "enzymehloopt"
6061

@@ -25445,8 +25446,9 @@ struct DotGeneralToSyrk
2544525446
cast<ElementsAttr>(makeAttr(alphaType, 0))),
2544625447
enzymexla::LapackUploAttr::get(op.getContext(),
2544725448
enzymexla::LapackUplo::F),
25448-
enzymexla::LapackTransposeAttr::get(op.getContext(), lapackTranspose),
25449-
rewriter.getUnitAttr());
25449+
enzymexla::LapackUploAttr::get(op.getContext(),
25450+
enzymexla::LapackUplo::F),
25451+
enzymexla::LapackTransposeAttr::get(op.getContext(), lapackTranspose));
2545025452
rewriter.replaceOp(op, syrkOp.getResult());
2545125453
return success();
2545225454
}
@@ -25460,31 +25462,26 @@ struct TransposeSyrkToSyrk
2546025462
LogicalResult matchAndRewriteImpl(enzymexla::SyrkOp op,
2546125463
PatternRewriter &rewriter) const {
2546225464
auto input = op.getA();
25463-
if (cast<RankedTensorType>(input.getType()).getRank() != 2)
25465+
if (cast<RankedTensorType>(input.getType()).getRank() != 2) {
2546425466
return failure(); // support only rank 2 matrices for now
25467+
}
2546525468

2546625469
auto transposeOp = input.getDefiningOp<stablehlo::TransposeOp>();
25467-
if (!transposeOp)
25470+
if (!transposeOp) {
2546825471
return failure();
25472+
}
2546925473

2547025474
auto perm = transposeOp.getPermutation();
25471-
if (perm.size() != 2 || perm[0] != 1 || perm[1] != 0)
25475+
if (perm.size() != 2 || perm[0] != 1 || perm[1] != 0) {
2547225476
return failure();
25473-
25474-
enzymexla::LapackTranspose lapackTranspose;
25475-
switch (op.getTranspose()) {
25476-
case enzymexla::LapackTranspose::none:
25477-
lapackTranspose = enzymexla::LapackTranspose::transpose;
25478-
break;
25479-
default:
25480-
lapackTranspose = enzymexla::LapackTranspose::none;
2548125477
}
2548225478

2548325479
rewriter.replaceOpWithNewOp<enzymexla::SyrkOp>(
2548425480
op, op.getResult().getType(), transposeOp.getOperand(), op.getC(),
25485-
op.getAlpha(), op.getBeta(), op.getUploAttr(),
25486-
enzymexla::LapackTransposeAttr::get(op.getContext(), lapackTranspose),
25487-
op.getFillAttr());
25481+
op.getAlpha(), op.getBeta(), op.getUploAttr(), op.getOutputUploAttr(),
25482+
enzymexla::LapackTransposeAttr::get(
25483+
op.getContext(),
25484+
enzyme::transposeLapackTranspose(op.getTranspose(), false)));
2548825485
return success();
2548925486
}
2549025487
};
@@ -25527,7 +25524,8 @@ struct FuseMulIntoSyrk
2552725524

2552825525
rewriter.replaceOpWithNewOp<enzymexla::SyrkOp>(
2552925526
op, syrkOp.getType(), syrkOp.getA(), syrkOp.getC(), newAlpha, newBeta,
25530-
syrkOp.getUploAttr(), syrkOp.getTransposeAttr(), syrkOp.getFillAttr());
25527+
syrkOp.getUploAttr(), syrkOp.getOutputUploAttr(),
25528+
syrkOp.getTransposeAttr());
2553125529
return success();
2553225530
}
2553325531
};
@@ -25579,7 +25577,159 @@ struct FuseAddIntoSyrk
2557925577

2558025578
rewriter.replaceOpWithNewOp<enzymexla::SyrkOp>(
2558125579
op, syrkOp.getType(), syrkOp.getA(), newC, syrkOp.getAlpha(), newBeta,
25582-
syrkOp.getUploAttr(), syrkOp.getTransposeAttr(), syrkOp.getFillAttr());
25580+
syrkOp.getUploAttr(), syrkOp.getOutputUploAttr(),
25581+
syrkOp.getTransposeAttr());
25582+
return success();
25583+
}
25584+
};
25585+
25586+
struct SyrkSimplifyOutputUplo final
25587+
: public CheckedOpRewritePattern<enzymexla::SyrkOp,
25588+
SyrkSimplifyOutputUplo> {
25589+
using CheckedOpRewritePattern<
25590+
enzymexla::SyrkOp, SyrkSimplifyOutputUplo>::CheckedOpRewritePattern;
25591+
25592+
LogicalResult matchAndRewriteImpl(enzymexla::SyrkOp op,
25593+
PatternRewriter &rewriter) const {
25594+
// we can also try to align the uplos for other cases but that would be
25595+
// less common
25596+
if (op.getOutputUplo() != enzymexla::LapackUplo::F) {
25597+
return rewriter.notifyMatchFailure(op, "output_uplo is not F");
25598+
}
25599+
25600+
// Track all users of syrk. If these end at syrk ops (via supported
25601+
// elementwise operations) we can potentially set a common uplo.
25602+
SmallVector<enzymexla::SyrkOp> childSyrkOps;
25603+
25604+
// Worklist approach: track all children that need to be processed
25605+
SmallVector<Value> worklist;
25606+
llvm::SmallPtrSet<Value, 8> visited;
25607+
25608+
// Start with the result of the current syrk op
25609+
worklist.push_back(op.getResult());
25610+
visited.insert(op.getResult());
25611+
25612+
while (!worklist.empty()) {
25613+
Value current = worklist.pop_back_val();
25614+
25615+
// Check all users of the current value
25616+
for (Operation *user : current.getUsers()) {
25617+
// If user is a syrk op, add it to our list of child syrk ops
25618+
if (auto childSyrk = dyn_cast<enzymexla::SyrkOp>(user)) {
25619+
if (childSyrk.getC() != current) {
25620+
return failure();
25621+
}
25622+
childSyrkOps.push_back(childSyrk);
25623+
continue;
25624+
}
25625+
25626+
// Check if the operation is a supported elementwise operation
25627+
// that preserves symmetry (elementwise ops preserve the symmetric
25628+
// structure)
25629+
if (!stablehlo::hasTraitElementwise(user)) {
25630+
// Found a non-elementwise user that is not a syrk op
25631+
// This is not supported, so we fail
25632+
return rewriter.notifyMatchFailure(
25633+
op, "found non-elementwise, non-syrk user");
25634+
}
25635+
25636+
// For elementwise operations, add their results to the worklist
25637+
for (Value result : user->getResults()) {
25638+
if (!visited.contains(result)) {
25639+
visited.insert(result);
25640+
worklist.push_back(result);
25641+
}
25642+
}
25643+
}
25644+
}
25645+
25646+
// If no child syrk ops were found, nothing to optimize
25647+
if (childSyrkOps.empty()) {
25648+
return rewriter.notifyMatchFailure(op, "no child syrk ops found");
25649+
}
25650+
25651+
// Check the uplos for all child syrk ops
25652+
// Collect all the uplos we see
25653+
bool hasUpper = false, hasLower = false;
25654+
int countOutputUpper = 0, countOutputLower = 0;
25655+
25656+
for (enzymexla::SyrkOp childSyrk : childSyrkOps) {
25657+
enzymexla::LapackUplo childUplo = childSyrk.getUplo();
25658+
switch (childUplo) {
25659+
case enzymexla::LapackUplo::U:
25660+
hasUpper = true;
25661+
break;
25662+
case enzymexla::LapackUplo::L:
25663+
hasLower = true;
25664+
break;
25665+
case enzymexla::LapackUplo::F:
25666+
break;
25667+
}
25668+
25669+
enzymexla::LapackUplo childOutputUplo = childSyrk.getOutputUplo();
25670+
switch (childOutputUplo) {
25671+
case enzymexla::LapackUplo::U:
25672+
countOutputUpper++;
25673+
break;
25674+
case enzymexla::LapackUplo::L:
25675+
countOutputLower++;
25676+
break;
25677+
case enzymexla::LapackUplo::F:
25678+
break;
25679+
}
25680+
}
25681+
25682+
// Check for conflict: if we have both U and L, we cannot satisfy both
25683+
if (hasUpper && hasLower) {
25684+
return rewriter.notifyMatchFailure(
25685+
op, "conflicting uplos among child syrk ops (both U and L)");
25686+
}
25687+
25688+
// Determine the common uplo
25689+
enzymexla::LapackUplo newOutputUplo;
25690+
if (hasUpper) {
25691+
// At least one child requires U, set output to U
25692+
newOutputUplo = enzymexla::LapackUplo::U;
25693+
} else if (hasLower) {
25694+
// At least one child requires L, set output to L
25695+
newOutputUplo = enzymexla::LapackUplo::L;
25696+
} else {
25697+
// All children have uplo F
25698+
// Check the output_uplo of each child to find the uplo that minimizes
25699+
// copying. According to LowerEnzymeXLABlas::resolveUplo:
25700+
// - When uplo is F, needsCopy = (output_uplo == F)
25701+
// - If output_uplo is U or L, no copy is needed when input matches
25702+
25703+
// Choose the uplo that matches the most children's output_uplo
25704+
// This minimizes the number of copies needed
25705+
if (countOutputUpper > countOutputLower) {
25706+
newOutputUplo = enzymexla::LapackUplo::U;
25707+
} else if (countOutputLower > countOutputUpper) {
25708+
newOutputUplo = enzymexla::LapackUplo::L;
25709+
} else {
25710+
// Tied or all have output_uplo F
25711+
// Default to U as per standardizeUplo (see LowerEnzymeXLABlas)
25712+
newOutputUplo = enzymexla::LapackUplo::U;
25713+
}
25714+
}
25715+
25716+
// If the new output uplo is the same as what we already have, no change
25717+
if (newOutputUplo == op.getOutputUplo()) {
25718+
return rewriter.notifyMatchFailure(op, "output_uplo already optimal");
25719+
}
25720+
25721+
auto newUploAttr =
25722+
enzymexla::LapackUploAttr::get(rewriter.getContext(), newOutputUplo);
25723+
25724+
// Create a new syrk op with the updated output_uplo
25725+
rewriter.modifyOpInPlace(op, [&]() { op.setOutputUploAttr(newUploAttr); });
25726+
25727+
// Also update the child syrk ops that have uplo F to use the new uplo
25728+
// This ensures they don't need to copy
25729+
for (enzymexla::SyrkOp &childSyrk : childSyrkOps) {
25730+
rewriter.modifyOpInPlace(childSyrk,
25731+
[&]() { childSyrk.setUploAttr(newUploAttr); });
25732+
}
2558325733
return success();
2558425734
}
2558525735
};
@@ -27894,6 +28044,7 @@ struct EnzymeHLOOptPass
2789428044
DotGeneralRemoveBatchDimensions,
2789528045
DUSDynamicSliceSimplify,
2789628046
WhileDUSDSSimplify,
28047+
SyrkSimplifyOutputUplo,
2789728048
WhileDUSDUSSimplify,
2789828049
WhileDUS,
2789928050
DeleteDimsReduce,

src/enzyme_ad/jax/Passes/LinalgUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#define ENZYMEXLA_LINALGUTILS_H
33

44
#include "mlir/IR/Attributes.h"
5-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
5+
#include "mlir/IR/PatternMatch.h"
66
#include "llvm/ADT/SmallVector.h"
77

88
llvm::SmallVector<int64_t> columnMajorMatrixLayout(int64_t ndim);

0 commit comments

Comments
 (0)