Skip to content

Commit c2dfd55

Browse files
committed
feat: transpose scatter to scatter transpose
1 parent 418f80e commit c2dfd55

File tree

6 files changed

+228
-30
lines changed

6 files changed

+228
-30
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
#include <cstddef>
5757
#include <iterator>
5858
#include <mlir/IR/Value.h>
59+
#include <mlir/IR/ValueRange.h>
5960
#include <numeric>
6061
#define DEBUG_TYPE "enzymehloopt"
6162

@@ -22032,15 +22033,12 @@ struct TransposeReverse final
2203222033
if (!reverseOp->getResult(0).hasOneUse())
2203322034
return failure();
2203422035

22035-
auto invPerm = getInversePermutation(op.getPermutation());
22036-
SmallVector<int64_t> newReverseDims(reverseOp.getDimensions().size());
22037-
for (auto [i, dim] : llvm::enumerate(reverseOp.getDimensions()))
22038-
newReverseDims[i] = invPerm[dim];
22039-
2204022036
auto newTranspose = stablehlo::TransposeOp::create(
2204122037
rewriter, op.getLoc(), reverseOp.getOperand(), op.getPermutation());
22042-
rewriter.replaceOpWithNewOp<stablehlo::ReverseOp>(op, newTranspose,
22043-
newReverseDims);
22038+
rewriter.replaceOpWithNewOp<stablehlo::ReverseOp>(
22039+
op, newTranspose,
22040+
applyInversePermutationToDims(op.getPermutation(),
22041+
reverseOp.getDimensions()));
2204422042
return success();
2204522043
}
2204622044
};
@@ -28543,6 +28541,56 @@ struct FuseReshapeCollapseOrExpandDimsIntoReduce final
2854328541
}
2854428542
};
2854528543

28544+
struct TransposeScatter final
28545+
: CheckedOpRewritePattern<stablehlo::TransposeOp, TransposeScatter> {
28546+
using CheckedOpRewritePattern::CheckedOpRewritePattern;
28547+
28548+
LogicalResult matchAndRewriteImpl(stablehlo::TransposeOp op,
28549+
PatternRewriter &rewriter) {
28550+
auto scatterOp = op.getOperand().getDefiningOp<stablehlo::ScatterOp>();
28551+
if (!scatterOp) {
28552+
return rewriter.notifyMatchFailure(op,
28553+
"TransposeOp with non-scatter input");
28554+
}
28555+
28556+
if (!isOnlyUsedInOperation(scatterOp, op)) {
28557+
return failure();
28558+
}
28559+
28560+
SmallVector<Value> transposedInputs;
28561+
for (auto input : scatterOp.getInputs()) {
28562+
auto transposedInput = stablehlo::TransposeOp::create(
28563+
rewriter, op.getLoc(), input, op.getPermutation());
28564+
transposedInputs.push_back(transposedInput);
28565+
}
28566+
28567+
auto scatterDims = scatterOp.getScatterDimensionNumbers();
28568+
auto invPerm = getInversePermutation(op.getPermutation());
28569+
28570+
auto newInputBatchingDims =
28571+
applyPermutationToDims(invPerm, scatterDims.getInputBatchingDims());
28572+
llvm::sort(newInputBatchingDims);
28573+
28574+
auto newScatterDimsToOperandDims = applyPermutationToDims(
28575+
invPerm, scatterDims.getScatterDimsToOperandDims());
28576+
28577+
auto newScatterDims = stablehlo::ScatterDimensionNumbersAttr::get(
28578+
rewriter.getContext(), scatterDims.getUpdateWindowDims(),
28579+
scatterDims.getInsertedWindowDims(), newInputBatchingDims,
28580+
scatterDims.getScatterIndicesBatchingDims(),
28581+
newScatterDimsToOperandDims, scatterDims.getIndexVectorDim());
28582+
28583+
auto newScatterOp = stablehlo::ScatterOp::create(
28584+
rewriter, op.getLoc(), TypeRange(op.getType()), transposedInputs,
28585+
scatterOp.getScatterIndices(), scatterOp.getUpdates(), newScatterDims,
28586+
scatterOp.getIndicesAreSortedAttr(), scatterOp.getUniqueIndicesAttr());
28587+
newScatterOp.getUpdateComputation().takeBody(
28588+
scatterOp.getUpdateComputation());
28589+
rewriter.replaceOp(op, newScatterOp->getResults());
28590+
return success();
28591+
}
28592+
};
28593+
2854628594
/////////////// End Imported from stablehlo
2854728595

2854828596
// clang-format off
@@ -29065,13 +29113,14 @@ struct EnzymeHLOOptPass
2906529113
}
2906629114

2906729115
if (passses & (2048 * 32)) {
29068-
patterns.add<TransposeWhile, TransposeSliceBase<stablehlo::SliceOp>,
29069-
TransposeConcat, TransposeDUS, TransposeIota,
29070-
TransposeReduceWindow, TransposeReduce, TransposeSelect,
29071-
TransposeSliceBase<stablehlo::DynamicSliceOp>,
29072-
TransposeReverse, TransposeBatchNormTraining,
29073-
TransposeBatchNormInference, TransposeBatchNormGrad,
29074-
TransposeIf, TransposeFFT, TransposeReshape>(context);
29116+
patterns
29117+
.add<TransposeWhile, TransposeSliceBase<stablehlo::SliceOp>,
29118+
TransposeConcat, TransposeDUS, TransposeIota,
29119+
TransposeReduceWindow, TransposeReduce, TransposeSelect,
29120+
TransposeSliceBase<stablehlo::DynamicSliceOp>, TransposeReverse,
29121+
TransposeBatchNormTraining, TransposeBatchNormInference,
29122+
TransposeBatchNormGrad, TransposeIf, TransposeFFT,
29123+
TransposeReshape, TransposeScatter>(context);
2907529124
patterns.add<TransposeElementwise>(true, context);
2907629125
}
2907729126

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2694,3 +2694,8 @@ def ApplyWhileElementwiseReductionToReducePatterns : EnzymeHLOPatternOp<
26942694
"while_elementwise_reduction_to_reduce"> {
26952695
let patterns = ["WhileElementwiseReductionToReduce"];
26962696
}
2697+
2698+
def ApplyTransposeScatterPatterns : EnzymeHLOPatternOp<
2699+
"transpose_scatter"> {
2700+
let patterns = ["TransposeScatter"];
2701+
}

src/enzyme_ad/jax/Utils.cpp

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2384,6 +2384,35 @@ SmallVector<int64_t> getInversePermutation(ArrayRef<int64_t> perm) {
23842384
return res;
23852385
}
23862386

2387+
SmallVector<int64_t> applyPermutationToDims(ArrayRef<int64_t> perm,
2388+
ArrayRef<int64_t> dims) {
2389+
SmallVector<int64_t> res(dims.size());
2390+
for (auto en : llvm::enumerate(dims)) {
2391+
res[en.index()] = perm[en.value()];
2392+
}
2393+
return res;
2394+
}
2395+
2396+
template <typename T>
2397+
SmallVector<T> applyPermutation(ArrayRef<int64_t> perm, ArrayRef<T> values) {
2398+
SmallVector<T> res;
2399+
for (auto p : perm) {
2400+
res.push_back(values[p]);
2401+
}
2402+
return res;
2403+
}
2404+
2405+
SmallVector<int64_t> applyInversePermutationToDims(ArrayRef<int64_t> perm,
2406+
ArrayRef<int64_t> dims) {
2407+
return applyPermutationToDims(getInversePermutation(perm), dims);
2408+
}
2409+
2410+
template <typename T>
2411+
SmallVector<T> applyInversePermutation(ArrayRef<int64_t> perm,
2412+
ArrayRef<T> values) {
2413+
return applyPermutation(getInversePermutation(perm), values);
2414+
}
2415+
23872416
Value transposeSliceHelper(stablehlo::TransposeOp transpose,
23882417
PatternRewriter &rewriter, stablehlo::SliceOp op) {
23892418
return transposeSliceHelper(transpose, rewriter, op.getStartIndices(),
@@ -2449,10 +2478,9 @@ Value sliceTransposeHelper(stablehlo::TransposeOp transpose,
24492478
auto newUpdate =
24502479
TransposeOpCreate(rewriter, transpose->getLoc(), op.getUpdate(),
24512480
transpose.getPermutation());
2452-
SmallVector<Value> starts;
2453-
for (auto ind : getInversePermutation(transpose.getPermutation())) {
2454-
starts.push_back(op.getStartIndices()[ind]);
2455-
}
2481+
SmallVector<Value> startIndices = llvm::to_vector(op.getStartIndices());
2482+
auto starts = applyInversePermutation(transpose.getPermutation(),
2483+
ArrayRef<Value>(startIndices));
24562484
return stablehlo::DynamicUpdateSliceOp::create(
24572485
rewriter, transpose->getLoc(), transpose.getOperand(), newUpdate, starts);
24582486
}
@@ -2461,12 +2489,10 @@ Value sliceTransposeHelper(stablehlo::TransposeOp transpose,
24612489
PatternRewriter &rewriter, ArrayRef<int64_t> starts,
24622490
ArrayRef<int64_t> limits,
24632491
ArrayRef<int64_t> strides) {
2464-
SmallVector<int64_t> start, end, step;
2465-
for (auto ind : getInversePermutation(transpose.getPermutation())) {
2466-
start.push_back(starts[ind]);
2467-
end.push_back(limits[ind]);
2468-
step.push_back(strides[ind]);
2469-
}
2492+
auto invPerm = getInversePermutation(transpose.getPermutation());
2493+
auto start = applyPermutation(invPerm, starts);
2494+
auto end = applyPermutation(invPerm, limits);
2495+
auto step = applyPermutation(invPerm, strides);
24702496
return SliceOpCreate(rewriter, transpose.getLoc(), transpose.getOperand(),
24712497
start, end, step);
24722498
}
@@ -2475,12 +2501,9 @@ Value sliceTransposeHelper(stablehlo::TransposeOp transpose,
24752501
PatternRewriter &rewriter,
24762502
ArrayRef<Value> sliceStarts,
24772503
ArrayRef<int64_t> sliceSizes) {
2478-
SmallVector<int64_t> sizes;
2479-
SmallVector<Value> starts;
2480-
for (auto ind : getInversePermutation(transpose.getPermutation())) {
2481-
sizes.push_back(sliceSizes[ind]);
2482-
starts.push_back(sliceStarts[ind]);
2483-
}
2504+
auto invPerm = getInversePermutation(transpose.getPermutation());
2505+
auto sizes = applyPermutation(invPerm, sliceSizes);
2506+
auto starts = applyPermutation(invPerm, sliceStarts);
24842507
return DynamicSliceOpCreate(rewriter, transpose.getLoc(),
24852508
transpose.getOperand(), starts, sizes);
24862509
}
@@ -2549,6 +2572,15 @@ bool isFusible(Operation *op, stablehlo::BroadcastInDimOp bcast) {
25492572
.Default([](auto other) { return matchPattern(other, m_Constant()); });
25502573
}
25512574

2575+
bool isFusible(Operation *op, stablehlo::TransposeOp transpose) {
2576+
return TypeSwitch<Operation *, bool>(op)
2577+
.Case<stablehlo::TransposeOp, stablehlo::BroadcastInDimOp>(
2578+
[](auto prevOp) { return true; })
2579+
.Case<stablehlo::ReshapeOp>(
2580+
[](auto reshape) { return reshapeIsTranspose(reshape); })
2581+
.Default([](auto other) { return matchPattern(other, m_Constant()); });
2582+
}
2583+
25522584
bool IsTensorFilled(Value input) {
25532585
// Use a worklist-based approach to traverse the SSA def-use chain
25542586
// and determine if the value is known to be a dense (fully-populated) matrix.

src/enzyme_ad/jax/Utils.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,20 @@ bool canFuseIntoReduce(Operation *op);
12211221

12221222
llvm::SmallVector<int64_t> getInversePermutation(ArrayRef<int64_t> perm);
12231223

1224+
llvm::SmallVector<int64_t> applyPermutationToDims(ArrayRef<int64_t> perm,
1225+
ArrayRef<int64_t> dims);
1226+
1227+
llvm::SmallVector<int64_t>
1228+
applyInversePermutationToDims(ArrayRef<int64_t> perm, ArrayRef<int64_t> dims);
1229+
1230+
template <typename T>
1231+
llvm::SmallVector<T> applyPermutation(ArrayRef<int64_t> perm,
1232+
ArrayRef<T> values);
1233+
1234+
template <typename T>
1235+
llvm::SmallVector<T> applyInversePermutation(ArrayRef<int64_t> perm,
1236+
ArrayRef<T> values);
1237+
12241238
Value transposeSliceHelper(stablehlo::TransposeOp transpose,
12251239
PatternRewriter &rewriter, stablehlo::SliceOp op);
12261240
Value transposeSliceHelper(stablehlo::TransposeOp transpose,
@@ -1258,6 +1272,7 @@ Value sliceTransposeHelper(stablehlo::TransposeOp transpose,
12581272
bool isFusible(stablehlo::TransposeOp transpose, Operation *op);
12591273
bool isFusible(Operation *op, stablehlo::BroadcastInDimOp bcast);
12601274
bool isFusible(Operation *op, stablehlo::ReshapeOp reshape);
1275+
bool isFusible(Operation *op, stablehlo::TransposeOp transpose);
12611276

12621277
template <typename OpTy>
12631278
Value getIdentityValueForOp(OpBuilder &builder, Location loc, Type elemType);

src/enzyme_ad/jax/primitives.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,7 @@ def optimization_passes(
570570
"transpose_if",
571571
"transpose_fft",
572572
"transpose_reshape",
573+
"transpose_scatter",
573574
]
574575
elif transpose_propagate == "down":
575576
transform_passes_list += [
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// RUN: enzymexlamlir-opt %s --pass-pipeline='builtin.module(enzyme-hlo-opt{passses=65536},enzyme-hlo-opt)' | FileCheck %s
2+
3+
func.func @main1(%arg0: tensor<5x2xf32>, %arg1: tensor<4x3x2xf32>) -> tensor<5x2xf32> {
4+
%c = stablehlo.constant dense<[[[0, 1, 2, 3], [3, 1, 0, 2], [2, 4, 4, 2]]]> : tensor<1x3x4xi64>
5+
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<5x2xf32>) -> tensor<2x5xf32>
6+
%1 = stablehlo.transpose %arg1, dims = [2, 1, 0] : (tensor<4x3x2xf32>) -> tensor<2x3x4xf32>
7+
%2 = "stablehlo.scatter"(%0, %c, %1) <{scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>}> ({
8+
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
9+
%4 = stablehlo.multiply %arg2, %arg3 : tensor<f32>
10+
stablehlo.return %4 : tensor<f32>
11+
}) : (tensor<2x5xf32>, tensor<1x3x4xi64>, tensor<2x3x4xf32>) -> tensor<2x5xf32>
12+
%3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<2x5xf32>) -> tensor<5x2xf32>
13+
return %3 : tensor<5x2xf32>
14+
}
15+
16+
// CHECK: func.func @main1(%arg0: tensor<5x2xf32>, %arg1: tensor<4x3x2xf32>) -> tensor<5x2xf32> {
17+
// CHECK-NEXT{LITERAL}: %c = stablehlo.constant dense<[[[0, 1, 2, 3], [3, 1, 0, 2], [2, 4, 4, 2]]]> : tensor<1x3x4xi64>
18+
// CHECK-NEXT: %0 = stablehlo.transpose %arg1, dims = [2, 1, 0] : (tensor<4x3x2xf32>) -> tensor<2x3x4xf32>
19+
// CHECK-NEXT: %1 = "stablehlo.scatter"(%arg0, %c, %0) <{scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [0]>}> ({
20+
// CHECK-NEXT: ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
21+
// CHECK-NEXT: %2 = stablehlo.multiply %arg2, %arg3 : tensor<f32>
22+
// CHECK-NEXT: stablehlo.return %2 : tensor<f32>
23+
// CHECK-NEXT: }) : (tensor<5x2xf32>, tensor<1x3x4xi64>, tensor<2x3x4xf32>) -> tensor<5x2xf32>
24+
// CHECK-NEXT: return %1 : tensor<5x2xf32>
25+
// CHECK-NEXT: }
26+
27+
func.func @main2(%arg0: tensor<5x2xf32>, %arg1: tensor<4x3x2xf32>) -> tensor<5x2xf32> {
28+
%c = stablehlo.constant dense<[[[0, 1, 2, 3], [3, 1, 0, 2], [2, 4, 4, 2]]]> : tensor<1x3x4xi64>
29+
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<5x2xf32>) -> tensor<2x5xf32>
30+
%1 = stablehlo.transpose %arg1, dims = [2, 1, 0] : (tensor<4x3x2xf32>) -> tensor<2x3x4xf32>
31+
%2 = "stablehlo.scatter"(%0, %c, %1) <{scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>}> ({
32+
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
33+
%4 = stablehlo.add %arg2, %arg3 : tensor<f32>
34+
stablehlo.return %4 : tensor<f32>
35+
}) : (tensor<2x5xf32>, tensor<1x3x4xi64>, tensor<2x3x4xf32>) -> tensor<2x5xf32>
36+
%3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<2x5xf32>) -> tensor<5x2xf32>
37+
return %3 : tensor<5x2xf32>
38+
}
39+
40+
// CHECK: func.func @main2(%arg0: tensor<5x2xf32>, %arg1: tensor<4x3x2xf32>) -> tensor<5x2xf32> {
41+
// CHECK-NEXT{LITERAL}: %c = stablehlo.constant dense<[[[0, 1, 2, 3], [3, 1, 0, 2], [2, 4, 4, 2]]]> : tensor<1x3x4xi64>
42+
// CHECK-NEXT: %0 = stablehlo.transpose %arg1, dims = [2, 1, 0] : (tensor<4x3x2xf32>) -> tensor<2x3x4xf32>
43+
// CHECK-NEXT: %1 = "stablehlo.scatter"(%arg0, %c, %0) <{scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [0]>}> ({
44+
// CHECK-NEXT: ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
45+
// CHECK-NEXT: %2 = stablehlo.add %arg2, %arg3 : tensor<f32>
46+
// CHECK-NEXT: stablehlo.return %2 : tensor<f32>
47+
// CHECK-NEXT: }) : (tensor<5x2xf32>, tensor<1x3x4xi64>, tensor<2x3x4xf32>) -> tensor<5x2xf32>
48+
// CHECK-NEXT: return %1 : tensor<5x2xf32>
49+
// CHECK-NEXT: }
50+
51+
func.func @main3(%arg0: tensor<32x32xf32>, %arg1: tensor<32xf32>) -> tensor<32x32xf32> {
52+
%c = stablehlo.constant dense<[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], [10, 10], [11, 11], [12, 12], [13, 13], [14, 14], [15, 15], [16, 16], [17, 17], [18, 18], [19, 19], [20, 20], [21, 21], [22, 22], [23, 23], [24, 24], [25, 25], [26, 26], [27, 27], [28, 28], [29, 29], [30, 30], [31, 31]]> : tensor<32x2xi64>
53+
%cst = stablehlo.constant dense<0.000000e+00> : tensor<32x32xf32>
54+
%0 = "stablehlo.scatter"(%cst, %c, %arg1) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = true}> ({
55+
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
56+
stablehlo.return %arg3 : tensor<f32>
57+
}) {enzymexla.symmetric_matrix = [#enzymexla<guaranteed NOTGUARANTEED>]} : (tensor<32x32xf32>, tensor<32x2xi64>, tensor<32xf32>) -> tensor<32x32xf32>
58+
%1 = stablehlo.transpose %0, dims = [1, 0] : (tensor<32x32xf32>) -> tensor<32x32xf32>
59+
%2 = stablehlo.add %arg0, %1 {enzymexla.symmetric_matrix = [#enzymexla<guaranteed NOTGUARANTEED>]} : tensor<32x32xf32>
60+
return %2 : tensor<32x32xf32>
61+
}
62+
63+
// CHECK: func.func @main3(%arg0: tensor<32x32xf32>, %arg1: tensor<32xf32>) -> tensor<32x32xf32> {
64+
// CHECK-NEXT{LITERAL}: %c = stablehlo.constant dense<[[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9], [10, 10], [11, 11], [12, 12], [13, 13], [14, 14], [15, 15], [16, 16], [17, 17], [18, 18], [19, 19], [20, 20], [21, 21], [22, 22], [23, 23], [24, 24], [25, 25], [26, 26], [27, 27], [28, 28], [29, 29], [30, 30], [31, 31]]> : tensor<32x2xi64>
65+
// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<32x32xf32>
66+
// CHECK-NEXT: %0 = "stablehlo.scatter"(%cst, %c, %arg1) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [1, 0], index_vector_dim = 1>, unique_indices = true}> ({
67+
// CHECK-NEXT: ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
68+
// CHECK-NEXT: stablehlo.return %arg3 : tensor<f32>
69+
// CHECK-NEXT: }) : (tensor<32x32xf32>, tensor<32x2xi64>, tensor<32xf32>) -> tensor<32x32xf32>
70+
// CHECK-NEXT: %1 = stablehlo.add %arg0, %0 {enzymexla.symmetric_matrix = [#enzymexla<guaranteed NOTGUARANTEED>]} : tensor<32x32xf32>
71+
// CHECK-NEXT: return %1 : tensor<32x32xf32>
72+
// CHECK-NEXT: }
73+
74+
func.func @main4(%arg0: tensor<3x4x4xf64>) -> tensor<3x4x4xf64> {
75+
%cst = stablehlo.constant dense<2.000000e+00> : tensor<f64>
76+
%cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<3x4x3xf64>
77+
%c = stablehlo.constant dense<[[0], [2], [1]]> : tensor<3x1xi64>
78+
%0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x4x4xf64>) -> tensor<4x4x3xf64>
79+
%1 = "stablehlo.scatter"(%0, %c, %cst_0) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1, 2], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = true}> ({
80+
^bb0(%arg1: tensor<f64>, %arg2: tensor<f64>):
81+
stablehlo.return %cst : tensor<f64>
82+
}) : (tensor<4x4x3xf64>, tensor<3x1xi64>, tensor<3x4x3xf64>) -> tensor<4x4x3xf64>
83+
%2 = stablehlo.transpose %1, dims = [2, 1, 0] : (tensor<4x4x3xf64>) -> tensor<3x4x4xf64>
84+
return %2 : tensor<3x4x4xf64>
85+
}
86+
87+
// CHECK: func.func @main4(%arg0: tensor<3x4x4xf64>) -> tensor<3x4x4xf64> {
88+
// CHECK-NEXT: %cst = stablehlo.constant dense<2.000000e+00> : tensor<f64>
89+
// CHECK-NEXT: %cst_0 = stablehlo.constant dense<2.000000e+00> : tensor<3x4x3xf64>
90+
// CHECK-NEXT{LITERAL}: %c = stablehlo.constant dense<[[0], [2], [1]]> : tensor<3x1xi64>
91+
// CHECK-NEXT: %0 = "stablehlo.scatter"(%arg0, %c, %cst_0) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1, 2], inserted_window_dims = [0], scatter_dims_to_operand_dims = [2], index_vector_dim = 1>, unique_indices = true}> ({
92+
// CHECK-NEXT: ^bb0(%arg1: tensor<f64>, %arg2: tensor<f64>):
93+
// CHECK-NEXT: stablehlo.return %cst : tensor<f64>
94+
// CHECK-NEXT: }) : (tensor<3x4x4xf64>, tensor<3x1xi64>, tensor<3x4x3xf64>) -> tensor<3x4x4xf64>
95+
// CHECK-NEXT: return %0 : tensor<3x4x4xf64>
96+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)