@@ -36,9 +36,9 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
3636#shared = #ttg.swizzled_shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [0 , 1 ]}>
3737#smem = #ttg.shared_memory
3838module attributes {" ttg.num-warps" = 8 : i32 , " ttg.threads-per-warp" = 64 : i32 } {
39- tt.func public @alloc_convert_small_load (%arg0: tensor <128 x128 xf16 , #blocked >, %arg1: tensor <128 x 128 x f16 , #blocked >) attributes {noinline = false } {
39+ tt.func public @alloc_convert_small_load (%arg0: tensor <128 x128 xf16 , #blocked >, %arg1: tensor <256 x 128 x f16 , #blocked >) attributes {noinline = false } {
4040 %1 = ttg.local_alloc %arg0 : (tensor <128 x128 xf16 , #blocked >) -> !ttg.memdesc <128 x128 xf16 , #shared , #smem >
41- %2 = ttg.convert_layout %arg1 : tensor <128 x 128 x f16 , #blocked > -> tensor <128 x 128 x f16 , #mma >
41+ %2 = ttg.convert_layout %arg1 : tensor <256 x 128 x f16 , #blocked > -> tensor <256 x 128 x f16 , #mma >
4242 %3 = ttg.local_load %1 : !ttg.memdesc <128 x128 xf16 , #shared , #smem > -> tensor <128 x128 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>>
4343 tt.return
4444 }
@@ -62,9 +62,9 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
6262#shared = #ttg.swizzled_shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [0 , 1 , 2 ]}>
6363#smem = #ttg.shared_memory
6464module attributes {" ttg.num-warps" = 8 : i32 , " ttg.threads-per-warp" = 64 : i32 } {
65- tt.func public @alloc_convert_3d_load (%arg0: tensor <1 x128 x128 xf16 , #blocked >, %arg1: tensor <1 x 128 x 128 x f16 , #blocked >) attributes {noinline = false } {
65+ tt.func public @alloc_convert_3d_load (%arg0: tensor <1 x128 x128 xf16 , #blocked >, %arg1: tensor <1 x 256 x 128 x f16 , #blocked >) attributes {noinline = false } {
6666 %1 = ttg.local_alloc %arg0 : (tensor <1 x128 x128 xf16 , #blocked >) -> !ttg.memdesc <1 x128 x128 xf16 , #shared , #smem >
67- %2 = ttg.convert_layout %arg1 : tensor <1 x 128 x 128 x f16 , #blocked > -> tensor <1 x 128 x 128 x f16 , #mma >
67+ %2 = ttg.convert_layout %arg1 : tensor <1 x 256 x 128 x f16 , #blocked > -> tensor <1 x 256 x 128 x f16 , #mma >
6868 %3 = ttg.local_load %1 : !ttg.memdesc <1 x128 x128 xf16 , #shared , #smem > -> tensor <1 x128 x128 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>>
6969 tt.return
7070 }
@@ -87,9 +87,9 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
8787#shared = #ttg.swizzled_shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [0 , 1 ]}>
8888#smem = #ttg.shared_memory
8989module attributes {" ttg.num-warps" = 8 : i32 , " ttg.threads-per-warp" = 64 : i32 } {
90- tt.func public @alloc_convert_32k_limit (%arg0: tensor <64 x128 xf16 , #blocked >, %arg1: tensor <64 x 128 x f16 , #blocked >) attributes {noinline = false } {
90+ tt.func public @alloc_convert_32k_limit (%arg0: tensor <64 x128 xf16 , #blocked >, %arg1: tensor <128 x 128 x f16 , #blocked >) attributes {noinline = false } {
9191 %1 = ttg.local_alloc %arg0 : (tensor <64 x128 xf16 , #blocked >) -> !ttg.memdesc <64 x128 xf16 , #shared , #smem >
92- %2 = ttg.convert_layout %arg1 : tensor <64 x 128 x f16 , #blocked > -> tensor <64 x 128 x f16 , #mma >
92+ %2 = ttg.convert_layout %arg1 : tensor <128 x 128 x f16 , #blocked > -> tensor <128 x 128 x f16 , #mma >
9393 %3 = ttg.local_load %1 : !ttg.memdesc <64 x128 xf16 , #shared , #smem > -> tensor <64 x128 xf16 , #ttg.dot_op <{opIdx = 0 , kWidth = 4 , parent = #mma }>>
9494 tt.return
9595 }
@@ -98,29 +98,29 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
9898// -----
9999
100100// Check that optimization correctly handles LDS shortcut (see #mma2 -> #dotop2 conversion)
101- // CHECK-DAG: [[BLOCKED_1:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [4 , 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
101+ // CHECK-DAG: [[BLOCKED_1:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [8 , 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
102102// CHECK-DAG: [[BLOCKED_2:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [0, 1]}>
103103// CHECK-DAG: [[MMA_1:#[a-z0-9]*]] = #ttg.amd_mfma<{version = 2, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
104104// CHECK-DAG: [[MMA_2:#[a-z0-9]*]] = #ttg.amd_mfma<{version = 2, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
105105// CHECK-DAG: [[SHARED:#[a-z0-9]*]] = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>
106106
107107// CHECK: tt.func public @mfma_dot_shortcut([[ARG_0:%[a-z0-9]*]]: {{.*}}, [[ARG_1:%[a-z0-9]*]]: {{.*}}, [[ARG_2:%[a-z0-9]*]]: {{.*}})
108108// CHECK: [[ALLOC:%[0-9]+]] = ttg.local_alloc [[ARG_0]] : (tensor<128x128xf16, [[BLOCKED_1]]>) -> !ttg.memdesc<128x128xf16, [[SHARED]], #smem>
109- // CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = ttg.convert_layout [[ARG_1]] {{.*}}: tensor<128x128xf32 , [[BLOCKED_1]]> -> tensor<128x128xf32 , [[BLOCKED_2]]>
110- // CHECK: [[CONVERT_1:%[0-9]+]] = ttg.convert_layout [[INTERMEDIATE_CONV]] {{.*}}: tensor<128x128xf32 , [[BLOCKED_2]]> -> tensor<128x128xf32 , [[MMA_2]]>
109+ // CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = ttg.convert_layout [[ARG_1]] {{.*}}: tensor<256x128xf32 , [[BLOCKED_1]]> -> tensor<256x128xf32 , [[BLOCKED_2]]>
110+ // CHECK: [[CONVERT_1:%[0-9]+]] = ttg.convert_layout [[INTERMEDIATE_CONV]] {{.*}}: tensor<256x128xf32 , [[BLOCKED_2]]> -> tensor<256x128xf32 , [[MMA_2]]>
111111// CHECK: [[CONVERT_2:%[0-9]+]] = ttg.convert_layout [[ARG_2]] {{.*}}: tensor<256x128xf16, [[MMA_1]]> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_1]], kWidth = 4}>>
112112// CHECK: [[LOAD:%[0-9]+]] = ttg.local_load [[ALLOC]] : !ttg.memdesc<128x128xf16, [[SHARED]], #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_2]], kWidth = 4}>>
113- #blocked = #ttg.blocked <{sizePerThread = [4 , 1 ], threadsPerWarp = [16 , 4 ], warpsPerCTA = [1 , 8 ], order = [0 , 1 ]}>
113+ #blocked = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [16 , 4 ], warpsPerCTA = [1 , 8 ], order = [0 , 1 ]}>
114114#mma1 = #ttg.amd_mfma <{version = 2 , warpsPerCTA = [1 , 8 ], instrShape = [32 , 32 ], isTransposed = false }>
115115#mma2 = #ttg.amd_mfma <{version = 2 , warpsPerCTA = [8 , 1 ], instrShape = [32 , 32 ], isTransposed = true }>
116116#dotop1 = #ttg.dot_op <{opIdx =0 , parent =#mma1 , kWidth =4 }>
117117#dotop2 = #ttg.dot_op <{opIdx =0 , parent =#mma2 , kWidth =4 }>
118118#shared = #ttg.swizzled_shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [0 , 1 ]}>
119119#smem = #ttg.shared_memory
120120module attributes {" ttg.num-warps" = 8 : i32 , " ttg.threads-per-warp" = 64 : i32 } {
121- tt.func public @mfma_dot_shortcut (%arg0: tensor <128 x128 xf16 , #blocked >, %arg1: tensor <128 x 128 x f32 , #blocked >, %arg2: tensor <256 x128 xf16 , #mma2 >) attributes {noinline = false } {
121+ tt.func public @mfma_dot_shortcut (%arg0: tensor <128 x128 xf16 , #blocked >, %arg1: tensor <256 x 128 x f32 , #blocked >, %arg2: tensor <256 x128 xf16 , #mma2 >) attributes {noinline = false } {
122122 %alloc = ttg.local_alloc %arg0 : (tensor <128 x128 xf16 , #blocked >) -> !ttg.memdesc <128 x128 xf16 , #shared , #smem >
123- %convert_1 = ttg.convert_layout %arg1 : tensor <128 x 128 x f32 , #blocked > -> tensor <128 x 128 x f32 , #mma1 >
123+ %convert_1 = ttg.convert_layout %arg1 : tensor <256 x 128 x f32 , #blocked > -> tensor <256 x 128 x f32 , #mma1 >
124124 %convert_2 = ttg.convert_layout %arg2 : tensor <256 x128 xf16 , #mma2 > -> tensor <256 x128 xf16 , #dotop2 >
125125 %load = ttg.local_load %alloc : !ttg.memdesc <128 x128 xf16 , #shared , #smem > -> tensor <128 x128 xf16 , #dotop1 >
126126 tt.return
0 commit comments