Skip to content

Commit 980931d

Browse files
authored
refactor: generalize existing passes to operate on traits (#1603)
* fix: bug fix and generalize transpose_unary_transpose patterns * feat: generalize commutative associative reordering based on traits * refactor: rewrite associative bin op * feat: generalize BinaryOpTransposeSimplify * ci: use specific reactant commit * chore: run fmt * fix: induction var in remove_no_ops * test: use atol and rtol * fix: tolerance
1 parent b98ffe2 commit 980931d

File tree

14 files changed

+320
-377
lines changed

14 files changed

+320
-377
lines changed

.github/workflows/test-gb-25.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ jobs:
5353
- 'main'
5454
# - '0123456789abcdef0123456789abcdef01234567'
5555
reactant_commit:
56-
- 'main'
57-
# - 'regenerate-mlir-bindings'
56+
# - 'main'
57+
- 'ap/pass_generalization'
5858

5959
steps:
6060
- name: Check GPUs

src/enzyme_ad/jax/Implementations/WhileLoopInfo.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,17 @@ struct WhileLoopInfo {
3939
std::optional<int64_t> getConstantStart();
4040
std::optional<int64_t> getConstantLimit();
4141

42-
Value getInductionVariable() { return op.getBody().front().getArgument(0); }
42+
// assumes computeInfo() has been called and was successful
43+
// returns the induction variable in the body of the while op
44+
Value getInductionVariable() {
45+
auto &condBlk = op.getCond().front();
46+
auto condTerm = cast<stablehlo::ReturnOp>(condBlk.getTerminator());
47+
auto condV = condTerm->getOperand(0);
48+
auto cond = condV.getDefiningOp<stablehlo::CompareOp>();
49+
auto induct = dyn_cast<BlockArgument>(cond.getOperand(0));
50+
auto blockArgNum = induct.getArgNumber();
51+
return op.getBody().front().getArgument(blockArgNum);
52+
}
4353

4454
int64_t getConstantNumIters();
4555
Value getNumIters(OpBuilder &builder);

0 commit comments

Comments
 (0)