@@ -297,3 +297,132 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16
297297 tt.return %0 : tensor <64 x16 x4 xf32 , #blocked1 >
298298 }
299299}
300+
301+ // -----
302+
303+ // Test transposition with 32 elements per work-item.
304+
305+ #blocked = #triton_gpu.blocked <{sizePerThread = [16 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [1 , 1 ], order = [0 , 1 ]}>
306+ #blocked1 = #triton_gpu.blocked <{sizePerThread = [1 , 16 ], threadsPerWarp = [16 , 1 ], warpsPerCTA = [1 , 1 ], order = [0 , 1 ]}>
307+
308+ module attributes {" triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 1 : i32 , " triton_gpu.threads-per-warp" = 16 : i32 } {
309+ // CHECK-LABEL: llvm.func spir_kernelcc @test(
310+ // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>)
311+ tt.func @test (%arg0: tensor <32 x16 xf32 , #blocked >) -> tensor <32 x16 xf32 , #blocked1 > {
312+ // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32
313+ // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
314+ // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
315+ // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
316+ // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64
317+ // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
318+ // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64
319+ // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64
320+ // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64
321+ // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64
322+ // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
323+ // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]]
324+ // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
325+ // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
326+ // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]]
327+ // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
328+ // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
329+ // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]]
330+ // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
331+ // CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
332+ // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]]
333+ // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
334+ // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64
335+ // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
336+ // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32>
337+ // CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32>
338+ // CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32
339+ %0 = triton_gpu.convert_layout %arg0 : tensor <32 x16 xf32 , #blocked > -> tensor <32 x16 xf32 , #blocked1 >
340+ tt.return %0 : tensor <32 x16 xf32 , #blocked1 >
341+ }
342+ }
343+
344+ // -----
345+
346+ // Test transposition with 32 elements per work-item with a different layout.
347+
348+ #blocked = #triton_gpu.blocked <{sizePerThread = [16 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [1 , 1 ], order = [0 , 1 ]}>
349+ #blocked1 = #triton_gpu.blocked <{sizePerThread = [1 , 16 ], threadsPerWarp = [16 , 1 ], warpsPerCTA = [1 , 1 ], order = [0 , 1 ]}>
350+
351+ module attributes {" triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 1 : i32 , " triton_gpu.threads-per-warp" = 16 : i32 } {
352+ // CHECK-LABEL: llvm.func spir_kernelcc @test(
353+ // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>)
354+ tt.func @test (%arg0: tensor <16 x32 xf32 , #blocked >) -> tensor <16 x32 xf32 , #blocked1 > {
355+ // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32
356+ // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
357+ // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
358+ // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
359+ // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64
360+ // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
361+ // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64
362+ // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64
363+ // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64
364+ // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64
365+ // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
366+ // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]]
367+ // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
368+ // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
369+ // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]]
370+ // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
371+ // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
372+ // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]]
373+ // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
374+ // CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
375+ // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]]
376+ // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
377+ // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64
378+ // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
379+ // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32>
380+ // CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32>
381+ // CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32
382+ %0 = triton_gpu.convert_layout %arg0 : tensor <16 x32 xf32 , #blocked > -> tensor <16 x32 xf32 , #blocked1 >
383+ tt.return %0 : tensor <16 x32 xf32 , #blocked1 >
384+ }
385+ }
386+
387+ // -----
388+
389+ // Test transposition with 32 elements per work-item and two warps in each dimension.
390+
391+ #blocked = #triton_gpu.blocked <{sizePerThread = [16 , 1 ], threadsPerWarp = [1 , 16 ], warpsPerCTA = [2 , 2 ], order = [0 , 1 ]}>
392+ #blocked1 = #triton_gpu.blocked <{sizePerThread = [1 , 16 ], threadsPerWarp = [16 , 1 ], warpsPerCTA = [2 , 2 ], order = [0 , 1 ]}>
393+
394+ module attributes {" triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 4 : i32 , " triton_gpu.threads-per-warp" = 16 : i32 } {
395+ // CHECK-LABEL: llvm.func spir_kernelcc @test(
396+ // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>)
397+ tt.func @test (%arg0: tensor <32 x64 xf32 , #blocked >) -> tensor <32 x64 xf32 , #blocked1 > {
398+ // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32
399+ // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
400+ // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8
401+ // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
402+ // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64
403+ // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
404+ // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64
405+ // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64
406+ // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64
407+ // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64
408+ // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
409+ // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]]
410+ // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
411+ // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
412+ // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]]
413+ // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
414+ // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
415+ // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]]
416+ // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
417+ // CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32>
418+ // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]]
419+ // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> ()
420+ // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64
421+ // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32
422+ // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32>
423+ // CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32>
424+ // CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32
425+ %0 = triton_gpu.convert_layout %arg0 : tensor <32 x64 xf32 , #blocked > -> tensor <32 x64 xf32 , #blocked1 >
426+ tt.return %0 : tensor <32 x64 xf32 , #blocked1 >
427+ }
428+ }
0 commit comments