Skip to content

Commit deee78f

Browse files
[BACKEND] Fold transpose(splat_const) (#5259)
Add folding for a transpose of a splat constant. --------- Co-authored-by: peterbell10 <[email protected]>
1 parent 22e212b commit deee78f

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,11 @@ OpFoldResult TransOp::fold(FoldAdaptor adaptor) {
199199
return getResult();
200200
}
201201

202+
// Eliminate splat constant transpose ops.
203+
if (auto attr =
204+
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSrc()))
205+
return attr.reshape(getType());
206+
202207
return {};
203208
}
204209

test/Triton/canonicalize.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,14 @@ tt.func @fold_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
173173
// CHECK-NEXT: tt.return %[[cst]] : tensor<8x2xf32>
174174
tt.return %bst_out : tensor<8x2xf32>
175175
}
176+
177+
// -----
178+
179+
// CHECK-LABEL: @fold_transpose_constant
180+
tt.func @fold_transpose_constant() -> tensor<128x16xf32> {
181+
// CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<128x16xf32>
182+
%cst = arith.constant dense<1.0> : tensor<16x128xf32>
183+
%r = tt.trans %cst {order = array<i32: 1, 0>} : tensor<16x128xf32> -> tensor<128x16xf32>
184+
// CHECK-NEXT: tt.return %[[cst]] : tensor<128x16xf32>
185+
tt.return %r : tensor<128x16xf32>
186+
}

0 commit comments

Comments
 (0)