Skip to content
Open
30 changes: 9 additions & 21 deletions src/enzyme_ad/jax/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,13 +660,22 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
}
}

if (isa<enzymexla::SymmOp>(op)) {
return State::GUARANTEED;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this generally true? it is true for syrk but not for symm if B is not symmetric

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup you're right, there's no op for syrk yet right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct


bool recursiveCheck = false;

// elementwise ops
if (stablehlo::hasTraitElementwise(op)) {
recursiveCheck = true;
}

if (isa<stablehlo::TransposeOp>(op), isa<stablehlo::DotGeneralOp>(op)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for dot general it only holds if the operands commute, not in general I think

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even then not necessarily since you could have batched dimensions

// All operands symmetric => symmetric result
recursiveCheck = true;
}

/**
* TODO
* - check if its * 0 -> symmetric
Expand Down Expand Up @@ -739,13 +748,6 @@ NoNanResultAnalysis::localGuaranteed(Operation *op,
PatternRewriter &rewriter) {
assert(op);

if (auto boolAttr = op->getAttrOfType<BoolAttr>(getAttrName())) {
if (boolAttr.getValue())
return State::GUARANTEED;
else
return State::NOTGUARANTEED;
}

DenseElementsAttr denseAttr;
if (matchPattern(op, m_Constant(&denseAttr))) {
if (guaranteedConstantOp(op, denseAttr, rewriter)) {
Expand Down Expand Up @@ -882,13 +884,6 @@ FiniteResultAnalysis::localGuaranteed(Operation *op,
PatternRewriter &rewriter) {
assert(op);

if (auto boolAttr = op->getAttrOfType<BoolAttr>(getAttrName())) {
if (boolAttr.getValue())
return State::GUARANTEED;
else
return State::NOTGUARANTEED;
}

DenseElementsAttr denseAttr;
if (matchPattern(op, m_Constant(&denseAttr))) {
if (guaranteedConstantOp(op, denseAttr, rewriter)) {
Expand Down Expand Up @@ -995,13 +990,6 @@ NonNegativeResultAnalysis::State NonNegativeResultAnalysis::localGuaranteed(
PatternRewriter &rewriter) {
assert(op);

if (auto boolAttr = op->getAttrOfType<BoolAttr>(getAttrName())) {
if (boolAttr.getValue())
return State::GUARANTEED;
else
return State::NOTGUARANTEED;
}

DenseElementsAttr denseAttr;
if (matchPattern(op, m_Constant(&denseAttr))) {
if (guaranteedConstantOp(op, denseAttr, rewriter)) {
Expand Down
12 changes: 11 additions & 1 deletion src/enzyme_ad/jax/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,18 @@ template <typename Child> class GuaranteedResultAnalysisBase {
State localGuaranteedWithSetAttr(Operation *op,
SmallVectorImpl<Operation *> &localtodo,
PatternRewriter &rewriter) {
auto state = ((Child *)this)->localGuaranteed(op, localtodo, rewriter);

auto attrName = ((Child *)this)->getAttrName();

if (auto boolAttr = op->getAttrOfType<BoolAttr>(attrName)) {
if (boolAttr.getValue())
return State::GUARANTEED;
else
return State::NOTGUARANTEED;
}

auto state = ((Child *)this)->localGuaranteed(op, localtodo, rewriter);

switch (state) {
case State::GUARANTEED:
rewriter.modifyOpInPlace(op, [&]() {
Expand Down
2 changes: 1 addition & 1 deletion test/lit_tests/diffrules/stablehlo/while4.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ module {
// CHECK-NEXT: %28 = stablehlo.multiply %25#5, %27 : tensor<3x2xf32>
// CHECK-NEXT: %29 = stablehlo.reduce(%28 init: %cst_7) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor<f32>) -> tensor<3xf32>
// CHECK-NEXT: %30 = stablehlo.add %25#3, %29 : tensor<3xf32>
// CHECK-NEXT: %31 = stablehlo.dot_general %28, %5, contracting_dims = [1] x [0] {enzymexla.guaranteed_symmetric = false} : (tensor<3x2xf32>, tensor<2x3xf32>) -> tensor<3x3xf32>
// CHECK-NEXT: %31 = stablehlo.dot_general %28, %5, contracting_dims = [1] x [0] : (tensor<3x2xf32>, tensor<2x3xf32>) -> tensor<3x3xf32>
// CHECK-NEXT: %32 = stablehlo.reduce(%28 init: %cst_7) applies stablehlo.add across dimensions = [1] : (tensor<3x2xf32>, tensor<f32>) -> tensor<3xf32>
// CHECK-NEXT: %33 = stablehlo.add %25#4, %32 : tensor<3xf32>
// CHECK-NEXT: %34 = stablehlo.transpose %25#2, dims = [1, 0] : (tensor<3x3xf32>) -> tensor<3x3xf32>
Expand Down
22 changes: 22 additions & 0 deletions test/lit_tests/structured_tensors/propagate_symmetric.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=transpose_symmetric_simplify" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s

func.func @pass1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> {
%alpha = stablehlo.constant dense<2.0> : tensor<f32>
%beta = stablehlo.constant dense<3.0> : tensor<f32>
%0 = enzymexla.blas.symm %arg0, %arg1, %arg2, %alpha, %beta {side = #enzymexla.side<left>, uplo = #enzymexla.uplo<U>} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<f32>, tensor<f32>) -> tensor<2x2xf32>
%1 = stablehlo.reshape %arg0 {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>) -> tensor<2x2xf32>
%2 = stablehlo.subtract %1, %0 : tensor<2x2xf32>
%3 = stablehlo.dot_general %2, %1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%4 = stablehlo.transpose %3, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %4 : tensor<2x2xf32>
}

// CHECK: func.func @pass1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK-NEXT: %cst = stablehlo.constant dense<2.000000e+00> : tensor<f32>
// CHECK-NEXT: %cst_0 = stablehlo.constant dense<3.000000e+00> : tensor<f32>
// CHECK-NEXT: %0 = enzymexla.blas.symm %arg0, %arg1, %arg2, %cst, %cst_0 {enzymexla.guaranteed_symmetric = true, side = #enzymexla.side<left>, uplo = #enzymexla.uplo<U>} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<f32>, tensor<f32>) -> tensor<2x2xf32>
// CHECK-NEXT: %1 = stablehlo.reshape %arg0 {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: %2 = stablehlo.subtract %1, %0 : tensor<2x2xf32>
// CHECK-NEXT: %3 = stablehlo.dot_general %2, %1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: return %3 : tensor<2x2xf32>
// CHECK-NEXT: }
Loading