Skip to content

Commit 6a60956

Browse files
authored
Merge branch 'main' into jm/multislice_opt
2 parents 61ffc4c + 715e2bf commit 6a60956

File tree

3 files changed

+358
-5
lines changed

3 files changed

+358
-5
lines changed

src/enzyme_ad/jax/Passes/AutoBatching.cpp

Lines changed: 249 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -976,8 +976,7 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
976976
for (auto &[op, slices] : userOpToSlicesMap) {
977977
bool avoidBatching =
978978
llvm::TypeSwitch<Operation *, bool>(op)
979-
.Case<stablehlo::DynamicSliceOp, stablehlo::ReshapeOp,
980-
stablehlo::SliceOp,
979+
.Case<stablehlo::ReshapeOp, stablehlo::SliceOp,
981980
// TODO: avoid scatter since that lowers to loop right now
982981
stablehlo::ScatterOp>([=](auto op) { return true; })
983982
.Case<stablehlo::BroadcastInDimOp, stablehlo::TransposeOp>(
@@ -987,9 +986,13 @@ LogicalResult GreedyWhileLoopBatchFission::matchAndRewriteImpl(
987986
continue;
988987
}
989988

990-
if ((dyn_cast<BatchOpInterface>(op) ||
991-
stablehlo::hasTraitElementwise(op)) &&
992-
op->getNumResults() == 1) {
989+
if (auto dsOp = dyn_cast<stablehlo::DynamicSliceOp>(op)) {
990+
if (raiseDynamicSliceToGather(rewriter, whileOp, slices, dsOp, info)) {
991+
anyOpRewritten = true;
992+
}
993+
} else if ((dyn_cast<BatchOpInterface>(op) ||
994+
stablehlo::hasTraitElementwise(op)) &&
995+
op->getNumResults() == 1) {
993996
if (liftOperationByBatching(rewriter, whileOp, slices, op, info)) {
994997
anyOpRewritten = true;
995998
} else if (liftReduceLikeOperation(rewriter, whileOp, slices, op, info)) {
@@ -1440,6 +1443,247 @@ bool liftReduceLikeOperation(
14401443
return true;
14411444
}
14421445

1446+
bool raiseDynamicSliceToGather(
1447+
PatternRewriter &rewriter, stablehlo::WhileOp whileOp,
1448+
ArrayRef<SliceInfo<stablehlo::DynamicSliceOp>> slices,
1449+
stablehlo::DynamicSliceOp dsOp, WhileLoopInfo info) {
1450+
// Pattern: x[..., y[idx], z[idx], ...] where idx is an affine function of
1451+
// the loop induction variable. We need to:
1452+
// 1. Hoist the computation of y[idx], z[idx], etc. for all loop iterations
1453+
// 2. Create a gather operation from x using those indices
1454+
// 3. Replace dsOp uses with a dynamic slice into the gather result
1455+
1456+
// Find all start indices that are dependent on the loop and come from inner
1457+
// dynamic slices (through possible reshape)
1458+
SmallVector<int64_t> dependentDims;
1459+
SmallVector<Value> innerSliceOperands;
1460+
SmallVector<SliceInfo<stablehlo::DynamicSliceOp>> innerSliceInfos;
1461+
1462+
for (auto [i, startIndex] : llvm::enumerate(dsOp.getStartIndices())) {
1463+
// Use traverseOperandsForHoisting to classify this operand
1464+
SmallVector<BatchLiftingMode> modes;
1465+
SmallVector<Value> operands;
1466+
SmallVector<SmallVector<int64_t>> dims;
1467+
SmallVector<int64_t> hoisted;
1468+
SmallVector<SliceInfo<stablehlo::DynamicSliceOp>> mapped;
1469+
DenseMap<Value, SmallVector<Operation *>> hoistMap;
1470+
1471+
SmallVector<Value> singleOperand = {startIndex};
1472+
if (!traverseOperandsForHoisting(singleOperand, whileOp, slices, info,
1473+
modes, operands, dims, hoisted, mapped,
1474+
hoistMap)) {
1475+
return false;
1476+
}
1477+
1478+
if (modes[0] == BatchLiftingMode::DYNAMIC_SLICE) {
1479+
dependentDims.push_back(i);
1480+
innerSliceOperands.push_back(operands[0]);
1481+
innerSliceInfos.push_back(mapped[0]);
1482+
}
1483+
}
1484+
1485+
if (dependentDims.empty()) {
1486+
return false;
1487+
}
1488+
1489+
// Get outer operand - it must be constant across iterations (loop invariant)
1490+
Value outerOperand;
1491+
Value dsOperand = dsOp.getOperand();
1492+
SmallVector<Operation *> canBeHoisted;
1493+
if (!info.isConstantAcrossIterations(dsOperand, outerOperand, canBeHoisted,
1494+
true)) {
1495+
return false;
1496+
}
1497+
if (!outerOperand) {
1498+
// The operand is defined inside the loop but is hoistable - hoist it
1499+
DenseMap<Value, SmallVector<Operation *>> hoistMap;
1500+
hoistMap[dsOperand] = canBeHoisted;
1501+
DenseMap<Value, Value> hoistedValues;
1502+
hoistChainOfOps(hoistMap, rewriter, whileOp, info, hoistedValues);
1503+
outerOperand = hoistedValues[dsOperand];
1504+
}
1505+
1506+
// Verify all non-dependent start indices are constant across iterations
1507+
for (auto [i, startIndex] : llvm::enumerate(dsOp.getStartIndices())) {
1508+
if (llvm::is_contained(dependentDims, i)) {
1509+
continue;
1510+
}
1511+
if (!info.isConstantAcrossIterations(startIndex, true)) {
1512+
return false;
1513+
}
1514+
}
1515+
1516+
int64_t numIters = info.getConstantNumIters();
1517+
Location loc = dsOp.getLoc();
1518+
1519+
rewriter.setInsertionPoint(whileOp);
1520+
1521+
// Step 1: Hoist each inner slice operand and construct the gather indices
1522+
// We need to gather all indices for all loop iterations and concatenate them.
1523+
SmallVector<Value> hoistedIndicesList;
1524+
Type hoistedIndicesElemTy;
1525+
1526+
for (size_t idx = 0; idx < dependentDims.size(); idx++) {
1527+
Value hoistedIndices;
1528+
if (!info.hoistOperationFromLoop(
1529+
rewriter, innerSliceOperands[idx], innerSliceInfos[idx].sliceOp,
1530+
innerSliceInfos[idx].dimensions, hoistedIndices)) {
1531+
return false;
1532+
}
1533+
1534+
auto hoistedTy = cast<RankedTensorType>(hoistedIndices.getType());
1535+
if (idx == 0) {
1536+
hoistedIndicesElemTy = hoistedTy.getElementType();
1537+
}
1538+
1539+
// Reshape to [numIters, 1] for use as gather indices
1540+
SmallVector<int64_t> reshapeShape = {numIters, 1};
1541+
auto reshapeTy = RankedTensorType::get(reshapeShape, hoistedIndicesElemTy);
1542+
1543+
// Convert type if needed
1544+
if (hoistedTy.getElementType() != hoistedIndicesElemTy) {
1545+
hoistedIndices = stablehlo::ConvertOp::create(
1546+
rewriter, loc,
1547+
RankedTensorType::get(hoistedTy.getShape(), hoistedIndicesElemTy),
1548+
hoistedIndices);
1549+
}
1550+
1551+
Value reshaped =
1552+
stablehlo::ReshapeOp::create(rewriter, loc, reshapeTy, hoistedIndices);
1553+
hoistedIndicesList.push_back(reshaped);
1554+
}
1555+
1556+
// Concatenate all hoisted indices along the last dimension
1557+
Value gatherIndices;
1558+
if (hoistedIndicesList.size() == 1) {
1559+
gatherIndices = hoistedIndicesList[0];
1560+
} else {
1561+
// Result shape: [numIters, numDependentDims]
1562+
SmallVector<int64_t> concatShape = {numIters,
1563+
(int64_t)dependentDims.size()};
1564+
auto concatTy = RankedTensorType::get(concatShape, hoistedIndicesElemTy);
1565+
gatherIndices = stablehlo::ConcatenateOp::create(
1566+
rewriter, loc, concatTy, hoistedIndicesList, /*dimension=*/1);
1567+
}
1568+
1569+
// Step 2: Create the gather operation from the outer operand
1570+
auto outerOperandTy = cast<RankedTensorType>(outerOperand.getType());
1571+
auto dsSliceSizes = dsOp.getSliceSizes();
1572+
1573+
// The gather slice sizes: dependent dimensions get 1, others get original
1574+
SmallVector<int64_t> gatherSliceSizes;
1575+
for (size_t i = 0; i < dsSliceSizes.size(); i++) {
1576+
if (llvm::is_contained(dependentDims, i)) {
1577+
gatherSliceSizes.push_back(1);
1578+
} else {
1579+
gatherSliceSizes.push_back(dsSliceSizes[i]);
1580+
}
1581+
}
1582+
1583+
// offsetDims: output dimensions corresponding to non-collapsed slice dims
1584+
// Start at 1 since batch dim is at position 0, then consecutive for each
1585+
// non-collapsed dimension
1586+
SmallVector<int64_t> offsetDims;
1587+
int64_t offsetDimIdx = 1; // Start after the batch dimension
1588+
for (size_t i = 0; i < outerOperandTy.getRank(); i++) {
1589+
if (!llvm::is_contained(dependentDims, i)) {
1590+
offsetDims.push_back(offsetDimIdx);
1591+
offsetDimIdx++;
1592+
}
1593+
}
1594+
1595+
// collapsedSliceDims: the dimensions we're indexing into
1596+
SmallVector<int64_t> collapsedSliceDims = llvm::to_vector(dependentDims);
1597+
1598+
// startIndexMap: maps index vector dimensions to operand dimensions
1599+
SmallVector<int64_t> startIndexMap = llvm::to_vector(dependentDims);
1600+
1601+
// Calculate output shape: [numIters, ...sliceSizes for non-dependent dims...]
1602+
SmallVector<int64_t> gatherOutputShape;
1603+
gatherOutputShape.push_back(numIters);
1604+
for (size_t i = 0; i < dsSliceSizes.size(); i++) {
1605+
if (!llvm::is_contained(dependentDims, i)) {
1606+
gatherOutputShape.push_back(dsSliceSizes[i]);
1607+
}
1608+
}
1609+
1610+
auto gatherResultTy =
1611+
RankedTensorType::get(gatherOutputShape, outerOperandTy.getElementType());
1612+
1613+
auto gatherOp = stablehlo::GatherOp::create(
1614+
rewriter, loc, gatherResultTy, outerOperand, gatherIndices,
1615+
stablehlo::GatherDimensionNumbersAttr::get(
1616+
rewriter.getContext(),
1617+
/*offsetDims=*/offsetDims,
1618+
/*collapsedSliceDims=*/collapsedSliceDims,
1619+
/*operandBatchingDims=*/{},
1620+
/*startIndicesBatchingDims=*/{},
1621+
/*startIndexMap=*/startIndexMap,
1622+
/*indexVectorDim=*/1),
1623+
gatherSliceSizes);
1624+
1625+
// Step 3: Replace the dsOp with a dynamic slice into the gather result
1626+
// The dynamic slice will index using the loop induction variable
1627+
rewriter.setInsertionPointAfter(dsOp);
1628+
1629+
auto inductionVar = info.getInductionVariable();
1630+
auto inductionVarType = cast<RankedTensorType>(inductionVar.getType());
1631+
1632+
// Compute the index for the dynamic slice
1633+
Value sliceIndex;
1634+
if (info.isConstantStart() && info.getConstantStart() == 0) {
1635+
sliceIndex = inductionVar;
1636+
} else {
1637+
sliceIndex = stablehlo::SubtractOp::create(rewriter, loc, inductionVar,
1638+
info.getStart());
1639+
}
1640+
if (!info.isStepOne()) {
1641+
sliceIndex = stablehlo::DivOp::create(rewriter, loc, sliceIndex,
1642+
info.getStep(rewriter));
1643+
}
1644+
1645+
// Convert sliceIndex to the same type as gather indices if needed
1646+
if (inductionVarType.getElementType() != hoistedIndicesElemTy) {
1647+
auto newIndexTy = RankedTensorType::get({}, hoistedIndicesElemTy);
1648+
sliceIndex =
1649+
stablehlo::ConvertOp::create(rewriter, loc, newIndexTy, sliceIndex);
1650+
}
1651+
1652+
// Create constZero with the same type as sliceIndex (after conversion)
1653+
auto sliceIndexTy = cast<RankedTensorType>(sliceIndex.getType());
1654+
auto constZero = stablehlo::ConstantOp::create(
1655+
rewriter, loc, sliceIndexTy,
1656+
cast<ElementsAttr>(makeAttr(sliceIndexTy, 0)));
1657+
// Build the start indices for dynamic slice
1658+
SmallVector<Value> dynSliceStarts;
1659+
dynSliceStarts.push_back(sliceIndex);
1660+
for (size_t i = 0; i < dsSliceSizes.size(); i++) {
1661+
if (!llvm::is_contained(dependentDims, i)) {
1662+
dynSliceStarts.push_back(constZero);
1663+
}
1664+
}
1665+
1666+
// Build the slice sizes (1 for the batch dim, original sizes for others)
1667+
SmallVector<int64_t> dynSliceSizes;
1668+
dynSliceSizes.push_back(1);
1669+
for (size_t i = 0; i < dsSliceSizes.size(); i++) {
1670+
if (!llvm::is_contained(dependentDims, i)) {
1671+
dynSliceSizes.push_back(dsSliceSizes[i]);
1672+
}
1673+
}
1674+
1675+
auto dynSlice = stablehlo::DynamicSliceOp::create(
1676+
rewriter, loc, gatherOp.getResult(), dynSliceStarts, dynSliceSizes);
1677+
1678+
// Reshape to match the original dsOp output type
1679+
auto replacement =
1680+
stablehlo::ReshapeOp::create(rewriter, loc, dsOp.getType(), dynSlice);
1681+
1682+
rewriter.replaceOp(dsOp, replacement.getResult());
1683+
1684+
return true;
1685+
}
1686+
14431687
bool liftOperationByBatching(
14441688
PatternRewriter &rewriter, stablehlo::WhileOp whileOp,
14451689
ArrayRef<SliceInfo<stablehlo::DynamicSliceOp>> slices, Operation *op,

src/enzyme_ad/jax/Passes/AutoBatching.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ struct SliceToBatchElementwise : public SliceToBatchBase {
144144
: SliceToBatchBase(CheckElementwise, ctx, benefit) {}
145145
};
146146

147+
bool raiseDynamicSliceToGather(
148+
mlir::PatternRewriter &rewriter, mlir::stablehlo::WhileOp whileOp,
149+
llvm::ArrayRef<SliceInfo<mlir::stablehlo::DynamicSliceOp>> slices,
150+
mlir::stablehlo::DynamicSliceOp dsOp, mlir::enzyme::WhileLoopInfo info);
151+
147152
bool liftOperationByBatching(
148153
mlir::PatternRewriter &rewriter, mlir::stablehlo::WhileOp whileOp,
149154
llvm::ArrayRef<SliceInfo<mlir::stablehlo::DynamicSliceOp>> slices,
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-opt="enable_auto_batching_passes=true" %s | FileCheck %s
2+
3+
module {
4+
func.func @main(%arg0: tensor<274xcomplex<f64>>, %arg1: tensor<274xf64>, %arg2: tensor<274xf64>, %arg3: tensor<274xf64>, %arg4: tensor<274xf64>, %arg5: tensor<274xi32>, %arg6: tensor<274xi32>, %arg7: tensor<274xi32>, %arg8: tensor<274xi32>, %arg9: tensor<126xf64>, %arg10: tensor<126xf64>) -> tensor<274xcomplex<f64>> {
5+
%c = stablehlo.constant dense<1> : tensor<274xi32>
6+
%c_0 = stablehlo.constant dense<1> : tensor<i32>
7+
%c_1 = stablehlo.constant dense<0> : tensor<i32>
8+
%c_2 = stablehlo.constant dense<1> : tensor<i32>
9+
%c_3 = stablehlo.constant dense<274> : tensor<i32>
10+
%8:2 = stablehlo.while(%iterArg = %c_1, %iterArg_4 = %arg0) : tensor<i32>, tensor<274xcomplex<f64>>
11+
cond {
12+
%9 = stablehlo.compare LT, %iterArg, %c_3 : (tensor<i32>, tensor<i32>) -> tensor<i1>
13+
stablehlo.return %9 : tensor<i1>
14+
} do {
15+
%9 = stablehlo.add %c_2, %iterArg {enzymexla.bounds = [[1, 274]]} : tensor<i32>
16+
%11 = stablehlo.subtract %9, %c_0 {enzymexla.bounds = [[0, 273]]} : tensor<i32>
17+
%12 = stablehlo.dynamic_slice %arg0, %11, sizes = [1] : (tensor<274xcomplex<f64>>, tensor<i32>) -> tensor<1xcomplex<f64>>
18+
%13 = stablehlo.dynamic_slice %arg5, %iterArg, sizes = [1] : (tensor<274xi32>, tensor<i32>) -> tensor<1xi32>
19+
%14 = stablehlo.reshape %13 : (tensor<1xi32>) -> tensor<i32>
20+
%15 = stablehlo.dynamic_slice %arg9, %14, sizes = [1] : (tensor<126xf64>, tensor<i32>) -> tensor<1xf64>
21+
%16 = stablehlo.dynamic_slice %arg7, %iterArg, sizes = [1] : (tensor<274xi32>, tensor<i32>) -> tensor<1xi32>
22+
%17 = stablehlo.reshape %16 : (tensor<1xi32>) -> tensor<i32>
23+
%18 = stablehlo.dynamic_slice %arg10, %17, sizes = [1] : (tensor<126xf64>, tensor<i32>) -> tensor<1xf64>
24+
%19 = stablehlo.complex %15, %18 : tensor<1xcomplex<f64>>
25+
%20 = stablehlo.exponential %19 : tensor<1xcomplex<f64>>
26+
%21 = stablehlo.dynamic_slice %arg6, %iterArg, sizes = [1] : (tensor<274xi32>, tensor<i32>) -> tensor<1xi32>
27+
%22 = stablehlo.reshape %21 : (tensor<1xi32>) -> tensor<i32>
28+
%23 = stablehlo.dynamic_slice %arg9, %22, sizes = [1] : (tensor<126xf64>, tensor<i32>) -> tensor<1xf64>
29+
%24 = stablehlo.dynamic_slice %arg8, %iterArg, sizes = [1] : (tensor<274xi32>, tensor<i32>) -> tensor<1xi32>
30+
%25 = stablehlo.reshape %24 : (tensor<1xi32>) -> tensor<i32>
31+
%26 = stablehlo.dynamic_slice %arg10, %25, sizes = [1] : (tensor<126xf64>, tensor<i32>) -> tensor<1xf64>
32+
%27 = stablehlo.complex %23, %26 : tensor<1xcomplex<f64>>
33+
%28 = stablehlo.exponential %27 : tensor<1xcomplex<f64>>
34+
%29 = chlo.conj %28 : tensor<1xcomplex<f64>> -> tensor<1xcomplex<f64>>
35+
%30 = stablehlo.multiply %20, %12 : tensor<1xcomplex<f64>>
36+
%31 = stablehlo.multiply %30, %29 : tensor<1xcomplex<f64>>
37+
%32 = stablehlo.dynamic_update_slice %iterArg_4, %31, %11 : (tensor<274xcomplex<f64>>, tensor<1xcomplex<f64>>, tensor<i32>) -> tensor<274xcomplex<f64>>
38+
stablehlo.return %9, %32 : tensor<i32>, tensor<274xcomplex<f64>>
39+
}
40+
return %8#1 : tensor<274xcomplex<f64>>
41+
}
42+
}
43+
44+
// CHECK: func.func @main(%arg0: tensor<274xcomplex<f64>>, %arg1: tensor<274xf64>, %arg2: tensor<274xf64>, %arg3: tensor<274xf64>, %arg4: tensor<274xf64>, %arg5: tensor<274xi32>, %arg6: tensor<274xi32>, %arg7: tensor<274xi32>, %arg8: tensor<274xi32>, %arg9: tensor<126xf64>, %arg10: tensor<126xf64>) -> tensor<274xcomplex<f64>> {
45+
// CHECK-NEXT: %0 = stablehlo.reshape %arg8 : (tensor<274xi32>) -> tensor<274x1xi32>
46+
// CHECK-NEXT: %1 = "stablehlo.gather"(%arg10, %0) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1>}> : (tensor<126xf64>, tensor<274x1xi32>) -> tensor<274xf64>
47+
// CHECK-NEXT: %2 = stablehlo.reshape %arg6 : (tensor<274xi32>) -> tensor<274x1xi32>
48+
// CHECK-NEXT: %3 = "stablehlo.gather"(%arg9, %2) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1>}> : (tensor<126xf64>, tensor<274x1xi32>) -> tensor<274xf64>
49+
// CHECK-NEXT: %4 = stablehlo.reshape %arg7 : (tensor<274xi32>) -> tensor<274x1xi32>
50+
// CHECK-NEXT: %5 = "stablehlo.gather"(%arg10, %4) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1>}> : (tensor<126xf64>, tensor<274x1xi32>) -> tensor<274xf64>
51+
// CHECK-NEXT: %6 = stablehlo.reshape %arg5 : (tensor<274xi32>) -> tensor<274x1xi32>
52+
// CHECK-NEXT: %7 = "stablehlo.gather"(%arg9, %6) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1>}> : (tensor<126xf64>, tensor<274x1xi32>) -> tensor<274xf64>
53+
// CHECK-NEXT: %8 = stablehlo.complex %7, %5 : tensor<274xcomplex<f64>>
54+
// CHECK-NEXT: %9 = stablehlo.complex %3, %1 : tensor<274xcomplex<f64>>
55+
// CHECK-NEXT: %10 = stablehlo.exponential %9 : tensor<274xcomplex<f64>>
56+
// CHECK-NEXT: %11 = chlo.conj %10 : tensor<274xcomplex<f64>> -> tensor<274xcomplex<f64>>
57+
// CHECK-NEXT: %12 = stablehlo.exponential %8 : tensor<274xcomplex<f64>>
58+
// CHECK-NEXT: %13 = stablehlo.multiply %12, %arg0 : tensor<274xcomplex<f64>>
59+
// CHECK-NEXT: %14 = stablehlo.multiply %13, %11 : tensor<274xcomplex<f64>>
60+
// CHECK-NEXT: return %14 : tensor<274xcomplex<f64>>
61+
// CHECK-NEXT: }
62+
63+
module {
64+
func.func @main(%arg0: tensor<18x3x10xf32>, %arg1: tensor<5xi32>, %arg2: tensor<5xi32>) -> tensor<5x3xf32> {
65+
%c = stablehlo.constant dense<0> : tensor<i32>
66+
%c_0 = stablehlo.constant dense<1> : tensor<5xi32>
67+
%c_1 = stablehlo.constant dense<5> : tensor<i32>
68+
%cst = stablehlo.constant dense<0.000000e+00> : tensor<5x3xf32>
69+
%c_2 = stablehlo.constant dense<1> : tensor<i32>
70+
%c_3 = stablehlo.constant dense<0> : tensor<i32>
71+
%c_4 = stablehlo.constant dense<1> : tensor<i32>
72+
%0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<18x3x10xf32>) -> tensor<10x3x18xf32>
73+
%5 = stablehlo.subtract %arg2, %c_0 : tensor<5xi32>
74+
%6 = stablehlo.subtract %arg1, %c_0 : tensor<5xi32>
75+
%7:2 = stablehlo.while(%iterArg = %c_3, %iterArg_5 = %cst) : tensor<i32>, tensor<5x3xf32>
76+
cond {
77+
%8 = stablehlo.compare LT, %iterArg, %c_1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
78+
stablehlo.return %8 : tensor<i1>
79+
} do {
80+
%8 = stablehlo.add %c_4, %iterArg {enzymexla.bounds = [[1, 5]]} : tensor<i32>
81+
%9 = stablehlo.dynamic_slice %6, %iterArg, sizes = [1] : (tensor<5xi32>, tensor<i32>) -> tensor<1xi32>
82+
%10 = stablehlo.reshape %9 : (tensor<1xi32>) -> tensor<i32>
83+
%11 = stablehlo.dynamic_slice %5, %iterArg, sizes = [1] : (tensor<5xi32>, tensor<i32>) -> tensor<1xi32>
84+
%12 = stablehlo.reshape %11 : (tensor<1xi32>) -> tensor<i32>
85+
%13 = stablehlo.dynamic_slice %0, %10, %c, %12, sizes = [1, 3, 1] : (tensor<10x3x18xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<1x3x1xf32>
86+
%15 = stablehlo.subtract %8, %c_2 {enzymexla.bounds = [[0, 4]]} : tensor<i32>
87+
%16 = stablehlo.reshape %13 : (tensor<1x3x1xf32>) -> tensor<1x3xf32>
88+
%17 = stablehlo.dynamic_update_slice %iterArg_5, %16, %15, %c : (tensor<5x3xf32>, tensor<1x3xf32>, tensor<i32>, tensor<i32>) -> tensor<5x3xf32>
89+
stablehlo.return %8, %17 : tensor<i32>, tensor<5x3xf32>
90+
}
91+
return %7#1 : tensor<5x3xf32>
92+
}
93+
}
94+
95+
// CHECK: func.func @main(%arg0: tensor<18x3x10xf32>, %arg1: tensor<5xi32>, %arg2: tensor<5xi32>) -> tensor<5x3xf32> {
96+
// CHECK-NEXT: %c = stablehlo.constant dense<1> : tensor<5x2xi32>
97+
// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<18x3x10xf32>) -> tensor<10x3x18xf32>
98+
// CHECK-NEXT: %1 = stablehlo.reshape %arg1 : (tensor<5xi32>) -> tensor<5x1xi32>
99+
// CHECK-NEXT: %2 = stablehlo.reshape %arg2 : (tensor<5xi32>) -> tensor<5x1xi32>
100+
// CHECK-NEXT: %3 = stablehlo.concatenate %1, %2, dim = 1 : (tensor<5x1xi32>, tensor<5x1xi32>) -> tensor<5x2xi32>
101+
// CHECK-NEXT: %4 = stablehlo.subtract %3, %c : tensor<5x2xi32>
102+
// CHECK-NEXT: %5 = "stablehlo.gather"(%0, %4) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0, 2], start_index_map = [0, 2], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 3, 1>}> : (tensor<10x3x18xf32>, tensor<5x2xi32>) -> tensor<5x3xf32>
103+
// CHECK-NEXT: return %5 : tensor<5x3xf32>
104+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)