Skip to content

Commit b3d6fde

Browse files
authored
feat: iota tensor detection + indirect iota indexing simplification (#1542)
* feat: iota tensor detection * chore: run fmt * feat: rewrite iota ops * test: indirect indexing * feat: support more iota like ops for scatter detection
1 parent c373ec1 commit b3d6fde

File tree

4 files changed

+267
-7
lines changed

4 files changed

+267
-7
lines changed

src/enzyme_ad/jax/Passes/AutoBatching.cpp

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1111
#include "src/enzyme_ad/jax/Implementations/WhileLoopInfo.h"
1212
#include "src/enzyme_ad/jax/Passes/Passes.h"
13+
#include "src/enzyme_ad/jax/Passes/StructuredTensors.h"
1314
#include "src/enzyme_ad/jax/Utils.h"
1415
#include "stablehlo/dialect/StablehloOps.h"
1516
#include "llvm/ADT/DenseMap.h"
@@ -796,6 +797,44 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
796797
if (candidateSlices.empty())
797798
return rewriter.notifyMatchFailure(whileOp, "no candidate slices found");
798799

800+
bool anyOpRewritten = false;
801+
802+
// iota [idx] where iota starts at 0 and iter var also starts at 0
803+
// replace this with idx
804+
// If we do a successful rewrite here, we remove the DynamicSliceInfo from
805+
// the candidateSlices vector (a later invocation will handle the rest)
806+
SmallVector<DynamicSliceInfo> retainedSlices;
807+
for (auto [i, slice] : llvm::enumerate(candidateSlices)) {
808+
auto iotaDetection = detectIotaLikeTensor(slice.sliceOp.getOperand());
809+
if (iotaDetection &&
810+
slice.inductionVarDimension == iotaDetection.value().dimension &&
811+
iotaDetection.value().start == 0 &&
812+
iotaDetection.value().limit == limit) {
813+
anyOpRewritten = true;
814+
815+
OpBuilder::InsertionGuard guard(rewriter);
816+
rewriter.setInsertionPoint(slice.sliceOp);
817+
Value newOperand = info.getInductionVariable();
818+
auto sliceType =
819+
cast<RankedTensorType>(slice.sliceOp.getResult().getType());
820+
auto outElemType = sliceType.getElementType();
821+
if (cast<TensorType>(newOperand.getType()).getElementType() !=
822+
outElemType) {
823+
newOperand = rewriter
824+
.create<stablehlo::ConvertOp>(
825+
slice.sliceOp.getLoc(),
826+
RankedTensorType::get({}, outElemType), newOperand)
827+
.getResult();
828+
}
829+
rewriter.replaceOpWithNewOp<stablehlo::BroadcastInDimOp>(
830+
slice.sliceOp, sliceType, newOperand,
831+
rewriter.getDenseI64ArrayAttr({}));
832+
} else {
833+
retainedSlices.push_back(slice);
834+
}
835+
}
836+
candidateSlices = std::move(retainedSlices);
837+
799838
// Create a map of user operations to their corresponding dynamic slices
800839
DenseMap<Operation *, SmallVector<DynamicSliceInfo>> userOpToSlicesMap;
801840
for (auto ds : candidateSlices) {
@@ -819,9 +858,8 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
819858
}
820859

821860
if (userOpToSlicesMap.empty())
822-
return failure();
861+
return anyOpRewritten ? success() : failure();
823862

824-
bool wasLifted = false;
825863
for (auto &[op, slices] : userOpToSlicesMap) {
826864
SmallVector<bool> allIntermediateReshapes(slices.size());
827865
for (auto [i, slice] : llvm::enumerate(slices))
@@ -839,17 +877,17 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
839877
op->hasTrait<OpTrait::Elementwise>()) {
840878
if (liftOperationByBatching(rewriter, whileOp, slices, op, info,
841879
intermediateReshape)) {
842-
wasLifted = true;
880+
anyOpRewritten = true;
843881
}
844882
} else if (!intermediateReshape && isa<stablehlo::ReshapeOp>(op)) {
845883
if (liftSpecialReshapeOp(rewriter, whileOp, slices,
846884
dyn_cast<stablehlo::ReshapeOp>(op), info)) {
847-
wasLifted = true;
885+
anyOpRewritten = true;
848886
}
849887
}
850888
}
851889

852-
return wasLifted ? success() : failure();
890+
return anyOpRewritten ? success() : failure();
853891
};
854892

855893
bool GreedyWhileLoopBatchFission::liftSpecialReshapeOp(

src/enzyme_ad/jax/Passes/StructuredTensors.cpp

Lines changed: 173 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,10 @@ absl::Status detectDiagonalTensor(stablehlo::ScatterOp scatterOp,
7575
return absl::InvalidArgumentError(
7676
"Scatter dimension numbers are not valid for a diagonal tensor.");
7777

78-
if (auto iotaOp = dyn_cast<stablehlo::IotaOp>(indices.getDefiningOp())) {
79-
if (iotaOp.getIotaDimension() == 0) {
78+
auto isIotaLikeTensor = detectIotaLikeTensor(indices);
79+
if (isIotaLikeTensor) {
80+
auto iotaLikeTensor = isIotaLikeTensor.value();
81+
if (iotaLikeTensor.dimension == 0 && iotaLikeTensor.start == 0) {
8082
*outUpdates = updates;
8183
return absl::OkStatus();
8284
}
@@ -85,5 +87,174 @@ absl::Status detectDiagonalTensor(stablehlo::ScatterOp scatterOp,
8587
return absl::InvalidArgumentError("Not a diagonal tensor.");
8688
}
8789

90+
std::optional<IotaLikeTensor> detectIotaLikeTensor(mlir::Value tensor) {
91+
if (!tensor)
92+
return std::nullopt;
93+
94+
auto elemType =
95+
cast<mlir::RankedTensorType>(tensor.getType()).getElementType();
96+
if (!isa<mlir::IntegerType>(elemType))
97+
return std::nullopt;
98+
99+
struct ChainItem {
100+
mlir::Operation *op;
101+
int64_t offset; // only populated for AddOp/SubtractOp
102+
};
103+
104+
// Build a chain of operations from startOp to the base case
105+
SmallVector<ChainItem> chain;
106+
llvm::DenseSet<mlir::Operation *> visited;
107+
mlir::Operation *currentOp = tensor.getDefiningOp();
108+
109+
// Traverse to find base case
110+
while (currentOp && !visited.contains(currentOp)) {
111+
visited.insert(currentOp);
112+
113+
// check if we found a base case
114+
if (isa<stablehlo::IotaOp, stablehlo::ConstantOp>(currentOp)) {
115+
chain.push_back({currentOp, 0});
116+
break;
117+
}
118+
119+
// navigate to the next op. If any unsupported intermediate op is found,
120+
// then return std::nullopt
121+
Operation *nextOp;
122+
123+
// TODO: we might want to support broadcast_in_dim / insert_dims / drop_dims
124+
// as well
125+
if (isa<stablehlo::TransposeOp>(currentOp)) {
126+
chain.push_back({currentOp, 0});
127+
nextOp = currentOp->getOperand(0).getDefiningOp();
128+
} else if (auto convertOp = dyn_cast<stablehlo::ConvertOp>(currentOp)) {
129+
// if operand of convertOp is not a integer, then return std::nullopt
130+
if (!isa<mlir::IntegerType>(
131+
cast<TensorType>(convertOp.getOperand().getType())
132+
.getElementType()))
133+
return std::nullopt;
134+
chain.push_back({currentOp, 0});
135+
nextOp = convertOp.getOperand().getDefiningOp();
136+
} else if (auto addOp = dyn_cast<stablehlo::AddOp>(currentOp)) {
137+
APInt offsetVal;
138+
if (matchPattern(addOp.getRhs(), m_ConstantInt(&offsetVal))) {
139+
chain.push_back({currentOp, offsetVal.getSExtValue()});
140+
nextOp = addOp.getLhs().getDefiningOp();
141+
} else if (matchPattern(addOp.getLhs(), m_ConstantInt(&offsetVal))) {
142+
chain.push_back({currentOp, offsetVal.getSExtValue()});
143+
nextOp = addOp.getRhs().getDefiningOp();
144+
} else {
145+
return std::nullopt;
146+
}
147+
} else if (auto subOp = dyn_cast<stablehlo::SubtractOp>(currentOp)) {
148+
APInt offsetVal;
149+
if (matchPattern(subOp.getRhs(), m_ConstantInt(&offsetVal))) {
150+
chain.push_back({currentOp, -offsetVal.getSExtValue()});
151+
nextOp = subOp.getLhs().getDefiningOp();
152+
} else {
153+
return std::nullopt;
154+
}
155+
} else { // unsupported op
156+
return std::nullopt;
157+
}
158+
159+
currentOp = nextOp;
160+
}
161+
162+
if (chain.empty())
163+
return std::nullopt;
164+
165+
// process the base case
166+
IotaLikeTensor result;
167+
if (auto iotaOp = dyn_cast<stablehlo::IotaOp>(chain.back().op)) {
168+
auto iotaType = cast<RankedTensorType>(iotaOp.getResult().getType());
169+
auto iotaDim = static_cast<int64_t>(iotaOp.getIotaDimension());
170+
result = IotaLikeTensor{0, iotaType.getShape()[iotaDim], iotaDim, iotaType};
171+
} else if (auto constantOp =
172+
dyn_cast<stablehlo::ConstantOp>(chain.back().op)) {
173+
auto denseAttr = cast<DenseElementsAttr>(constantOp.getValue());
174+
auto constType = cast<RankedTensorType>(constantOp.getResult().getType());
175+
auto shape = constType.getShape();
176+
177+
if (denseAttr.isSplat())
178+
return std::nullopt;
179+
180+
// Calculate strides for indexing
181+
SmallVector<int64_t> strides(constType.getRank(), 1);
182+
for (int64_t i = constType.getRank() - 2; i >= 0; --i) {
183+
strides[i] = strides[i + 1] * shape[i + 1];
184+
}
185+
186+
bool isIotaLike = false;
187+
auto denseAttrValues = denseAttr.getValues<APInt>();
188+
189+
for (int64_t dim = 0; dim < constType.getRank(); dim++) {
190+
bool isIotaAlongDim = true;
191+
std::optional<int64_t> detectedStart;
192+
193+
SmallVector<int64_t> indices(constType.getRank(), 0);
194+
int64_t numElements = constType.getNumElements();
195+
196+
for (int64_t idx = 0; idx < numElements && isIotaAlongDim; idx++) {
197+
int64_t temp = idx;
198+
// linear to cartesian indexing
199+
for (int64_t d = 0; d < constType.getRank(); d++) {
200+
indices[d] = temp / strides[d];
201+
temp = temp % strides[d];
202+
}
203+
204+
int64_t actualValue = denseAttrValues[idx].getSExtValue();
205+
206+
if (!detectedStart) {
207+
detectedStart = actualValue;
208+
}
209+
210+
int64_t expectedValue = detectedStart.value() + indices[dim];
211+
if (actualValue != expectedValue) {
212+
isIotaAlongDim = false;
213+
break;
214+
}
215+
}
216+
217+
if (isIotaAlongDim && detectedStart) {
218+
isIotaLike = true;
219+
result =
220+
IotaLikeTensor{detectedStart.value(),
221+
detectedStart.value() + shape[dim], dim, constType};
222+
break;
223+
}
224+
}
225+
226+
if (!isIotaLike)
227+
return std::nullopt;
228+
} else {
229+
return std::nullopt;
230+
}
231+
232+
// traverse the chain in reverse order
233+
for (int64_t i = chain.size() - 2; i >= 0; i--) {
234+
auto item = chain[i];
235+
236+
if (isa<stablehlo::ConvertOp>(item.op)) {
237+
continue;
238+
} else if (auto transposeOp = dyn_cast<stablehlo::TransposeOp>(item.op)) {
239+
auto permutation = transposeOp.getPermutation();
240+
for (int64_t idx = 0; idx < permutation.size(); idx++) {
241+
if (permutation[idx] == result.dimension) {
242+
result.dimension = idx;
243+
break;
244+
}
245+
}
246+
continue;
247+
} else if (isa<stablehlo::AddOp, stablehlo::SubtractOp>(item.op)) {
248+
result.start += item.offset;
249+
continue;
250+
}
251+
252+
assert(false && "reached unreachable case...");
253+
}
254+
255+
result.tensorType = cast<RankedTensorType>(tensor.getType());
256+
return result;
257+
}
258+
88259
} // namespace enzyme
89260
} // namespace mlir

src/enzyme_ad/jax/Passes/StructuredTensors.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "src/enzyme_ad/jax/Utils.h"
66
#include "stablehlo/dialect/StablehloOps.h"
77

8+
#include <optional>
9+
810
namespace mlir {
911
namespace enzyme {
1012

@@ -16,5 +18,14 @@ absl::Status detectConstantSetindexScatterOp(stablehlo::ScatterOp scatterOp,
1618
absl::Status detectDiagonalTensor(stablehlo::ScatterOp scatterOp,
1719
mlir::Value *outUpdates);
1820

21+
struct IotaLikeTensor {
22+
int64_t start;
23+
int64_t limit;
24+
int64_t dimension;
25+
mlir::RankedTensorType tensorType;
26+
};
27+
28+
std::optional<IotaLikeTensor> detectIotaLikeTensor(mlir::Value tensor);
29+
1930
} // namespace enzyme
2031
} // namespace mlir
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=reshape_dynamic_slice(1);while_is_copy_simplify;greedy_while_loop_batch_fission;broadcast_to_reshape;merge_consecutive_reshapes;reshape_licm(0)" --transform-interpreter --enzyme-hlo-remove-transform --inline --enzyme-hlo-opt --enzyme-hlo-generate-td="patterns=reshape_dynamic_slice(1);while_is_copy_simplify;greedy_while_loop_batch_fission;broadcast_to_reshape;merge_consecutive_reshapes;reshape_licm(0);reshape_elementwise(0)" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s
2+
3+
module {
4+
func.func @main(%arg0: tensor<10xf64>, %arg1: tensor<10xf64>) -> tensor<10xf32> {
5+
%c = stablehlo.constant dense<1> : tensor<i32>
6+
%c_0 = stablehlo.constant dense<1> : tensor<i64>
7+
%c_1 = stablehlo.constant dense<10> : tensor<i64>
8+
%c_2 = stablehlo.constant dense<0> : tensor<i64>
9+
%cst = stablehlo.constant dense<0.000000e+00> : tensor<10xf32>
10+
%c_3 = stablehlo.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xi32>
11+
%0:2 = stablehlo.while(%iterArg = %c_2, %iterArg_4 = %cst) : tensor<i64>, tensor<10xf32>
12+
cond {
13+
%1 = stablehlo.compare LT, %iterArg, %c_1 : (tensor<i64>, tensor<i64>) -> tensor<i1>
14+
stablehlo.return %1 : tensor<i1>
15+
} do {
16+
%1 = stablehlo.add %iterArg, %c_0 : tensor<i64>
17+
%2 = stablehlo.dynamic_slice %c_3, %iterArg, sizes = [1] : (tensor<10xi32>, tensor<i64>) -> tensor<1xi32>
18+
%3 = stablehlo.reshape %2 : (tensor<1xi32>) -> tensor<i32>
19+
%4 = stablehlo.dynamic_slice %arg0, %3, sizes = [1] : (tensor<10xf64>, tensor<i32>) -> tensor<1xf64>
20+
%5 = stablehlo.dynamic_slice %arg1, %3, sizes = [1] : (tensor<10xf64>, tensor<i32>) -> tensor<1xf64>
21+
%6 = stablehlo.add %4, %5 : tensor<1xf64>
22+
%7 = stablehlo.maximum %4, %5 : tensor<1xf64>
23+
%8 = stablehlo.add %6, %7 : tensor<1xf64>
24+
%9 = stablehlo.convert %8 : (tensor<1xf64>) -> tensor<1xf32>
25+
%10 = stablehlo.convert %1 : (tensor<i64>) -> tensor<i32>
26+
%11 = stablehlo.subtract %10, %c : tensor<i32>
27+
%12 = stablehlo.dynamic_update_slice %iterArg_4, %9, %11 : (tensor<10xf32>, tensor<1xf32>, tensor<i32>) -> tensor<10xf32>
28+
stablehlo.return %1, %12 : tensor<i64>, tensor<10xf32>
29+
}
30+
return %0#1 : tensor<10xf32>
31+
}
32+
}
33+
34+
// CHECK: func.func @main(%arg0: tensor<10xf64>, %arg1: tensor<10xf64>) -> tensor<10xf32> {
35+
// CHECK-NEXT: %0 = stablehlo.add %arg0, %arg1 : tensor<10xf64>
36+
// CHECK-NEXT: %1 = stablehlo.maximum %arg0, %arg1 : tensor<10xf64>
37+
// CHECK-NEXT: %2 = stablehlo.add %0, %1 : tensor<10xf64>
38+
// CHECK-NEXT: %3 = stablehlo.convert %2 : (tensor<10xf64>) -> tensor<10xf32>
39+
// CHECK-NEXT: return %3 : tensor<10xf32>
40+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)