@@ -44,17 +44,19 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
4444 }
4545}
4646
47- // FIXME: This was broken in https://github.com/triton-lang/triton/pull/5840
48- // // -----
47+ // -----
4948
49+ // CHECK-DAG: [[$BLOCKED1:#.*]] = #ttg.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}>
50+ // CHECK-DAG: [[$BLOCKED2:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 8, 8], warpsPerCTA = [1, 2, 4], order = [1, 2, 0]}>
51+ // CHECK-DAG: [[$MMA:#.*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}>
5052// Check that optimization works with 3d tensors
5153// in case of relatively small scratch buffer
52- // DISABLE- CHECK-LABEL: alloc_convert_3d_load
53- // DISABLE- CHECK-32KLIMIT-LABEL: alloc_convert_3d_load
54- // DISABLE- CHECK: %0 = ttg.local_alloc %arg0 : {{.*}}#blocked {{.*}}#shared
55- // DISABLE- CHECK: %1 = ttg.convert_layout %arg1 : {{.*}}#blocked {{.*}}#mma
56- // DISABLE- CHECK: %2 = ttg.convert_layout %1 : {{.*}}#mma {{.*}}#mma1
57- // DISABLE- CHECK: %3 = ttg.local_load %0 : {{.*}}#shared{{.*}}# ttg.dot_op<{opIdx = 0, parent = #mma1 , kWidth = 4}>>
54+ // CHECK-LABEL: alloc_convert_3d_load
55+ // CHECK-32KLIMIT-LABEL: alloc_convert_3d_load
56+ // CHECK: [[V0:%.*]] = ttg.local_alloc {{.*}}[[$BLOCKED1]] {{.*}}
57+ // CHECK: [[V1:%.*]] = ttg.convert_layout {{.*}}[[$BLOCKED1]] {{.*}}[[$BLOCKED2]]
58+ // CHECK: [[V2:%.*]] = ttg.convert_layout [[V1]] : {{.*}}[[$BLOCKED2]] {{.*}}[[$MMA]]
59+ // CHECK: [[V3:%.*]] = ttg.local_load [[V0]] : {{.*}}#ttg.dot_op<{opIdx = 0, parent = [[$MMA]] , kWidth = 4}>>
5860#blocked = #ttg.blocked <{sizePerThread = [1 , 8 , 1 ], threadsPerWarp = [1 , 16 , 4 ], warpsPerCTA = [1 , 1 , 8 ], order = [0 , 1 , 2 ]}>
5961#mma = #ttg.amd_mfma <{versionMajor = 2 , versionMinor = 0 , warpsPerCTA = [1 , 1 , 8 ], instrShape = [32 , 32 ], isTransposed = false }>
6062#shared = #ttg.swizzled_shared <{vec = 4 , perPhase = 1 , maxPhase = 16 , order = [0 , 1 , 2 ]}>
@@ -93,22 +95,21 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
9395 }
9496}
9597
96- // FIXME: This was broken in https://github.com/triton-lang/triton/pull/5840
9798// -----
9899
99100// Check that optimization correctly handles LDS shortcut (see #mma2 -> #dotop2 conversion)
100- // DISABLE- CHECK-DAG: [[BLOCKED_1:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
101- // DISABLE- CHECK-DAG: [[BLOCKED_2:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [0, 1]}>
102- // DISABLE- CHECK-DAG: [[MMA_1:#[a-z0-9]*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
103- // DISABLE- CHECK-DAG: [[MMA_2:#[a-z0-9]*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
104- // DISABLE- CHECK-DAG: [[SHARED:#[a-z0-9]*]] = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>
101+ // CHECK-DAG: [[BLOCKED_1:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
102+ // CHECK-DAG: [[BLOCKED_2:#[a-z0-9]*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [2, 4], order = [0, 1]}>
103+ // CHECK-DAG: [[MMA_1:#[a-z0-9]*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
104+ // CHECK-DAG: [[MMA_2:#[a-z0-9]*]] = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
105+ // CHECK-DAG: [[SHARED:#[a-z0-9]*]] = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>
105106
106- // DISABLE- CHECK: tt.func public @mfma_dot_shortcut([[ARG_0:%[a-z0-9]*]]: {{.*}}, [[ARG_1:%[a-z0-9]*]]: {{.*}}, [[ARG_2:%[a-z0-9]*]]: {{.*}})
107- // DISABLE- CHECK: [[ALLOC:%[0-9]+]] = ttg.local_alloc [[ARG_0]] : (tensor<128x128xf16, [[BLOCKED_1]]>) -> !ttg.memdesc<128x128xf16, [[SHARED]], #smem>
108- // DISABLE- CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = ttg.convert_layout [[ARG_1]] : tensor<128x128xf32, [[BLOCKED_1]]> -> tensor<128x128xf32, [[BLOCKED_2]]>
109- // DISABLE- CHECK: [[CONVERT_1:%[0-9]+]] = ttg.convert_layout [[INTERMEDIATE_CONV]] : tensor<128x128xf32, [[BLOCKED_2]]> -> tensor<128x128xf32, [[MMA_2]]>
110- // DISABLE- CHECK: [[CONVERT_2:%[0-9]+]] = ttg.convert_layout [[ARG_2]] : tensor<256x128xf16, [[MMA_1]]> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_1]], kWidth = 4}>>
111- // DISABLE- CHECK: [[LOAD:%[0-9]+]] = ttg.local_load [[ALLOC]] : !ttg.memdesc<128x128xf16, [[SHARED]], #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_2]], kWidth = 4}>>
107+ // CHECK: tt.func public @mfma_dot_shortcut([[ARG_0:%[a-z0-9]*]]: {{.*}}, [[ARG_1:%[a-z0-9]*]]: {{.*}}, [[ARG_2:%[a-z0-9]*]]: {{.*}})
108+ // CHECK: [[ALLOC:%[0-9]+]] = ttg.local_alloc [[ARG_0]] : (tensor<128x128xf16, [[BLOCKED_1]]>) -> !ttg.memdesc<128x128xf16, [[SHARED]], #smem>
109+ // CHECK: [[INTERMEDIATE_CONV:%[0-9]+]] = ttg.convert_layout [[ARG_1]] : tensor<128x128xf32, [[BLOCKED_1]]> -> tensor<128x128xf32, [[BLOCKED_2]]>
110+ // CHECK: [[CONVERT_1:%[0-9]+]] = ttg.convert_layout [[INTERMEDIATE_CONV]] : tensor<128x128xf32, [[BLOCKED_2]]> -> tensor<128x128xf32, [[MMA_2]]>
111+ // CHECK: [[CONVERT_2:%[0-9]+]] = ttg.convert_layout [[ARG_2]] : tensor<256x128xf16, [[MMA_1]]> -> tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_1]], kWidth = 4}>>
112+ // CHECK: [[LOAD:%[0-9]+]] = ttg.local_load [[ALLOC]] : !ttg.memdesc<128x128xf16, [[SHARED]], #smem> -> tensor<128x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[MMA_2]], kWidth = 4}>>
112113#blocked = #ttg.blocked <{sizePerThread = [4 , 1 ], threadsPerWarp = [16 , 4 ], warpsPerCTA = [1 , 8 ], order = [0 , 1 ]}>
113114#mma1 = #ttg.amd_mfma <{versionMajor = 2 , versionMinor = 0 , warpsPerCTA = [1 , 8 ], instrShape = [32 , 32 ], isTransposed = false }>
114115#mma2 = #ttg.amd_mfma <{versionMajor = 2 , versionMinor = 0 , warpsPerCTA = [8 , 1 ], instrShape = [32 , 32 ], isTransposed = true }>
@@ -125,6 +126,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32}
125126 tt.return
126127 }
127128}
129+
128130// -----
129131
130132// Checks that optimization do not crash on 1d tensor
0 commit comments