From 642062d20978e8915353e5bacedddcb8c20d856f Mon Sep 17 00:00:00 2001 From: victor-eds Date: Thu, 21 Nov 2024 15:57:56 +0000 Subject: [PATCH 1/2] [XPU][TritonIntelGPUToLLVM] Add support for more transpose kinds Add support for layout conversion transposes in which rows managed by a single thread are contiguous in the output matrix. Signed-off-by: victor-eds --- .../intel/intel-allocate-shared-memory.mlir | 16 +++ .../Conversion/intel/sub-group-transpose.mlir | 100 ++++++++++++++---- third_party/intel/lib/Analysis/Utility.cpp | 81 ++++++++++++-- .../ConvertLayoutOpToLLVM.cpp | 39 ++++++- 4 files changed, 204 insertions(+), 32 deletions(-) diff --git a/test/Conversion/intel/intel-allocate-shared-memory.mlir b/test/Conversion/intel/intel-allocate-shared-memory.mlir index 5fad77531e..ec379fa269 100644 --- a/test/Conversion/intel/intel-allocate-shared-memory.mlir +++ b/test/Conversion/intel/intel-allocate-shared-memory.mlir @@ -63,3 +63,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : tt.return %0 : tensor<128x64xf32, #blocked1> } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 1], warpsPerCTA = [2, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 2], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [0, 1]}> + +// Check scracth memory configuration for different sub-group transpose-like layout conversions. + +// CHECK-LABEL: module attributes +// CHECK-SAME: triton_gpu.shared = 17408 : i32 +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + tt.func @test_contiguous(%arg0: tensor<32x128xf32, #blocked>) -> tensor<32x128xf32, #blocked1> { + %0 = triton_gpu.convert_layout %arg0 : tensor<32x128xf32, #blocked> -> tensor<32x128xf32, #blocked1> + tt.return %0 : tensor<32x128xf32, #blocked1> + } +} diff --git a/test/Conversion/intel/sub-group-transpose.mlir b/test/Conversion/intel/sub-group-transpose.mlir index b4a9b242a7..bf884d8d3a 100644 --- a/test/Conversion/intel/sub-group-transpose.mlir +++ b/test/Conversion/intel/sub-group-transpose.mlir @@ -25,7 +25,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi16> @@ -53,7 +53,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi16> @@ -81,7 +81,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -108,7 +108,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi8> @@ -134,7 +134,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi64> @@ -161,7 +161,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi64> @@ -189,7 +189,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi8> @@ -226,7 +226,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -263,7 +263,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -300,7 +300,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -337,7 +337,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -373,7 +373,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -410,7 +410,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -447,7 +447,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -485,7 +485,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 31 more stores: // CHECK-COUNT-31: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -523,7 +523,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 31 more stores: // CHECK-COUNT-31: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -561,7 +561,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // COM: Check there are 31 more stores: // CHECK-COUNT-31: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -596,9 +596,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: llvm.func spir_kernelcc @test( - // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) - tt.func @test(%arg0: tensor<64x64xf32, #blocked>) -> tensor<64x64xf32, #blocked1> { + // CHECK-LABEL: llvm.func spir_kernelcc @test_32( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) + tt.func @test_32(%arg0: tensor<64x64xf32, #blocked>) -> tensor<64x64xf32, #blocked1> { // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32 // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 @@ -616,7 +616,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // COM: Check there are 31 more stores: // CHECK-COUNT-31: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(1056 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(33 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -629,3 +629,61 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return %0 : tensor<64x64xf32, #blocked1> } } + +// ----- + +// Test transposition with two contiguous rows. + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 1], warpsPerCTA = [4, 2], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 2], threadsPerWarp = [1, 16], warpsPerCTA = [4, 2], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test_2_cont( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) + tt.func @test_2_cont(%arg0: tensor<64x64xf32, #blocked>) -> tensor<64x64xf32, #blocked1> { + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // COM: Offset changes with increased number of columns: + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(544 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 31 more stores: + // CHECK-COUNT-31: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(34 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]][1] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32> + // CHECK: %[[VAL_64:.*]] = llvm.getelementptr inbounds %[[VAL_62]][1] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_64]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32 + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1> + tt.return %0 : tensor<64x64xf32, #blocked1> + } +} + +// ----- + +// Test no barrier is introduced between transpositions with two contiguous rows. + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 1], warpsPerCTA = [4, 2], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 2], threadsPerWarp = [1, 16], warpsPerCTA = [4, 2], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test_2_cont_back_2_back( + tt.func @test_2_cont_back_2_back(%arg0: tensor<64x64xf32, #blocked>, %arg1: tensor<64x64xf32, #blocked>) -> (tensor<64x64xf32, #blocked1>, tensor<64x64xf32, #blocked1>) { + // CHECK-NOT: barrier + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1> + %1 = triton_gpu.convert_layout %arg1 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1> + tt.return %0, %1 : tensor<64x64xf32, #blocked1>, tensor<64x64xf32, #blocked1> + } +} diff --git a/third_party/intel/lib/Analysis/Utility.cpp b/third_party/intel/lib/Analysis/Utility.cpp index 13b8607192..8aa60dafdb 100644 --- a/third_party/intel/lib/Analysis/Utility.cpp +++ b/third_party/intel/lib/Analysis/Utility.cpp @@ -50,6 +50,46 @@ buildSubGroupTransposeRegisterBases(int32_t registerSize, int32_t laneSize) { return bases; } +// Return a vector such as: +// [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [1, 0], ..., +// [registerSize / (laneSize * 2), 0]], +// i.e., mapping registers to lanes till laneSize and performing an ID +// conversion afterwards. +std::vector> +buildContiguousSubGroupTransposeRegisterBases(int32_t registerSize, + int32_t laneSize) { + std::vector> bases; + std::vector curr(2); + int i = 1; + for (; i < laneSize; i *= 2) { + curr[1] = i; + bases.push_back(curr); + } + curr[1] = 0; + for (int32_t j = 1; i < registerSize; i *= 2, j *= 2) { + curr[0] = j; + bases.push_back(curr); + } + return bases; +} + +// Return a vector such as: +// [[registerSize / laneSize, 0], [registerSize / laneSize * 2, 0], ..., +// [registerSize / 2, 0]] +// i.e., mapping registers to lanes till laneSize and performing an ID +// conversion afterwards. +std::vector> +buildContiguousSubGroupTransposeLaneBases(int32_t registerSize, + int32_t laneSize) { + std::vector> bases; + std::vector curr(2); + for (int32_t i = registerSize / laneSize; i < registerSize; i *= 2) { + curr[0] = i; + bases.push_back(curr); + } + return bases; +} + // Return a vector such as: // [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [1, 0], ..., // [registerSize / (2 * laneSize), 0]] @@ -206,7 +246,7 @@ bool cvtIsSubGroupTranspose(RankedTensorType srcTy, RankedTensorType dstTy) { // - register=2**k -> (2**k, 0) // ... // - register=N -> (2**N, 0) - // - lane=1 -> (0, 1) + // - lane=1 -> (1, 0) // ... // - lane=2**j -> (2**j, 0) // ... @@ -214,13 +254,42 @@ bool cvtIsSubGroupTranspose(RankedTensorType srcTy, RankedTensorType dstTy) { // where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))] // // With N >= M. + // + // Alternatively, we can also lower transpositions in which the output matrix + // has more than one contiguous row owned by the same thread, resulting in: + // + // - register=1 -> (0, 1) + // ... + // - register=2**i -> (0, 2**i) + // ... + // - register=M -> (0, 2**M) + // ... + // - register=2**k -> (1, 0) + // ... + // - register=N -> (2**(N-k), 0) + // - lane=1 -> (2**(N-k+1), 0) + // ... + // - lane=2**j -> (2**(N-k+j), 0) + // ... + // lane=2**M -> (2**(N-k+M), 0) + // where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))] + // + // With N >= M. + // + // This is what we call the "contiguous" case. int32_t registerInDimSize = conversion->getInDimSize(kRegister); int32_t laneInDimSize = conversion->getInDimSize(kLane); - return conversion->getBases().lookup(kRegister) == - buildSubGroupTransposeRegisterBases(registerInDimSize, - laneInDimSize) && - conversion->getBases().lookup(kLane) == - buildSubGroupTransposeLaneBases(laneInDimSize); + return (conversion->getBases().lookup(kRegister) == + buildSubGroupTransposeRegisterBases(registerInDimSize, + laneInDimSize) && + conversion->getBases().lookup(kLane) == + buildSubGroupTransposeLaneBases(laneInDimSize)) || + (conversion->getBases().lookup(kRegister) == + buildContiguousSubGroupTransposeRegisterBases(registerInDimSize, + laneInDimSize) && + conversion->getBases().lookup(kLane) == + buildContiguousSubGroupTransposeLaneBases(registerInDimSize, + laneInDimSize)); } } // namespace mlir::triton::gpu::intel diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 72d5f7e291..c1eeada473 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -653,6 +653,27 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return res; } + int getNumContiguousRowsForTranspose(const LinearLayout &srcLayout, + const LinearLayout &dstLayout) const { + MLIRContext *ctx = getContext(); + + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + LinearLayout comp = + *dstLayout.invertAndCompose(srcLayout).quotient({kWarp, kBlock}); + // Basic case: the number of contiguous rows is 0. + if (comp.getBasis(kLane, 0)[0] == 1) + return 1; + // In other case, we only allow all threads handled by a single element to + // be contiguous, so we can simply: + int32_t sizePerThread = comp.getOutDimSize(kRegister); + int32_t threadsPerWarp = comp.getOutDimSize(kLane); + assert(sizePerThread % threadsPerWarp == 0 && "Invalid transpose"); + return sizePerThread / threadsPerWarp; + } + void performSubGroupTranspose(ConvertLayoutOp op, const LinearLayout &srcLayout, const LinearLayout &dstLayout, @@ -697,8 +718,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion }) .Default([](auto) { llvm_unreachable("Unsupported type"); }); - SmallVector outVals = - performSubGroupTranspose(loc, inVals, rewriter); + SmallVector outVals = performSubGroupTranspose( + loc, inVals, rewriter, + getNumContiguousRowsForTranspose(srcLayout, dstLayout)); TypeSwitch(origElemTy) .Case([&](FloatType floatTy) { @@ -747,7 +769,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion SmallVector performSubGroupTranspose(Location loc, ArrayRef inVals, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter, + int numContiguousRows) const { Type elementType = inVals.front().getType(); auto mod = rewriter.getInsertionPoint()->getParentOfType(); @@ -787,12 +810,18 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // Each work-item will load a row (but the last garbage element) and go to // the next row it needs to handle. - int32_t workItemStride = rowLength * threadsPerWarp; + + int32_t workItemStride = + numContiguousRows == 1 ? rowLength * threadsPerWarp : rowLength; Value workItemOffset = - mul(subGroupLocalId, int_val(offsetBitWidth, workItemStride)); + mul(subGroupLocalId, + int_val(offsetBitWidth, numContiguousRows * rowLength)); Value workItemBasePtr = gep(ptrType, elementType, subGroupBasePtr, ValueRange{workItemOffset}, /*inbounds=*/true); int32_t rowsPerThread = numRows / threadsPerWarp; + assert((numContiguousRows == 1 || numContiguousRows == rowsPerThread) && + "In case of more than one contiguous rows per thread, these must be " + "consecutive"); // We may not be able to load rows in a single operation if the sub-group // size exceeds a given threshold (16): unsigned vecLoadWidth = getVecLoadWidth(threadsPerWarp); From 92f40bf0d0c19bf42751b7f4d0aded303e8b48d2 Mon Sep 17 00:00:00 2001 From: Victor Perez Date: Tue, 26 Nov 2024 15:07:41 +0100 Subject: [PATCH 2/2] Address comments --- third_party/intel/lib/Analysis/Utility.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/intel/lib/Analysis/Utility.cpp b/third_party/intel/lib/Analysis/Utility.cpp index 5ad6fbd269..2f735d1b11 100644 --- a/third_party/intel/lib/Analysis/Utility.cpp +++ b/third_party/intel/lib/Analysis/Utility.cpp @@ -60,7 +60,7 @@ buildContiguousSubGroupTransposeRegisterBases(int32_t registerSize, int32_t laneSize) { std::vector> bases; std::vector curr(2); - int i = 1; + int32_t i = 1; for (; i < laneSize; i *= 2) { curr[1] = i; bases.push_back(curr);