@@ -36,9 +36,9 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
36
36
#shared = #ttg.swizzled_shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [0 , 1 ]}>
37
37
#smem = #ttg.shared_memory
38
38
module attributes {" ttg.num-warps" = 8 : i32 , " ttg.threads-per-warp" = 64 : i32 } {
39
- tt.func public @alloc_convert_small_load (%arg0: tensor <128 x128 xf16 , #blocked >, %arg1: tensor <128 x 128 x f16 , #blocked >) attributes {noinline = false } {
39
+ tt.func public @alloc_convert_small_load (%arg0: tensor <128 x128 xf16 , #blocked >, %arg1: tensor <256 x 128 x f16 , #blocked >) attributes {noinline = false } {
40
40
%1 = ttg.local_alloc %arg0 : (tensor <128 x128 xf16 , #blocked >) -> !ttg.memdesc <128 x128 xf16 , #shared , #smem >
41
- %2 = ttg.convert_layout %arg1 : tensor <128 x 128 x f16 , #blocked > -> tensor <128 x 128 x f16 , #mma >
41
+ %2 = ttg.convert_layout %arg1 : tensor <256 x 128 x f16 , #blocked > -> tensor <256 x 128 x f16 , #mma >
42
42
%3 = ttg.local_load %1 : !ttg.memdesc <128 x128 xf16 , #shared , #smem > -> tensor <128 x128 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>>
43
43
tt.return
44
44
}
@@ -62,9 +62,9 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
62
62
#shared = #ttg.swizzled_shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [0 , 1 , 2 ]}>
63
63
#smem = #ttg.shared_memory
64
64
module attributes {" ttg.num-warps" = 8 : i32 , " ttg.threads-per-warp" = 64 : i32 } {
65
- tt.func public @alloc_convert_3d_load (%arg0: tensor <1 x128 x128 xf16 , #blocked >, %arg1: tensor <1 x 128 x 128 x f16 , #blocked >) attributes {noinline = false } {
65
+ tt.func public @alloc_convert_3d_load (%arg0: tensor <1 x128 x128 xf16 , #blocked >, %arg1: tensor <1 x 256 x 128 x f16 , #blocked >) attributes {noinline = false } {
66
66
%1 = ttg.local_alloc %arg0 : (tensor <1 x128 x128 xf16 , #blocked >) -> !ttg.memdesc <1 x128 x128 xf16 , #shared , #smem >
67
- %2 = ttg.convert_layout %arg1 : tensor <1 x 128 x 128 x f16 , #blocked > -> tensor <1 x 128 x 128 x f16 , #mma >
67
+ %2 = ttg.convert_layout %arg1 : tensor <1 x 256 x 128 x f16 , #blocked > -> tensor <1 x 256 x 128 x f16 , #mma >
68
68
%3 = ttg.local_load %1 : !ttg.memdesc <1 x128 x128 xf16 , #shared , #smem > -> tensor <1 x128 x128 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>>
69
69
tt.return
70
70
}
@@ -87,9 +87,9 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
87
87
#shared = #ttg.swizzled_shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [0 , 1 ]}>
88
88
#smem = #ttg.shared_memory
89
89
module attributes {" ttg.num-warps" = 8 : i32 , " ttg.threads-per-warp" = 64 : i32 } {
90
- tt.func public @alloc_convert_32k_limit (%arg0: tensor <64 x128 xf16 , #blocked >, %arg1: tensor <64 x 128 x f16 , #blocked >) attributes {noinline = false } {
90
+ tt.func public @alloc_convert_32k_limit (%arg0: tensor <64 x128 xf16 , #blocked >, %arg1: tensor <128 x 128 x f16 , #blocked >) attributes {noinline = false } {
91
91
%1 = ttg.local_alloc %arg0 : (tensor <64 x128 xf16 , #blocked >) -> !ttg.memdesc <64 x128 xf16 , #shared , #smem >
92
- %2 = ttg.convert_layout %arg1 : tensor <64 x 128 x f16 , #blocked > -> tensor <64 x 128 x f16 , #mma >
92
+ %2 = ttg.convert_layout %arg1 : tensor <128 x 128 x f16 , #blocked > -> tensor <128 x 128 x f16 , #mma >
93
93
%3 = ttg.local_load %1 : !ttg.memdesc <64 x128 xf16 , #shared , #smem > -> tensor <64 x128 xf16 , #ttg.dot_op <{opIdx = 0 , kWidth = 4 , parent = #mma }>>
94
94
tt.return
95
95
}
@@ -98,29 +98,29 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
98
98
// -----
99
99
100
100
// Check that optimization correctly handles LDS shortcut (see #mma2 -> #dotop2 conversion)
101
- // CHECK-DAG: [[BLOCKED_1:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [4 , 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
101
+ // CHECK-DAG: [[BLOCKED_1:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [8 , 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
102
102
// CHECK-DAG: [[BLOCKED_2:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [0, 1]}>
103
103
// CHECK-DAG: [[MMA_1:#[a-z0-9]*]] = #ttg.amd_mfma<{version = 2, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
104
104
// CHECK-DAG: [[MMA_2:#[a-z0-9]*]] = #ttg.amd_mfma<{version = 2, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
105
105
// CHECK-DAG: [[SHARED:#[a-z0-9]*]] = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>
106
106
107
107
// CHECK: tt.func public @mfma_dot_shortcut([[ARG_0:%[a-z0-9]*]]: {{.*}}, [[ARG_1:%[a-z0-9]*]]: {{.*}}, [[ARG_2:%[a-z0-9]*]]: {{.*}})
108
108
// CHECK: [[ALLOC:%[0-9]+]] = ttg.local_alloc [[ARG_0]] : (tensor<128x128xf16, [[BLOCKED_1]]>) -> !ttg.memdesc<128x128xf16, [[SHARED]], #smem>
109
- // CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = ttg.convert_layout [[ARG_1]] {{.*}}: tensor<128x128xf32 , [[BLOCKED_1]]> -> tensor<128x128xf32 , [[BLOCKED_2]]>
110
- // CHECK: [[CONVERT_1:%[0-9]+]] = ttg.convert_layout [[INTERMEDIATE_CONV]] {{.*}}: tensor<128x128xf32 , [[BLOCKED_2]]> -> tensor<128x128xf32 , [[MMA_2]]>
109
+ // CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = ttg.convert_layout [[ARG_1]] {{.*}}: tensor<256x128xf32 , [[BLOCKED_1]]> -> tensor<256x128xf32 , [[BLOCKED_2]]>
110
+ // CHECK: [[CONVERT_1:%[0-9]+]] = ttg.convert_layout [[INTERMEDIATE_CONV]] {{.*}}: tensor<256x128xf32 , [[BLOCKED_2]]> -> tensor<256x128xf32 , [[MMA_2]]>
111
111
// CHECK: [[CONVERT_2:%[0-9]+]] = ttg.convert_layout [[ARG_2]] {{.*}}: tensor<256x128xf16, [[MMA_1]]> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_1]], kWidth = 4}>>
112
112
// CHECK: [[LOAD:%[0-9]+]] = ttg.local_load [[ALLOC]] : !ttg.memdesc<128x128xf16, [[SHARED]], #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_2]], kWidth = 4}>>
113
- #blocked = #ttg.blocked <{sizePerThread = [4 , 1 ], threadsPerWarp = [16 , 4 ], warpsPerCTA = [1 , 8 ], order = [0 , 1 ]}>
113
+ #blocked = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [16 , 4 ], warpsPerCTA = [1 , 8 ], order = [0 , 1 ]}>
114
114
#mma1 = #ttg.amd_mfma <{version = 2 , warpsPerCTA = [1 , 8 ], instrShape = [32 , 32 ], isTransposed = false }>
115
115
#mma2 = #ttg.amd_mfma <{version = 2 , warpsPerCTA = [8 , 1 ], instrShape = [32 , 32 ], isTransposed = true }>
116
116
#dotop1 = #ttg.dot_op <{opIdx =0 , parent =#mma1 , kWidth =4 }>
117
117
#dotop2 = #ttg.dot_op <{opIdx =0 , parent =#mma2 , kWidth =4 }>
118
118
#shared = #ttg.swizzled_shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [0 , 1 ]}>
119
119
#smem = #ttg.shared_memory
120
120
module attributes {" ttg.num-warps" = 8 : i32 , " ttg.threads-per-warp" = 64 : i32 } {
121
- tt.func public @mfma_dot_shortcut (%arg0: tensor <128 x128 xf16 , #blocked >, %arg1: tensor <128 x 128 x f32 , #blocked >, %arg2: tensor <256 x128 xf16 , #mma2 >) attributes {noinline = false } {
121
+ tt.func public @mfma_dot_shortcut (%arg0: tensor <128 x128 xf16 , #blocked >, %arg1: tensor <256 x 128 x f32 , #blocked >, %arg2: tensor <256 x128 xf16 , #mma2 >) attributes {noinline = false } {
122
122
%alloc = ttg.local_alloc %arg0 : (tensor <128 x128 xf16 , #blocked >) -> !ttg.memdesc <128 x128 xf16 , #shared , #smem >
123
- %convert_1 = ttg.convert_layout %arg1 : tensor <128 x 128 x f32 , #blocked > -> tensor <128 x 128 x f32 , #mma1 >
123
+ %convert_1 = ttg.convert_layout %arg1 : tensor <256 x 128 x f32 , #blocked > -> tensor <256 x 128 x f32 , #mma1 >
124
124
%convert_2 = ttg.convert_layout %arg2 : tensor <256 x128 xf16 , #mma2 > -> tensor <256 x128 xf16 , #dotop2 >
125
125
%load = ttg.local_load %alloc : !ttg.memdesc <128 x128 xf16 , #shared , #smem > -> tensor <128 x128 xf16 , #dotop1 >
126
126
tt.return
0 commit comments