Skip to content

Commit 546c4ed

Browse files
snonkavik-pal
andauthored
feat: symmetric tensor detection + transpose(symmetric) simplify (#1549)
* added symm op * fmt * wip * Update src/enzyme_ad/jax/Utils.cpp Co-authored-by: Avik Pal <[email protected]> * chore: run fmt * feat: transpose symmetric simplify * feat: generalize is_commutative check * fix: more checks + update tests * feat: generalizes the constant check --------- Co-authored-by: Avik Pal <[email protected]>
1 parent 044850d commit 546c4ed

File tree

12 files changed

+273
-42
lines changed

12 files changed

+273
-42
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6907,6 +6907,36 @@ struct SubSimplify
69076907
}
69086908
};
69096909

6910+
struct TransposeSymmetricSimplify
6911+
: public CheckedOpRewritePattern<stablehlo::TransposeOp,
6912+
TransposeSymmetricSimplify> {
6913+
using CheckedOpRewritePattern<
6914+
stablehlo::TransposeOp,
6915+
TransposeSymmetricSimplify>::CheckedOpRewritePattern;
6916+
6917+
LogicalResult matchAndRewriteImpl(stablehlo::TransposeOp op,
6918+
PatternRewriter &rewriter) const {
6919+
auto defOp = op.getOperand().getDefiningOp();
6920+
if (!defOp)
6921+
return rewriter.notifyMatchFailure(op, "no defining op");
6922+
6923+
auto perm = op.getPermutation();
6924+
if (perm.size() != 2 || perm[0] != 1 || perm[1] != 0)
6925+
return failure();
6926+
6927+
auto resTy = cast<RankedTensorType>(op.getResult().getType());
6928+
if (resTy.getRank() != 2 || resTy.getDimSize(0) != resTy.getDimSize(1))
6929+
return failure(); // quick check and exit
6930+
6931+
if (canApplySymmetricPattern(
6932+
defOp, rewriter)) { // tranpose(symmetric) -> symmetric
6933+
rewriter.replaceOp(op, op.getOperand());
6934+
return success();
6935+
}
6936+
return failure();
6937+
}
6938+
};
6939+
69106940
struct NoNanSelfSubSimplify
69116941
: public NoNanCheckedOpRewritePattern<stablehlo::SubtractOp,
69126942
NoNanSelfSubSimplify> {
@@ -25994,6 +26024,8 @@ struct EnzymeHLOOptPass
2599426024
NoNanAddSubSimplify, NoNanMulSimplify, NoNanDivSimplify>(
2599526025
(no_nan || all_finite), context);
2599626026

26027+
patterns.add<TransposeSymmetricSimplify>(context);
26028+
2599726029
// clang-format off
2599826030
patterns.add<
2599926031
WhileRepeatedInductionReduction,

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,11 @@ def ApplyNoNanZeroBasePowSimplify : EnzymeHLOParameterizedPatternOp<
599599
}];
600600
}
601601

602+
def ApplyTransposeSymmetricSimplify : EnzymeHLOPatternOp<
603+
"transpose_symmetric_simplify"> {
604+
let patterns = ["TransposeSymmetricSimplify"];
605+
}
606+
602607
def ApplyTransposeElementwisePatterns : EnzymeHLOParameterizedPatternOp<
603608
"transpose_elementwise"> {
604609
let arguments = (ins OptionalAttr<I64Attr>:$benefit, BoolAttr:$parameter);

src/enzyme_ad/jax/Utils.cpp

Lines changed: 138 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,135 @@ bool canApplyNoNanPattern(bool allowOnFloatingPointMath, Type outTy, Type inTy,
584584
return allowOnFloatingPointMath || guaranteedNoNanResult(op, rewriter);
585585
}
586586

587+
bool canApplySymmetricPattern(mlir::Operation *op, PatternRewriter &rewriter) {
588+
return guaranteedSymmetricResult(op, rewriter);
589+
}
590+
591+
SymmetricResultAnalysis initSymmetricResultAnalysis() {
592+
return SymmetricResultAnalysis();
593+
}
594+
595+
bool SymmetricResultAnalysis::constantIntCheck(DenseElementsAttr attr) {
596+
return false; // TODO
597+
}
598+
599+
bool SymmetricResultAnalysis::constantFloatCheck(DenseElementsAttr attr) {
600+
return false; // TODO
601+
}
602+
603+
SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
604+
Operation *op, SmallVectorImpl<Operation *> &localtodo,
605+
PatternRewriter &rewriter) {
606+
assert(op);
607+
608+
auto outTy = cast<RankedTensorType>(op->getResult(0).getType());
609+
if (outTy.getRank() != 2)
610+
return State::NOTGUARANTEED; // this pass only checks for symmetric matrices
611+
if (outTy.getDimSize(0) != outTy.getDimSize(1))
612+
return State::NOTGUARANTEED; // quick check and exit
613+
614+
SplatElementsAttr splatAttr;
615+
if (matchPattern(op, m_Constant(&splatAttr))) {
616+
return State::GUARANTEED;
617+
}
618+
619+
DenseElementsAttr denseAttr;
620+
if (matchPattern(op, m_Constant(&denseAttr))) {
621+
if (guaranteedConstantOp(op, denseAttr, rewriter)) {
622+
return State::GUARANTEED;
623+
} else {
624+
return State::NOTGUARANTEED;
625+
}
626+
}
627+
628+
// check that transpose dimensions are [1,0]
629+
auto isTrueTranspose = [](stablehlo::TransposeOp tOp) -> bool {
630+
auto perm = tOp.getPermutation();
631+
return perm.size() == 2 && perm[0] == 1 && perm[1] == 0;
632+
};
633+
634+
// TODO: check for dot_general as well
635+
636+
// commutative operation with A and A^T will always be symmetric
637+
// op(A, A^T) will also always be symmetric
638+
if (stablehlo::hasTraitElementwise(op) &&
639+
(op->hasTrait<OpTrait::IsCommutative>() ||
640+
op->hasTrait<hlo::OpTrait::IsCommutative>())) {
641+
auto lhs = op->getOperand(0);
642+
auto rhs = op->getOperand(1);
643+
644+
// op(A, A^T)
645+
if (auto rhsT = rhs.getDefiningOp<stablehlo::TransposeOp>()) {
646+
if (isTrueTranspose(rhsT)) {
647+
if (lhs == rhsT.getOperand()) {
648+
return State::GUARANTEED;
649+
}
650+
}
651+
}
652+
653+
// op(A^T, A)
654+
if (auto lhsT = lhs.getDefiningOp<stablehlo::TransposeOp>()) {
655+
if (isTrueTranspose(lhsT)) {
656+
if (rhs == lhsT.getOperand()) {
657+
return State::GUARANTEED;
658+
}
659+
}
660+
}
661+
}
662+
663+
bool recursiveCheck = false;
664+
665+
// elementwise ops
666+
if (stablehlo::hasTraitElementwise(op)) {
667+
recursiveCheck = true;
668+
}
669+
670+
/**
671+
* TODO
672+
* - check if its * 0 -> symmetric
673+
*/
674+
675+
if (recursiveCheck) {
676+
bool allOperandsGuaranteed = true;
677+
for (auto operand : op->getOperands()) {
678+
{
679+
auto found = valueCache.find(operand);
680+
if (found != valueCache.end()) {
681+
if (found->second) {
682+
continue;
683+
} else {
684+
return State::NOTGUARANTEED;
685+
}
686+
}
687+
}
688+
auto dop = operand.getDefiningOp();
689+
if (!dop)
690+
return State::NOTGUARANTEED;
691+
692+
{
693+
auto found = opCache.find(dop);
694+
if (found != opCache.end()) {
695+
if (found->second) {
696+
continue;
697+
} else {
698+
return State::NOTGUARANTEED;
699+
}
700+
}
701+
}
702+
703+
localtodo.push_back(dop);
704+
allOperandsGuaranteed = false;
705+
}
706+
707+
if (allOperandsGuaranteed)
708+
return State::GUARANTEED;
709+
else
710+
return State::PENDING;
711+
} else {
712+
return State::NOTGUARANTEED;
713+
}
714+
}
715+
587716
NoNanResultAnalysis initNoNanResultAnalysis() {
588717
auto finiteAnalysis = std::make_shared<FiniteResultAnalysis>();
589718
auto noNanAnalysis = std::make_shared<NoNanResultAnalysis>();
@@ -617,8 +746,9 @@ NoNanResultAnalysis::localGuaranteed(Operation *op,
617746
return State::NOTGUARANTEED;
618747
}
619748

620-
if (auto constantOp = dyn_cast<stablehlo::ConstantOp>(op)) {
621-
if (guaranteed(constantOp, rewriter)) {
749+
DenseElementsAttr denseAttr;
750+
if (matchPattern(op, m_Constant(&denseAttr))) {
751+
if (guaranteedConstantOp(op, denseAttr, rewriter)) {
622752
return State::GUARANTEED;
623753
} else {
624754
return State::NOTGUARANTEED;
@@ -759,8 +889,9 @@ FiniteResultAnalysis::localGuaranteed(Operation *op,
759889
return State::NOTGUARANTEED;
760890
}
761891

762-
if (auto constantOp = dyn_cast<stablehlo::ConstantOp>(op)) {
763-
if (guaranteed(constantOp, rewriter)) {
892+
DenseElementsAttr denseAttr;
893+
if (matchPattern(op, m_Constant(&denseAttr))) {
894+
if (guaranteedConstantOp(op, denseAttr, rewriter)) {
764895
return State::GUARANTEED;
765896
} else {
766897
return State::NOTGUARANTEED;
@@ -871,8 +1002,9 @@ NonNegativeResultAnalysis::State NonNegativeResultAnalysis::localGuaranteed(
8711002
return State::NOTGUARANTEED;
8721003
}
8731004

874-
if (auto constantOp = dyn_cast<stablehlo::ConstantOp>(op)) {
875-
if (guaranteed(constantOp, rewriter)) {
1005+
DenseElementsAttr denseAttr;
1006+
if (matchPattern(op, m_Constant(&denseAttr))) {
1007+
if (guaranteedConstantOp(op, denseAttr, rewriter)) {
8761008
return State::GUARANTEED;
8771009
} else {
8781010
return State::NOTGUARANTEED;

src/enzyme_ad/jax/Utils.h

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,8 @@ bool canApplyNoNanPattern(bool allowOnFloatingPointMath, Type outTy, Type inTy);
313313
bool canApplyNoNanPattern(bool allowOnFloatingPointMath, Type outTy, Type inTy,
314314
mlir::Operation *op, PatternRewriter &rewriter);
315315

316+
bool canApplySymmetricPattern(mlir::Operation *op, PatternRewriter &rewriter);
317+
316318
template <typename Child> class GuaranteedResultAnalysisBase {
317319
protected:
318320
llvm::DenseMap<mlir::Value, bool> valueCache;
@@ -492,51 +494,47 @@ template <typename Child> class GuaranteedResultAnalysisBase {
492494
return false;
493495
}
494496

495-
bool guaranteed(stablehlo::ConstantOp constOp, PatternRewriter &rewriter) {
496-
if (!constOp)
497+
bool guaranteedConstantOp(Operation *op, DenseElementsAttr denseAttr,
498+
PatternRewriter &rewriter) {
499+
if (!op)
497500
return false;
498501

499502
auto attrName = ((Child *)this)->getAttrName();
500-
if (auto boolAttr = constOp->getAttrOfType<mlir::BoolAttr>(attrName)) {
503+
if (auto boolAttr = op->getAttrOfType<mlir::BoolAttr>(attrName)) {
501504
if (boolAttr.getValue())
502505
return true;
503506
else
504507
return false;
505508
}
506509

507-
auto it = opCache.find(constOp);
510+
auto it = opCache.find(op);
508511
if (it != opCache.end())
509512
return it->second;
510513

511-
Attribute attr = constOp.getValue();
512-
513514
bool guaranteedResult = false;
514-
if (auto denseAttr = dyn_cast<DenseElementsAttr>(attr)) {
515-
if (denseAttr.getType().getShape().size() && denseAttr.isSplat()) {
516-
denseAttr = denseAttr.resizeSplat(
517-
RankedTensorType::get({}, denseAttr.getType().getElementType()));
518-
}
515+
if (denseAttr.getType().getShape().size() && denseAttr.isSplat()) {
516+
denseAttr = denseAttr.resizeSplat(
517+
RankedTensorType::get({}, denseAttr.getType().getElementType()));
518+
}
519519

520-
// For floating point values
521-
if (isa<FloatType>(denseAttr.getElementType())) {
522-
if (((Child *)this)->constantFloatCheck(denseAttr)) {
523-
guaranteedResult = true;
524-
}
520+
// For floating point values
521+
if (isa<FloatType>(denseAttr.getElementType())) {
522+
if (((Child *)this)->constantFloatCheck(denseAttr)) {
523+
guaranteedResult = true;
525524
}
525+
}
526526

527-
// For integer values
528-
if (isa<IntegerType>(denseAttr.getElementType())) {
529-
if (((Child *)this)->constantIntCheck(denseAttr)) {
530-
guaranteedResult = true;
531-
}
527+
// For integer values
528+
if (isa<IntegerType>(denseAttr.getElementType())) {
529+
if (((Child *)this)->constantIntCheck(denseAttr)) {
530+
guaranteedResult = true;
532531
}
533532
}
534533

535-
rewriter.modifyOpInPlace(constOp, [&]() {
536-
constOp->setAttr(attrName,
537-
BoolAttr::get(constOp.getContext(), guaranteedResult));
534+
rewriter.modifyOpInPlace(op, [&]() {
535+
op->setAttr(attrName, BoolAttr::get(op->getContext(), guaranteedResult));
538536
});
539-
opCache[constOp] = guaranteedResult;
537+
opCache[op] = guaranteedResult;
540538
return guaranteedResult;
541539
}
542540

@@ -563,6 +561,19 @@ template <typename Child> class GuaranteedResultAnalysisBase {
563561

564562
class FiniteResultAnalysis;
565563
class NoNanResultAnalysis;
564+
class SymmetricResultAnalysis;
565+
566+
class SymmetricResultAnalysis
567+
: public GuaranteedResultAnalysisBase<SymmetricResultAnalysis> {
568+
public:
569+
State localGuaranteed(Operation *op, SmallVectorImpl<Operation *> &localtodo,
570+
PatternRewriter &rewriter);
571+
572+
bool constantFloatCheck(DenseElementsAttr attr);
573+
bool constantIntCheck(DenseElementsAttr attr);
574+
575+
StringRef getAttrName() const { return "enzymexla.guaranteed_symmetric"; }
576+
};
566577

567578
class NoNanResultAnalysis
568579
: public GuaranteedResultAnalysisBase<NoNanResultAnalysis> {
@@ -604,6 +615,7 @@ class FiniteResultAnalysis
604615

605616
NoNanResultAnalysis initNoNanResultAnalysis();
606617
FiniteResultAnalysis initFiniteResultAnalysis();
618+
SymmetricResultAnalysis initSymmetricResultAnalysis();
607619

608620
inline bool guaranteedNoNanResult(mlir::Value value,
609621
PatternRewriter &rewriter) {
@@ -621,6 +633,15 @@ inline bool guaranteedFiniteResult(Operation *op, PatternRewriter &rewriter) {
621633
return initFiniteResultAnalysis().guaranteed(op, rewriter);
622634
}
623635

636+
inline bool guaranteedSymmetricResult(mlir::Value value,
637+
PatternRewriter &rewriter) {
638+
return initSymmetricResultAnalysis().guaranteed(value, rewriter);
639+
}
640+
inline bool guaranteedSymmetricResult(Operation *op,
641+
PatternRewriter &rewriter) {
642+
return initSymmetricResultAnalysis().guaranteed(op, rewriter);
643+
}
644+
624645
class NonNegativeResultAnalysis
625646
: public GuaranteedResultAnalysisBase<NonNegativeResultAnalysis> {
626647
public:

test/lit_tests/batchtests/reduce_autodiff.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func.func @main(%arg0: tensor<2xf64>, %arg1: tensor<2xf64>, %arg2: tensor<2xf64>
1818

1919
// CHECK: func.func @main(%arg0: tensor<2xf64>, %arg1: tensor<2xf64>, %arg2: tensor<2xf64>) -> (tensor<1xf64>, tensor<1xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) {
2020
// CHECK-NEXT: %0 = stablehlo.concatenate %arg1, %arg2, dim = 0 : (tensor<2xf64>, tensor<2xf64>) -> tensor<4xf64>
21-
// CHECK-NEXT: %1 = stablehlo.reshape %0 : (tensor<4xf64>) -> tensor<2x2xf64>
21+
// CHECK-NEXT: %1 = stablehlo.reshape %0 {enzymexla.guaranteed_symmetric = false} : (tensor<4xf64>) -> tensor<2x2xf64>
2222
// CHECK-NEXT: %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<2x2xf64>) -> tensor<2x2xf64>
2323
// CHECK-NEXT: %3:3 = call @fwddiffe2fwd_autodiff(%arg0, %2) : (tensor<2xf64>, tensor<2x2xf64>) -> (tensor<2xf64>, tensor<2xf64>, tensor<2x2xf64>)
2424
// CHECK-NEXT: %4 = stablehlo.slice %3#0 [0:1] : (tensor<2xf64>) -> tensor<1xf64>

test/lit_tests/binarytranspose.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ func.func @t14(%arg0: tensor<3x3xf64>, %arg1: tensor<3x3xf64>) -> tensor<3x3xf64
153153
}
154154

155155
// CHECK: func.func @t14(%arg0: tensor<3x3xf64>, %arg1: tensor<3x3xf64>) -> tensor<3x3xf64> {
156-
// CHECK-NEXT: %0 = stablehlo.multiply %arg0, %arg1 : tensor<3x3xf64>
157-
// CHECK-NEXT: %1 = stablehlo.cosine %0 : tensor<3x3xf64>
156+
// CHECK-NEXT: %0 = stablehlo.multiply %arg0, %arg1 {enzymexla.guaranteed_symmetric = false} : tensor<3x3xf64>
157+
// CHECK-NEXT: %1 = stablehlo.cosine %0 {enzymexla.guaranteed_symmetric = false} : tensor<3x3xf64>
158158
// CHECK-NEXT: %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<3x3xf64>) -> tensor<3x3xf64>
159159
// CHECK-NEXT: return %2 : tensor<3x3xf64>
160160
// CHECK-NEXT: }

test/lit_tests/diffrules/stablehlo/while4.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,11 @@ module {
214214
// CHECK-NEXT: %28 = stablehlo.multiply %25#5, %27 : tensor<3x2xf32>
215215
// CHECK-NEXT: %29 = stablehlo.reduce(%28 init: %cst_7) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor<f32>) -> tensor<3xf32>
216216
// CHECK-NEXT: %30 = stablehlo.add %25#3, %29 : tensor<3xf32>
217-
// CHECK-NEXT: %31 = stablehlo.dot_general %28, %5, contracting_dims = [1] x [0] : (tensor<3x2xf32>, tensor<2x3xf32>) -> tensor<3x3xf32>
217+
// CHECK-NEXT: %31 = stablehlo.dot_general %28, %5, contracting_dims = [1] x [0] {enzymexla.guaranteed_symmetric = false} : (tensor<3x2xf32>, tensor<2x3xf32>) -> tensor<3x3xf32>
218218
// CHECK-NEXT: %32 = stablehlo.reduce(%28 init: %cst_7) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor<f32>) -> tensor<3xf32>
219219
// CHECK-NEXT: %33 = stablehlo.add %25#4, %32 : tensor<3xf32>
220220
// CHECK-NEXT: %34 = stablehlo.transpose %25#2, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
221-
// CHECK-NEXT: %35 = stablehlo.add %31, %25#1 : tensor<3x3xf32>
221+
// CHECK-NEXT: %35 = stablehlo.add %31, %25#1 {enzymexla.guaranteed_symmetric = false} : tensor<3x3xf32>
222222
// CHECK-NEXT: %36 = stablehlo.transpose %35, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
223223
// CHECK-NEXT: return %14, %15, %16, %arg3, %arg4, %arg5, %36, %34, %30, %33 : tensor<6x2x3xf32>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<2xui64>, tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>, tensor<3xf32>
224224
// CHECK-NEXT: }

test/lit_tests/raising/affine_to_stablehlo13.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ module {
7272
// CHECK-NEXT: %2 = stablehlo.concatenate %1, %0, dim = 2 : (tensor<3x3x1xi64>, tensor<3x3x1xi64>) -> tensor<3x3x2xi64>
7373
// CHECK-NEXT: %3 = stablehlo.reshape %2 : (tensor<3x3x2xi64>) -> tensor<9x2xi64>
7474
// CHECK-NEXT: %4 = "stablehlo.gather"(%arg1, %3) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<3x3xi64>, tensor<9x2xi64>) -> tensor<9xi64>
75-
// CHECK-NEXT: %5 = stablehlo.reshape %4 : (tensor<9xi64>) -> tensor<3x3xi64>
75+
// CHECK-NEXT: %5 = stablehlo.reshape %4 {enzymexla.guaranteed_symmetric = false} : (tensor<9xi64>) -> tensor<3x3xi64>
7676
// CHECK-NEXT: %6 = stablehlo.transpose %5, dims = [1, 0] : (tensor<3x3xi64>) -> tensor<3x3xi64>
7777
// CHECK-NEXT: return %6, %arg1, %arg2 : tensor<3x3xi64>, tensor<3x3xi64>, tensor<3xi64>
7878
// CHECK-NEXT: }

0 commit comments

Comments
 (0)