@@ -43,15 +43,15 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32}
43
43
44
44
// -----
45
45
// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[32, 0], [64, 0]], block = []}>
46
- // CHECK-LABEL: store_dword
46
+ // CHECK-LABEL: store_dword_128x128
47
47
// CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked>
48
48
// CHECK-DAG: %[[PTR:.+]] = ttg.convert_layout %{{.*}} : tensor<128x128x!tt.ptr<f16>, #mma> -> tensor<128x128x!tt.ptr<f16>, #linear>
49
49
// CHECK-DAG: %[[VAL:.+]] = ttg.convert_layout %{{.*}} : tensor<128x128xf16, #mma> -> tensor<128x128xf16, #linear>
50
50
// CHECK: tt.store %[[PTR]], %[[VAL]] : tensor<128x128x!tt.ptr<f16>, #linear>
51
51
#blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [16 , 4 ], warpsPerCTA = [4 , 1 ], order = [0 , 1 ]}>
52
52
#mma = #ttg.amd_mfma <{versionMajor = 4 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [32 , 32 ], isTransposed = true }>
53
53
module attributes {" ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 64 : i32 } {
54
- tt.func public @store_dword (%arg0: !tt.ptr <f16 >) attributes {noinline = false } {
54
+ tt.func public @store_dword_128x128 (%arg0: !tt.ptr <f16 >) attributes {noinline = false } {
55
55
%cst = arith.constant dense <0.000000e+00 > : tensor <128 x128 xf32 , #mma >
56
56
%cst_0 = arith.constant dense <1.230000e+02 > : tensor <128 x128 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
57
57
%cst_1 = arith.constant dense <1.230000e+02 > : tensor <128 x128 xf32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
@@ -63,3 +63,26 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32}
63
63
tt.return
64
64
}
65
65
}
66
+
67
+ // -----
68
+ // CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 16], [0, 128], [64, 0], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[0, 32], [0, 64], [32, 0]], block = []}>
69
+ // CHECK-LABEL: store_dword_256x256
70
+ // CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked>
71
+ // CHECK-DAG: %[[PTR:.+]] = ttg.convert_layout %{{.*}} : tensor<256x256x!tt.ptr<f16>, #mma> -> tensor<256x256x!tt.ptr<f16>, #linear>
72
+ // CHECK-DAG: %[[VAL:.+]] = ttg.convert_layout %{{.*}} : tensor<256x256xf16, #mma> -> tensor<256x256xf16, #linear>
73
+ // CHECK: tt.store %[[PTR]], %[[VAL]] : tensor<256x256x!tt.ptr<f16>, #linear>
74
+ #blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [2 , 32 ], warpsPerCTA = [8 , 1 ], order = [1 , 0 ]}>
75
+ #mma = #ttg.amd_mfma <{versionMajor = 4 , versionMinor = 0 , warpsPerCTA = [2 , 4 ], instrShape = [32 , 32 ], isTransposed = true }>
76
+ module attributes {" ttg.num-warps" = 8 : i32 , " ttg.threads-per-warp" = 64 : i32 } {
77
+ tt.func public @store_dword_256x256 (%arg0: !tt.ptr <f16 >) attributes {noinline = false } {
78
+ %cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 , #mma >
79
+ %cst_0 = arith.constant dense <1.230000e+02 > : tensor <256 x256 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
80
+ %cst_1 = arith.constant dense <1.230000e+02 > : tensor <256 x256 xf32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
81
+ %0 = tt.dot %cst_0 , %cst_1 , %cst : tensor <256 x256 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * tensor <256 x256 xf32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <256 x256 xf32 , #mma >
82
+ %1 = ttg.convert_layout %0 : tensor <256 x256 xf32 , #mma > -> tensor <256 x256 xf32 , #blocked >
83
+ %2 = arith.truncf %1 : tensor <256 x256 xf32 , #blocked > to tensor <256 x256 xf16 , #blocked >
84
+ %3 = tt.splat %arg0 : !tt.ptr <f16 > -> tensor <256 x256 x!tt.ptr <f16 >, #blocked >
85
+ tt.store %3 , %2 : tensor <256 x256 x!tt.ptr <f16 >, #blocked >
86
+ tt.return
87
+ }
88
+ }
0 commit comments