File tree Expand file tree Collapse file tree 2 files changed +16
-0
lines changed Expand file tree Collapse file tree 2 files changed +16
-0
lines changed Original file line number Diff line number Diff line change @@ -199,6 +199,11 @@ OpFoldResult TransOp::fold(FoldAdaptor adaptor) {
199199 return getResult ();
200200 }
201201
202+ // Eliminate splat constant transpose ops.
203+ if (auto attr =
204+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSrc ()))
205+ return attr.reshape (getType ());
206+
202207 return {};
203208}
204209
Original file line number Diff line number Diff line change @@ -173,3 +173,14 @@ tt.func @fold_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> {
173173 // CHECK-NEXT: tt.return %[[cst]] : tensor<8x2xf32>
174174 tt.return %bst_out : tensor <8 x2 xf32 >
175175}
176+
177+ // -----
178+
179+ // CHECK-LABEL: @fold_transpose_constant
180+ tt.func @fold_transpose_constant () -> tensor <128 x16 xf32 > {
181+ // CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<128x16xf32>
182+ %cst = arith.constant dense <1.0 > : tensor <16 x128 xf32 >
183+ %r = tt.trans %cst {order = array<i32 : 1 , 0 >} : tensor <16 x128 xf32 > -> tensor <128 x16 xf32 >
184+ // CHECK-NEXT: tt.return %[[cst]] : tensor<128x16xf32>
185+ tt.return %r : tensor <128 x16 xf32 >
186+ }
You can’t perform that action at this time.
0 commit comments