diff --git a/test/Conversion/intel/intel-allocate-shared-memory.mlir b/test/Conversion/intel/intel-allocate-shared-memory.mlir index 81cbcf4a31..703c90ac31 100644 --- a/test/Conversion/intel/intel-allocate-shared-memory.mlir +++ b/test/Conversion/intel/intel-allocate-shared-memory.mlir @@ -81,3 +81,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : tt.return %0 : tensor<128x64xf32, #blocked1> } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 1], warpsPerCTA = [2, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 2], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [0, 1]}> + +// Check scracth memory configuration for different sub-group transpose-like layout conversions. + +// CHECK-LABEL: module attributes +// CHECK-SAME: triton_gpu.shared = 17408 : i32 +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + tt.func @test_contiguous(%arg0: tensor<32x128xf32, #blocked>) -> tensor<32x128xf32, #blocked1> { + %0 = triton_gpu.convert_layout %arg0 : tensor<32x128xf32, #blocked> -> tensor<32x128xf32, #blocked1> + tt.return %0 : tensor<32x128xf32, #blocked1> + } +} diff --git a/test/Conversion/intel/sub-group-transpose.mlir b/test/Conversion/intel/sub-group-transpose.mlir index b4a9b242a7..bf884d8d3a 100644 --- a/test/Conversion/intel/sub-group-transpose.mlir +++ b/test/Conversion/intel/sub-group-transpose.mlir @@ -25,7 +25,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi16> @@ -53,7 +53,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_usPU3AS3tt( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi16> @@ -81,7 +81,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -108,7 +108,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi8> @@ -134,7 +134,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi64> @@ -161,7 +161,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi64> @@ -189,7 +189,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ucPU3AS3hh( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi8> @@ -226,7 +226,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -263,7 +263,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -300,7 +300,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -337,7 +337,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -373,7 +373,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -410,7 +410,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -447,7 +447,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // COM: Check there are 15 more stores: // CHECK-COUNT-15: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -485,7 +485,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 31 more stores: // CHECK-COUNT-31: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -523,7 +523,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // COM: Check there are 31 more stores: // CHECK-COUNT-31: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -561,7 +561,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // COM: Check there are 31 more stores: // CHECK-COUNT-31: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(272 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(17 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -596,9 +596,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: llvm.func spir_kernelcc @test( - // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) - tt.func @test(%arg0: tensor<64x64xf32, #blocked>) -> tensor<64x64xf32, #blocked1> { + // CHECK-LABEL: llvm.func spir_kernelcc @test_32( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) + tt.func @test_32(%arg0: tensor<64x64xf32, #blocked>) -> tensor<64x64xf32, #blocked1> { // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32 // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 @@ -616,7 +616,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // COM: Check there are 31 more stores: // CHECK-COUNT-31: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( - // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(1056 : i64) : i64 + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(33 : i64) : i64 // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> @@ -629,3 +629,61 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return %0 : tensor<64x64xf32, #blocked1> } } + +// ----- + +// Test transposition with two contiguous rows. + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 1], warpsPerCTA = [4, 2], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 2], threadsPerWarp = [1, 16], warpsPerCTA = [4, 2], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test_2_cont( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) + tt.func @test_2_cont(%arg0: tensor<64x64xf32, #blocked>) -> tensor<64x64xf32, #blocked1> { + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAL_35:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[VAL_34]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_36:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_37:.*]] = llvm.zext %[[VAL_36]] : i32 to i64 + // CHECK: %[[VAL_38:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {no_unwind, will_return} : () -> i32 + // CHECK: %[[VAL_39:.*]] = llvm.zext %[[VAL_38]] : i32 to i64 + // COM: Offset changes with increased number of columns: + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(544 : i64) : i64 + // CHECK: %[[VAL_41:.*]] = llvm.mul %[[VAL_37]], %[[VAL_40]] : i64 + // CHECK: %[[VAL_42:.*]] = llvm.getelementptr inbounds %[[VAL_35]]{{\[}}%[[VAL_41]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj(%[[VAL_42]] + // COM: Check offset: + // CHECK: llvm.getelementptr inbounds %{{.*}}[17] + // COM: Check there are 31 more stores: + // CHECK-COUNT-31: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK-NOT: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS3jj( + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(34 : i64) : i64 + // CHECK: %[[VAL_60:.*]] = llvm.mul %[[VAL_39]], %[[VAL_59]] : i64 + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_42]]{{\[}}%[[VAL_60]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_61]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]][1] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32> + // CHECK: %[[VAL_64:.*]] = llvm.getelementptr inbounds %[[VAL_62]][1] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_64]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32 + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1> + tt.return %0 : tensor<64x64xf32, #blocked1> + } +} + +// ----- + +// Test no barrier is introduced between transpositions with two contiguous rows. + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 1], warpsPerCTA = [4, 2], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 2], threadsPerWarp = [1, 16], warpsPerCTA = [4, 2], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test_2_cont_back_2_back( + tt.func @test_2_cont_back_2_back(%arg0: tensor<64x64xf32, #blocked>, %arg1: tensor<64x64xf32, #blocked>) -> (tensor<64x64xf32, #blocked1>, tensor<64x64xf32, #blocked1>) { + // CHECK-NOT: barrier + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1> + %1 = triton_gpu.convert_layout %arg1 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1> + tt.return %0, %1 : tensor<64x64xf32, #blocked1>, tensor<64x64xf32, #blocked1> + } +} diff --git a/third_party/intel/lib/Analysis/Utility.cpp b/third_party/intel/lib/Analysis/Utility.cpp index cba7f77398..2f735d1b11 100644 --- a/third_party/intel/lib/Analysis/Utility.cpp +++ b/third_party/intel/lib/Analysis/Utility.cpp @@ -50,6 +50,46 @@ buildSubGroupTransposeRegisterBases(int32_t registerSize, int32_t laneSize) { return bases; } +// Return a vector such as: +// [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [1, 0], ..., +// [registerSize / (laneSize * 2), 0]], +// i.e., mapping registers to lanes till laneSize and performing an ID +// conversion afterwards. +std::vector> +buildContiguousSubGroupTransposeRegisterBases(int32_t registerSize, + int32_t laneSize) { + std::vector> bases; + std::vector curr(2); + int32_t i = 1; + for (; i < laneSize; i *= 2) { + curr[1] = i; + bases.push_back(curr); + } + curr[1] = 0; + for (int32_t j = 1; i < registerSize; i *= 2, j *= 2) { + curr[0] = j; + bases.push_back(curr); + } + return bases; +} + +// Return a vector such as: +// [[registerSize / laneSize, 0], [registerSize / laneSize * 2, 0], ..., +// [registerSize / 2, 0]] +// i.e., mapping registers to lanes till laneSize and performing an ID +// conversion afterwards. +std::vector> +buildContiguousSubGroupTransposeLaneBases(int32_t registerSize, + int32_t laneSize) { + std::vector> bases; + std::vector curr(2); + for (int32_t i = registerSize / laneSize; i < registerSize; i *= 2) { + curr[0] = i; + bases.push_back(curr); + } + return bases; +} + // Return a vector such as: // [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [1, 0], ..., // [registerSize / (2 * laneSize), 0]] @@ -253,7 +293,7 @@ bool cvtIsSubGroupTranspose(RankedTensorType srcTy, RankedTensorType dstTy) { // - register=2**k -> (2**k, 0) // ... // - register=N -> (2**N, 0) - // - lane=1 -> (0, 1) + // - lane=1 -> (1, 0) // ... // - lane=2**j -> (2**j, 0) // ... @@ -261,13 +301,42 @@ bool cvtIsSubGroupTranspose(RankedTensorType srcTy, RankedTensorType dstTy) { // where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))] // // With N >= M. + // + // Alternatively, we can also lower transpositions in which the output matrix + // has more than one contiguous row owned by the same thread, resulting in: + // + // - register=1 -> (0, 1) + // ... + // - register=2**i -> (0, 2**i) + // ... + // - register=M -> (0, 2**M) + // ... + // - register=2**k -> (1, 0) + // ... + // - register=N -> (2**(N-k), 0) + // - lane=1 -> (2**(N-k+1), 0) + // ... + // - lane=2**j -> (2**(N-k+j), 0) + // ... + // lane=2**M -> (2**(N-k+M), 0) + // where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))] + // + // With N >= M. + // + // This is what we call the "contiguous" case. int32_t registerInDimSize = conversion->getInDimSize(kRegister); int32_t laneInDimSize = conversion->getInDimSize(kLane); - return conversion->getBases().lookup(kRegister) == - buildSubGroupTransposeRegisterBases(registerInDimSize, - laneInDimSize) && - conversion->getBases().lookup(kLane) == - buildSubGroupTransposeLaneBases(laneInDimSize); + return (conversion->getBases().lookup(kRegister) == + buildSubGroupTransposeRegisterBases(registerInDimSize, + laneInDimSize) && + conversion->getBases().lookup(kLane) == + buildSubGroupTransposeLaneBases(laneInDimSize)) || + (conversion->getBases().lookup(kRegister) == + buildContiguousSubGroupTransposeRegisterBases(registerInDimSize, + laneInDimSize) && + conversion->getBases().lookup(kLane) == + buildContiguousSubGroupTransposeLaneBases(registerInDimSize, + laneInDimSize)); } } // namespace mlir::triton::gpu::intel diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index c3a5b8da74..f8c4dc7fc3 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -694,6 +694,27 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return res; } + int getNumContiguousRowsForTranspose(const LinearLayout &srcLayout, + const LinearLayout &dstLayout) const { + MLIRContext *ctx = getContext(); + + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + LinearLayout comp = + *dstLayout.invertAndCompose(srcLayout).quotient({kWarp, kBlock}); + // Basic case: the number of contiguous rows is 0. + if (comp.getBasis(kLane, 0)[0] == 1) + return 1; + // In other case, we only allow all threads handled by a single element to + // be contiguous, so we can simply: + int32_t sizePerThread = comp.getOutDimSize(kRegister); + int32_t threadsPerWarp = comp.getOutDimSize(kLane); + assert(sizePerThread % threadsPerWarp == 0 && "Invalid transpose"); + return sizePerThread / threadsPerWarp; + } + void performSubGroupTranspose(ConvertLayoutOp op, const LinearLayout &srcLayout, const LinearLayout &dstLayout, @@ -738,8 +759,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion }) .Default([](auto) { llvm_unreachable("Unsupported type"); }); - SmallVector outVals = - performSubGroupTranspose(loc, inVals, rewriter); + SmallVector outVals = performSubGroupTranspose( + loc, inVals, rewriter, + getNumContiguousRowsForTranspose(srcLayout, dstLayout)); TypeSwitch(origElemTy) .Case([&](FloatType floatTy) { @@ -788,7 +810,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion SmallVector performSubGroupTranspose(Location loc, ArrayRef inVals, - ConversionPatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter, + int numContiguousRows) const { Type elementType = inVals.front().getType(); auto mod = rewriter.getInsertionPoint()->getParentOfType(); @@ -828,12 +851,18 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // Each work-item will load a row (but the last garbage element) and go to // the next row it needs to handle. - int32_t workItemStride = rowLength * threadsPerWarp; + + int32_t workItemStride = + numContiguousRows == 1 ? rowLength * threadsPerWarp : rowLength; Value workItemOffset = - mul(subGroupLocalId, int_val(offsetBitWidth, workItemStride)); + mul(subGroupLocalId, + int_val(offsetBitWidth, numContiguousRows * rowLength)); Value workItemBasePtr = gep(ptrType, elementType, subGroupBasePtr, ValueRange{workItemOffset}, /*inbounds=*/true); int32_t rowsPerThread = numRows / threadsPerWarp; + assert((numContiguousRows == 1 || numContiguousRows == rowsPerThread) && + "In case of more than one contiguous rows per thread, these must be " + "consecutive"); // We may not be able to load rows in a single operation if the sub-group // size exceeds a given threshold (16): unsigned vecLoadWidth = getVecLoadWidth(threadsPerWarp);