diff --git a/test/Conversion/intel/sub-group-transpose.mlir b/test/Conversion/intel/sub-group-transpose.mlir new file mode 100644 index 0000000000..def61f6e73 --- /dev/null +++ b/test/Conversion/intel/sub-group-transpose.mlir @@ -0,0 +1,299 @@ +// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s + +// Basic 16x16 transpose test + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test_f16( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test_f16(%arg0: tensor<16x16xf16, #blocked>) -> tensor<16x16xf16, #blocked1> { + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f16 to i16 + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 + // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_us16PU3AS3tDv16_t(%[[VAL_59]] + // CHECK-SAME: (!llvm.ptr<3>, vector<16xi16>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi16> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i16 to f16 + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #blocked1> + tt.return %0 : tensor<16x16xf16, #blocked1> + } + + // CHECK-LABEL: llvm.func spir_kernelcc @test_bf16( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test_bf16(%arg0: tensor<16x16xbf16, #blocked>) -> tensor<16x16xbf16, #blocked1> { + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : bf16 to i16 + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 + // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_us16PU3AS3tDv16_t(%[[VAL_59]] + // CHECK-SAME: (!llvm.ptr<3>, vector<16xi16>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi16> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i16 to bf16 + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xbf16, #blocked> -> tensor<16x16xbf16, #blocked1> + tt.return %0 : tensor<16x16xbf16, #blocked1> + } + + // CHECK-LABEL: llvm.func spir_kernelcc @test_f32( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test_f32(%arg0: tensor<16x16xf32, #blocked>) -> tensor<16x16xf32, #blocked1> { + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #blocked1> + tt.return %0 : tensor<16x16xf32, #blocked1> + } + + // CHECK-LABEL: llvm.func spir_kernelcc @test_i8( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test_i8(%arg0: tensor<16x16xi8, #blocked>) -> tensor<16x16xi8, #blocked1> { + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_uc16PU3AS3hDv16_h(%[[VAL_59]] + // CHECK-SAME: (!llvm.ptr<3>, vector<16xi8>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi8> + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xi8, #blocked> -> tensor<16x16xi8, #blocked1> + tt.return %0 : tensor<16x16xi8, #blocked1> + } + + // CHECK-LABEL: llvm.func spir_kernelcc @test_i64( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test_i64(%arg0: tensor<16x16xi64, #blocked>) -> tensor<16x16xi64, #blocked1> { + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_59]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () + // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi64> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_60]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi64> + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xi64, #blocked> -> tensor<16x16xi64, #blocked1> + tt.return %0 : tensor<16x16xi64, #blocked1> + } + + // CHECK-LABEL: llvm.func spir_kernelcc @test_ptr( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test_ptr(%arg0: tensor<16x16x!tt.ptr, #blocked>) -> tensor<16x16x!tt.ptr, #blocked1> { + // CHECK-COUNT-16: llvm.ptrtoint %{{.*}} : !llvm.ptr<1> to i64 + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_59]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () + // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi64> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_60]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi64> + // CHECK-COUNT-16: llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<1> + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16x!tt.ptr, #blocked> -> tensor<16x16x!tt.ptr, #blocked1> + tt.return %0 : tensor<16x16x!tt.ptr, #blocked1> + } + + // CHECK-LABEL: llvm.func spir_kernelcc @test_i1( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test_i1(%arg0: tensor<16x16xi1, #blocked>) -> tensor<16x16xi1, #blocked1> { + // CHECK-COUNT-16: llvm.zext %{{.*}} : i1 to i8 + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_uc16PU3AS3hDv16_h(%[[VAL_59]] + // CHECK-SAME: (!llvm.ptr<3>, vector<16xi8>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi8> + // CHECK-COUNT-16: llvm.trunc %{{.*}} : i8 to i1 + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xi1, #blocked> -> tensor<16x16xi1, #blocked1> + tt.return %0 : tensor<16x16xi1, #blocked1> + } +} + +// ----- + +// Test with two sub-groups in the first dimension. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 1], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test(%arg0: tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked1> { + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + %0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #blocked1> + tt.return %0 : tensor<32x16xf32, #blocked1> + } +} + +// ----- + +// Test with two sub-groups in the second dimension. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 2], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 2], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test(%arg0: tensor<16x32xf32, #blocked>) -> tensor<16x32xf32, #blocked1> { + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + %0 = triton_gpu.convert_layout %arg0 : tensor<16x32xf32, #blocked> -> tensor<16x32xf32, #blocked1> + tt.return %0 : tensor<16x32xf32, #blocked1> + } +} + +// ----- + +// Test with four sub-groups in each dimension. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [4, 4], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test(%arg0: tensor<64x64xf32, #blocked>) -> tensor<64x64xf32, #blocked1> { + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1> + tt.return %0 : tensor<64x64xf32, #blocked1> + } +} + +// ----- + +// Test with four sub-groups in each dimension and an additional dimension. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1, 1], threadsPerWarp = [1, 16, 1], warpsPerCTA = [4, 4, 1], order = [0, 1, 2]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [4, 4, 1], order = [0, 1, 2]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test(%arg0: tensor<64x64x1xf32, #blocked>) -> tensor<64x64x1xf32, #blocked1> { + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64x1xf32, #blocked> -> tensor<64x64x1xf32, #blocked1> + tt.return %0 : tensor<64x64x1xf32, #blocked1> + } +} +// ----- + +// Test with four sub-groups in each dimension and sliced layout. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1, 1], threadsPerWarp = [1, 16, 1], warpsPerCTA = [4, 4, 1], order = [0, 1, 2]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [4, 4], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test(%arg0: tensor<64x64xf32, #triton_gpu.slice<{dim = 2, parent = #blocked}>>) -> tensor<64x64xf32, #blocked1> { + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> -> tensor<64x64xf32, #blocked1> + tt.return %0 : tensor<64x64xf32, #blocked1> + } +} + +// ----- + +// Test with one sub-group and double-sliced layout. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [1, 2, 3, 4, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 1], order = [1, 2, 0]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test(%arg0: tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>>) -> tensor<16x16x1xf32, #blocked1> { + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> -> tensor<16x16x1xf32, #blocked1> + tt.return %0 : tensor<16x16x1xf32, #blocked1> + } +} + +// ----- + +// Test with four sub-groups in each dimension and double-sliced layout. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [4, 1, 1, 4, 1], order = [1, 2, 3, 4, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [4, 1, 4], order = [1, 2, 0]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test(%arg0: tensor<64x16x4xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>>) -> tensor<64x16x4xf32, #blocked1> { + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( + %0 = triton_gpu.convert_layout %arg0 : tensor<64x16x4xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> -> tensor<64x16x4xf32, #blocked1> + tt.return %0 : tensor<64x16x4xf32, #blocked1> + } +} diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 44f33b13e2..aafa49c49f 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -2,6 +2,8 @@ #include "TargetInfo.h" #include "Utility.h" +#include "llvm/ADT/TypeSwitch.h" + #include "intel/include/Analysis/Utility.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h" @@ -432,11 +434,64 @@ struct ConvertLayoutOpConversion struct ConvertLayoutOpUsingLinearLayoutsConversion : public ConvertOpToLLVMPattern { + constexpr static unsigned minSubGroupTransposeWidth = 8; + + const TargetInfoBase &targetInfo; + // Set benefit to 2 so that this pattern applies before other convert-layout // conversions. TODO(jlebar): Eventually we want this to be the only pattern. - explicit ConvertLayoutOpUsingLinearLayoutsConversion( - LLVMTypeConverter &typeConverter, PatternBenefit benefit = 2) - : ConvertOpToLLVMPattern(typeConverter, benefit) {} + ConvertLayoutOpUsingLinearLayoutsConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit = 2) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + bool isSubGroupTranspose(const LinearLayout &srcLayout, + const LinearLayout &dstLayout) const { + MLIRContext *ctx = srcLayout.getInDimNames().begin()->getContext(); + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + + LinearLayout comp = dstLayout.invertAndCompose(srcLayout); + std::optional conversion = comp.divideRight( + LinearLayout::identity1D(comp.getInDimSize(kWarp), kWarp, kWarp) * + LinearLayout::identity1D(comp.getInDimSize(kBlock), kBlock, kBlock)); + assert(conversion && "Expecting valid conversion"); + // Expected conversion is: + // - register=1 -> (0, 1) + // ... + // - register=i -> (0, 2**(i-1)) + // ... + // - register=N -> (0, 2**(N-1)) + // - lane=1 -> (0, 1) + // ... + // - lane=j -> (2**(j-1), 0) + // ... + // lane=M -> (2**(M-1), 0) + // where out dims are: [register (size 2**(N-1)), lane (size 2**(M-1))] + // + // With N = M. + const auto buildBasis = [&](int32_t size, std::size_t index) { + std::vector> basis; + std::vector curr(2); + for (int32_t i = 1; i < size; i *= 2) { + curr[index] = i; + basis.push_back(curr); + } + return basis; + }; + + constexpr std::size_t laneIndex = 0; + constexpr std::size_t registerIndex = 1; + int32_t size = conversion->getInDimSize(kLane); + std::array>>, 2> + bases{{{kRegister, buildBasis(size, registerIndex)}, + {kLane, buildBasis(size, laneIndex)}}}; + std::array outDimNames{kRegister, kLane}; + return conversion == LinearLayout(bases, outDimNames); + } LogicalResult matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor, @@ -532,15 +587,212 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return success(); } + bool isSupportedSubGroupTranspose(ConvertLayoutOp op, + OpAdaptor adaptor) const { + auto srcType = cast(adaptor.getSrc().getType()); + ArrayRef body = srcType.getBody(); + // TODO: Support more configurations. + auto mod = op->getParentOfType(); + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + if (body.size() != threadsPerWarp) + return false; + return TypeSwitch(body.front()) + .Case([this](FloatType floatTy) { + // Support via bitcasting to integer type. + return isValidTypeForSubGroupTranspose( + IntegerType::get(floatTy.getContext(), floatTy.getWidth())); + }) + .Case([this](IntegerType intTy) { + // Support via extending to supported type. + return isValidTypeForSubGroupTranspose(intTy) || + intTy.getWidth() < minSubGroupTransposeWidth; + }) + .Case([](LLVM::LLVMPointerType) { + // Support via ptrtoint + return true; + }) + .Default(false); + } + LogicalResult transferWithinLane(ConvertLayoutOp op, const LinearLayout &srcLayout, const LinearLayout &dstLayout, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + // If the operation is a supported sub-group transposition, perform via SLM. + if (isSubGroupTranspose(srcLayout, dstLayout) && + isSupportedSubGroupTranspose(op, adaptor)) { + performSubGroupTranspose(op, srcLayout, dstLayout, adaptor, rewriter); + return success(); + } // TODO(jlebar): Implement me. return failure(); } + bool isValidTypeForSubGroupTranspose(Type type) const { + return TypeSwitch(type) + .Case([](IntegerType intTy) { + unsigned width = intTy.getWidth(); + return width == 8 || width == 16 || width == 32 || width == 64; + }) + .Default(false); + } + + void performSubGroupTranspose(ConvertLayoutOp op, + const LinearLayout &srcLayout, + const LinearLayout &dstLayout, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(isSubGroupTranspose(srcLayout, dstLayout) && + "Expecting sub-group transpose"); + assert(isSupportedSubGroupTranspose(op, adaptor) && + "Expecting supported sub-group transpose"); + + Location loc = op.getLoc(); + + SmallVector inVals = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + + auto srcTy = cast(op.getSrc().getType()); + Type origElemTy = inVals.front().getType(); + + TypeSwitch(origElemTy) + .Case([&](FloatType floatTy) { + // TODO: Support FP4. + Type dstType = int_ty(floatTy.getWidth()); + assert(isValidTypeForSubGroupTranspose(dstType) && + "Expecting valid type"); + llvm::transform(inVals, std::begin(inVals), [&](Value val) -> Value { + return bitcast(val, dstType); + }); + }) + .Case([&](IntegerType intTy) { + if (isValidTypeForSubGroupTranspose(intTy)) + return; + assert(intTy.getWidth() < minSubGroupTransposeWidth && + "Expecting type to extend to i8"); + Type dstType = i8_ty; + llvm::transform(inVals, std::begin(inVals), [&](Value val) -> Value { + return zext(dstType, val); + }); + }) + .Case([&](LLVM::LLVMPointerType) { + Type dstType = i64_ty; + assert(isValidTypeForSubGroupTranspose(dstType) && + "i64 type should be supported"); + llvm::transform(inVals, std::begin(inVals), [&](Value val) -> Value { + return ptrtoint(dstType, val); + }); + }) + .Default([](auto) { llvm_unreachable("Unsupported type"); }); + + SmallVector outVals = + performSubGroupTranspose(loc, inVals, rewriter); + + TypeSwitch(origElemTy) + .Case([&](FloatType floatTy) { + llvm::transform( + outVals, std::begin(outVals), + [&](Value val) -> Value { return bitcast(val, origElemTy); }); + }) + .Case([&](IntegerType intTy) { + // Check whether conversion took place. + if (intTy == outVals.front().getType()) + return; + llvm::transform( + outVals, std::begin(outVals), + [&](Value val) -> Value { return trunc(origElemTy, val); }); + }) + .Case([&](LLVM::LLVMPointerType ptrTy) { + llvm::transform( + outVals, std::begin(outVals), + [&](Value val) -> Value { return inttoptr(ptrTy, val); }); + }) + .Default([](auto) { llvm_unreachable("Unsupported type"); }); + + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, + op.getType()); + rewriter.replaceOp(op, result); + } + + VectorType + getTypeForSubGroupTranspose(ArrayRef inVals, + ConversionPatternRewriter &rewriter) const { + auto elementTy = cast(inVals.front().getType()); + return elementTy.getWidth() <= 16 ? vec_ty(elementTy, 16) + : vec_ty(elementTy, 8); + } + + Value wrapInVector(Location loc, VectorType type, ArrayRef values, + ConversionPatternRewriter &rewriter) const { + assert(type.getShape()[0] == values.size() && "Size mismatch"); + Value res = rewriter.create(loc, type); + for (auto [index, val] : llvm::enumerate(values)) + res = insert_element(res, val, i32_val(index)); + return res; + } + + SmallVector + unwrapFromVector(Location loc, Value vec, + ConversionPatternRewriter &rewriter) const { + SmallVector res; + for (unsigned i = 0, n = cast(vec.getType()).getShape()[0]; + i < n; ++i) + res.push_back(extract_element(vec, i32_val(i))); + return res; + } + + SmallVector + performSubGroupTranspose(Location loc, ArrayRef inVals, + ConversionPatternRewriter &rewriter) const { + VectorType opType = getTypeForSubGroupTranspose(inVals, rewriter); + auto mod = rewriter.getInsertionPoint()->getParentOfType(); + unsigned vecWidth = opType.getShape()[0]; + + Value smemBase = LLVM::intel::getSharedMemoryBase( + loc, rewriter, targetInfo, &*rewriter.getInsertionPoint()); + Type ptrType = smemBase.getType(); + + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + int offset = threadsPerWarp; + Type offsetType = getTypeConverter()->getIndexType(); + Value subGroupId = getValueOrCreateCastToIndexLike( + rewriter, loc, offsetType, + rewriter.create( + loc, /*upper_bound=*/IntegerAttr{})); + Value subGroupLocalId = getValueOrCreateCastToIndexLike( + rewriter, loc, offsetType, + rewriter.create(loc, + /*upper_bound=*/IntegerAttr{})); + Value wiStride = + rewriter.create(loc, offsetType, threadsPerWarp); + Value sgStride = rewriter.create( + loc, offsetType, threadsPerWarp * threadsPerWarp); + Value subGroupOffset = mul(sgStride, subGroupId); + Type elementType = opType.getElementType(); + Value subGroupBasePtr = gep(ptrType, elementType, smemBase, + ValueRange{subGroupOffset}, /*inbounds=*/true); + Value base = subGroupBasePtr; + // Store in matrix, transposed + for (ArrayRef vals = inVals; !vals.empty(); + vals = vals.drop_front(vecWidth)) { + ArrayRef curr = vals.take_front(vecWidth); + Value vec = wrapInVector(loc, opType, curr, rewriter); + rewriter.create(loc, base, vec); + base = gep(base.getType(), opType, base, ArrayRef{offset}, + /*inbounds=*/true); + } + + // Load from matrix, non-trasposed. + Value workItemOffset = mul(wiStride, subGroupLocalId); + Value workItemBasePtr = gep(ptrType, elementType, subGroupBasePtr, + ValueRange{workItemOffset}, /*inbounds=*/true); + Value transposedVec = + load(vec_ty(opType.getElementType(), inVals.size()), workItemBasePtr); + + return unwrapFromVector(loc, transposedVec, rewriter); + } + LogicalResult transferWithinBlockGroup(ConvertLayoutOp op, const LinearLayout &srcLayout, const LinearLayout &dstLayout, OpAdaptor adaptor, @@ -560,7 +812,7 @@ void mlir::triton::intel::populateConvertLayoutOpToLLVMPatterns( // Eventually the LL conversion will subsume all of the others and be the only // one left. patterns.add( - typeConverter, benefit.getBenefit() + 1); + typeConverter, targetInfo, benefit.getBenefit() + 1); patterns.add(typeConverter, targetInfo, benefit); }