@@ -125,3 +125,42 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
125125 tt.return %3 : tensor <128 x32 xf32 , #blocked >
126126 }
127127}
128+
129+ // -----
130+
131+ #mma = #ttg.amd_mfma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [32 , 32 ], isTransposed = true }>
132+ #mma1 = #ttg.amd_mfma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 16 ], isTransposed = true }>
133+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.shared = 16384 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
134+ tt.func public @convert_layout (%arg0: tensor <128 x64 xf16 , #mma >) attributes {noinline = false } {
135+ // CHECK-LABEL: convert_layout
136+
137+ // CHECK: [[ES_0:%.*]] = amdgpu.extract_slice %arg0 [0, 0] : tensor<128x64xf16, #mma> to tensor<128x16xf16, #mma>
138+ // CHECK: [[CL_0:%.*]] = ttg.convert_layout [[ES_0]] : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
139+ // CHECK: [[ES_1:%.*]] = amdgpu.extract_slice %arg0 [0, 16] : tensor<128x64xf16, #mma> to tensor<128x16xf16, #mma>
140+ // CHECK: [[CL_1:%.*]] = ttg.convert_layout [[ES_1]] : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
141+ // CHECK: [[ES_2:%.*]] = amdgpu.extract_slice %arg0 [0, 32] : tensor<128x64xf16, #mma> to tensor<128x16xf16, #mma>
142+ // CHECK: [[CL_2:%.*]] = ttg.convert_layout [[ES_2]] : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
143+ // CHECK: [[ES_3:%.*]] = amdgpu.extract_slice %arg0 [0, 48] : tensor<128x64xf16, #mma> to tensor<128x16xf16, #mma>
144+ // CHECK: [[CL_3:%.*]] = ttg.convert_layout [[ES_3]] : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
145+ // CHECK: %8 = amdgpu.concat [[CL_0]], [[CL_1]], [[CL_2]], [[CL_3]] [1, 4] {loweringOrder = array<i64: 1, 0>} : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>, tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>, tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>, tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
146+
147+ %0 = ttg.convert_layout %arg0 : tensor <128 x64 xf16 , #mma > -> tensor <128 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma1 , kWidth = 4 }>>
148+ amdgpu.instruction_sched_hint {isBufferLoadsAEnabled = false , isBufferLoadsBEnabled = false , numDsReadsA = #amdgpu.InstCounter <0 , none >, numDsReadsB = #amdgpu.InstCounter <0 , none >, numDsWritesA = #amdgpu.InstCounter <0 , none >, numDsWritesB = #amdgpu.InstCounter <0 , none >, numGlobalLoadsA = #amdgpu.InstCounter <0 , none >, numGlobalLoadsB = #amdgpu.InstCounter <0 , none >, numMMAs = #amdgpu.InstCounter <0 , none >, variant = #amdgpu.SchedHintVariant <refine_ops >}
149+ tt.return
150+ }
151+ }
152+
153+ // -----
154+
155+ // blocked layout cta tile has size of whole tensor, no transformation should happen
156+ // CHECK-LABEL: @convert_layout_kernel_neg
157+ // CHECK-NOT: amdgpu.extract_slice
158+ #blocked1 = #ttg.blocked <{sizePerThread = [4 , 1 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
159+ #blocked2 = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
160+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " hip:gfx942" , " ttg.threads-per-warp" = 64 : i32 } {
161+ tt.func public @convert_layout_kernel_neg (%arg0: tensor <128 x32 xf32 , #blocked1 >) -> tensor <128 x32 xf32 , #blocked2 > attributes {noinline = false } {
162+ amdgpu.instruction_sched_hint {isBufferLoadsAEnabled = false , isBufferLoadsBEnabled = false , numDsReadsA = #amdgpu.InstCounter <0 , none >, numDsReadsB = #amdgpu.InstCounter <0 , none >, numDsWritesA = #amdgpu.InstCounter <0 , none >, numDsWritesB = #amdgpu.InstCounter <0 , none >, numGlobalLoadsA = #amdgpu.InstCounter <0 , none >, numGlobalLoadsB = #amdgpu.InstCounter <0 , none >, numMMAs = #amdgpu.InstCounter <0 , none >, variant = #amdgpu.SchedHintVariant <refine_ops >}
163+ %0 = ttg.convert_layout %arg0 : tensor <128 x32 xf32 , #blocked1 > -> tensor <128 x32 xf32 , #blocked2 >
164+ tt.return %0 : tensor <128 x32 xf32 , #blocked2 >
165+ }
166+ }
0 commit comments