Skip to content

Commit 3539679

Browse files
committed
more vibe
1 parent e03a7b5 commit 3539679

File tree

4 files changed

+48
-2
lines changed

4 files changed

+48
-2
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14607,6 +14607,7 @@ struct GatherOpCanon final
1460714607
int64_t affineDim;
1460814608
int64_t scale;
1460914609
int64_t start;
14610+
int64_t componentIndex;
1461014611
};
1461114612

1461214613
LogicalResult
@@ -14650,7 +14651,8 @@ struct GatherOpCanon final
1465014651
dnums.getStartIndexMap()[0],
1465114652
affineDim,
1465214653
static_cast<int64_t>(*scaleVal),
14653-
static_cast<int64_t>(*startVal)
14654+
static_cast<int64_t>(*startVal),
14655+
static_cast<int64_t>(i)
1465414656
});
1465514657
}
1465614658

@@ -14676,7 +14678,12 @@ struct GatherOpCanon final
1467614678

1467714679
int64_t currentScale = 1;
1467814680
for (const auto& ds : dimStrides) {
14679-
int64_t implied_dim_size = baseIndicesTy.getDimSize(ds.affineDim);
14681+
int64_t implied_dim_size;
14682+
if (affineIota.implied_dim_sizes.size() > static_cast<size_t>(ds.componentIndex)) {
14683+
implied_dim_size = affineIota.implied_dim_sizes[ds.componentIndex];
14684+
} else {
14685+
implied_dim_size = baseIndicesTy.getDimSize(ds.affineDim);
14686+
}
1468014687
if (ds.scale != currentScale) {
1468114688
llvm::errs() << "BAIL affine iota matching: strides not forming a contiguous block (" << ds.scale << " vs " << currentScale << ")\n";
1468214689
return failure();

src/enzyme_ad/jax/Utils.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1866,6 +1866,11 @@ std::optional<AffineIotaLikeTensor> detectAffineIotaLikeTensor(mlir::Value tenso
18661866
AffineIotaLikeTensor result;
18671867
result.starts.push_back(iotaLike->start);
18681868
result.dimensions.push_back(iotaLike->dimension);
1869+
1870+
// Explicitly record the implied dimension size from the original iota shape
1871+
auto indicesTy = cast<RankedTensorType>(iotaLike->tensorType);
1872+
result.implied_dim_sizes.push_back(indicesTy.getDimSize(iotaLike->dimension));
1873+
18691874
result.scales.push_back(iotaLike->scale);
18701875
result.tensorType = iotaLike->tensorType;
18711876
return result;
@@ -1916,6 +1921,7 @@ std::optional<AffineIotaLikeTensor> detectAffineIotaLikeTensor(mlir::Value tenso
19161921
AffineIotaLikeTensor combined = *lhsAffine;
19171922
combined.starts.insert(combined.starts.end(), rhsAffine->starts.begin(), rhsAffine->starts.end());
19181923
combined.dimensions.insert(combined.dimensions.end(), rhsAffine->dimensions.begin(), rhsAffine->dimensions.end());
1924+
combined.implied_dim_sizes.insert(combined.implied_dim_sizes.end(), rhsAffine->implied_dim_sizes.begin(), rhsAffine->implied_dim_sizes.end());
19191925
combined.scales.insert(combined.scales.end(), rhsAffine->scales.begin(), rhsAffine->scales.end());
19201926
combined.tensorType = cast<RankedTensorType>(addOp.getType());
19211927
return combined;

src/enzyme_ad/jax/Utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,6 +1114,7 @@ std::optional<IotaLikeTensor> detectIotaLikeTensor(mlir::Value tensor);
11141114
struct AffineIotaLikeTensor {
11151115
llvm::SmallVector<mlir::TypedAttr> starts;
11161116
llvm::SmallVector<int64_t> dimensions;
1117+
llvm::SmallVector<int64_t> implied_dim_sizes;
11171118
llvm::SmallVector<mlir::TypedAttr> scales;
11181119
mlir::RankedTensorType tensorType;
11191120
};

test/lit_tests/gather_iota_to_slice.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,35 @@ func.func @gather_iota_negative_stride_offset_to_slice_reverse(%arg0: tensor<10x
231231
// CHECK-NEXT: %[[SLICE:.+]] = stablehlo.slice %arg0 [1:8:2]
232232
// CHECK-NEXT: %[[REVERSE:.+]] = stablehlo.reverse %[[SLICE]], dims = [0]
233233
// CHECK-NEXT: return %[[REVERSE]]
234+
235+
// ============================================================================
236+
// Tests for gather with multi-dimensional iota reshaped into a flat index
237+
// ============================================================================
238+
239+
// Reshaped indices: gather with a 2D grid of indices flattened
240+
func.func @gather_reshaped_iota_to_slice(%arg0: tensor<10xi64>) -> tensor<4xi64> {
241+
%iota1 = stablehlo.iota dim = 0 : tensor<2x2xi64>
242+
%iota2 = stablehlo.iota dim = 1 : tensor<2x2xi64>
243+
%c2 = stablehlo.constant dense<2> : tensor<2x2xi64>
244+
%scaled = stablehlo.multiply %iota1, %c2 : tensor<2x2xi64>
245+
%added = stablehlo.add %scaled, %iota2 : tensor<2x2xi64>
246+
%c_offset = stablehlo.constant dense<1> : tensor<2x2xi64>
247+
%indices_2d = stablehlo.add %added, %c_offset : tensor<2x2xi64>
248+
// indices_2d is [[1, 2],
249+
// [3, 4]]
250+
%indices = stablehlo.reshape %indices_2d : (tensor<2x2xi64>) -> tensor<4x1xi64>
251+
252+
%0 = "stablehlo.gather"(%arg0, %indices) {
253+
dimension_numbers = #stablehlo.gather<
254+
collapsed_slice_dims = [0],
255+
start_index_map = [0],
256+
index_vector_dim = 1
257+
>,
258+
slice_sizes = array<i64: 1>
259+
} : (tensor<10xi64>, tensor<4x1xi64>) -> tensor<4xi64>
260+
return %0 : tensor<4xi64>
261+
}
262+
// CHECK-LABEL: func.func @gather_reshaped_iota_to_slice
263+
// CHECK-NEXT: %[[SLICE:.+]] = stablehlo.slice %arg0 [1:5]
264+
// CHECK-NEXT: return %[[SLICE]]
265+

0 commit comments

Comments
 (0)