@@ -43,15 +43,15 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32}
4343
4444// -----
4545// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[32, 0], [64, 0]], block = []}>
46- // CHECK-LABEL: store_dword
46+ // CHECK-LABEL: store_dword_128x128
4747// CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked>
4848// CHECK-DAG: %[[PTR:.+]] = ttg.convert_layout %{{.*}} : tensor<128x128x!tt.ptr<f16>, #mma> -> tensor<128x128x!tt.ptr<f16>, #linear>
4949// CHECK-DAG: %[[VAL:.+]] = ttg.convert_layout %{{.*}} : tensor<128x128xf16, #mma> -> tensor<128x128xf16, #linear>
5050// CHECK: tt.store %[[PTR]], %[[VAL]] : tensor<128x128x!tt.ptr<f16>, #linear>
5151#blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [16 , 4 ], warpsPerCTA = [4 , 1 ], order = [0 , 1 ]}>
5252#mma = #ttg.amd_mfma <{versionMajor = 4 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [32 , 32 ], isTransposed = true }>
5353module attributes {" ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 64 : i32 } {
54- tt.func public @store_dword (%arg0: !tt.ptr <f16 >) attributes {noinline = false } {
54+ tt.func public @store_dword_128x128 (%arg0: !tt.ptr <f16 >) attributes {noinline = false } {
5555 %cst = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #mma >
5656 %cst_0 = arith.constant dense <1.230000e+02 > : tensor <128 x128 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
5757 %cst_1 = arith.constant dense <1.230000e+02 > : tensor <128 x128 xf32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
@@ -63,3 +63,26 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32}
6363 tt.return
6464 }
6565}
66+
67+ // -----
68+ // CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 16], [0, 128], [64, 0], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[0, 32], [0, 64], [32, 0]], block = []}>
69+ // CHECK-LABEL: store_dword_256x256
70+ // CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked>
71+ // CHECK-DAG: %[[PTR:.+]] = ttg.convert_layout %{{.*}} : tensor<256x256x!tt.ptr<f16>, #mma> -> tensor<256x256x!tt.ptr<f16>, #linear>
72+ // CHECK-DAG: %[[VAL:.+]] = ttg.convert_layout %{{.*}} : tensor<256x256xf16, #mma> -> tensor<256x256xf16, #linear>
73+ // CHECK: tt.store %[[PTR]], %[[VAL]] : tensor<256x256x!tt.ptr<f16>, #linear>
74+ #blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [2 , 32 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
75+ #mma = #ttg.amd_mfma <{versionMajor = 4 , versionMinor = 0 , warpsPerCTA = [2 , 4 ], instrShape = [32 , 32 ], isTransposed = true }>
76+ module attributes {" ttg.num-warps" = 8 : i32 , " ttg.threads-per-warp" = 64 : i32 } {
77+ tt.func public @store_dword_256x256 (%arg0: !tt.ptr <f16 >) attributes {noinline = false } {
78+ %cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 , #mma >
79+ %cst_0 = arith.constant dense <1.230000e+02 > : tensor <256 x256 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
80+ %cst_1 = arith.constant dense <1.230000e+02 > : tensor <256 x256 xf32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
81+ %0 = tt.dot %cst_0 , %cst_1 , %cst : tensor <256 x256 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * tensor <256 x256 xf32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <256 x256 xf32 , #mma >
82+ %1 = ttg.convert_layout %0 : tensor <256 x256 xf32 , #mma > -> tensor <256 x256 xf32 , #blocked >
83+ %2 = arith.truncf %1 : tensor <256 x256 xf32 , #blocked > to tensor <256 x256 xf16 , #blocked >
84+ %3 = tt.splat %arg0 : !tt.ptr <f16 > -> tensor <256 x256 x!tt.ptr <f16 >, #blocked >
85+ tt.store %3 , %2 : tensor <256 x256 x!tt.ptr <f16 >, #blocked >
86+ tt.return
87+ }
88+ }
0 commit comments