From d3869512a42bdbb8ac810c9bdef44c9adbb04273 Mon Sep 17 00:00:00 2001 From: victor-eds Date: Fri, 18 Oct 2024 13:18:42 +0100 Subject: [PATCH 1/7] [TritonIntelGPUToLLVM] Detect sub-group transpose cases when using linear layout Detect sub-group transpose cases as those in which warp and lane dimensions get swapped and no transfer within block-groups is needed. Use sub-group write operations to store the contents in local memory and vector operations to write back. These will be translated to non-transposed and transposed store and loads respectively. As data is moved within sub-groups, no barriers are needed. For now, handle only the case of a `single sub_group_size^2` block being transposed. This may be split in the future by performing `N*M` iterations for matrices of size `N*sub_group_sizexM*sub_group_size`. Signed-off-by: victor-eds --- .../Conversion/intel/sub-group-transpose.mlir | 404 ++++++++++++++++++ .../ConvertLayoutOpToLLVM.cpp | 214 +++++++++- 2 files changed, 613 insertions(+), 5 deletions(-) create mode 100644 test/Conversion/intel/sub-group-transpose.mlir diff --git a/test/Conversion/intel/sub-group-transpose.mlir b/test/Conversion/intel/sub-group-transpose.mlir new file mode 100644 index 0000000000..a666bee600 --- /dev/null +++ b/test/Conversion/intel/sub-group-transpose.mlir @@ -0,0 +1,404 @@ +// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm --canonicalize | FileCheck %s + +// Basic 16x16 transpose test + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test_f16( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test_f16(%arg0: tensor<16x16xf16, #blocked>) -> tensor<16x16xf16, #blocked1> { + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f16 to i16 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 + // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_us16PU3AS3tDv16_t(%[[VAL_59]] + // CHECK-SAME: (!llvm.ptr<3>, vector<16xi16>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi16> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i16 to f16 + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #blocked1> + tt.return %0 : tensor<16x16xf16, #blocked1> + } + + // CHECK-LABEL: llvm.func spir_kernelcc @test_bf16( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test_bf16(%arg0: tensor<16x16xbf16, #blocked>) -> tensor<16x16xbf16, #blocked1> { + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : bf16 to i16 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 + // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_us16PU3AS3tDv16_t(%[[VAL_59]] + // CHECK-SAME: (!llvm.ptr<3>, vector<16xi16>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi16> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i16 to bf16 + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xbf16, #blocked> -> tensor<16x16xbf16, #blocked1> + tt.return %0 : tensor<16x16xbf16, #blocked1> + } + + // CHECK-LABEL: llvm.func spir_kernelcc @test_f32( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test_f32(%arg0: tensor<16x16xf32, #blocked>) -> tensor<16x16xf32, #blocked1> { + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_68]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #blocked1> + tt.return %0 : tensor<16x16xf32, #blocked1> + } + + // CHECK-LABEL: llvm.func spir_kernelcc @test_i8( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test_i8(%arg0: tensor<16x16xi8, #blocked>) -> tensor<16x16xi8, #blocked1> { + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_uc16PU3AS3hDv16_h(%[[VAL_59]] + // CHECK-SAME: (!llvm.ptr<3>, vector<16xi8>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi8> + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xi8, #blocked> -> tensor<16x16xi8, #blocked1> + tt.return %0 : tensor<16x16xi8, #blocked1> + } + + // CHECK-LABEL: llvm.func spir_kernelcc @test_i64( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test_i64(%arg0: tensor<16x16xi64, #blocked>) -> tensor<16x16xi64, #blocked1> { + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_59]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () + // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi64> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_68]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi64> + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xi64, #blocked> -> tensor<16x16xi64, #blocked1> + tt.return %0 : tensor<16x16xi64, #blocked1> + } + + // CHECK-LABEL: llvm.func spir_kernelcc @test_ptr( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test_ptr(%arg0: tensor<16x16x!tt.ptr, #blocked>) -> tensor<16x16x!tt.ptr, #blocked1> { + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK-COUNT-16: llvm.ptrtoint %{{.*}} : !llvm.ptr<1> to i64 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_59]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () + // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi64> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_68]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi64> + // CHECK-COUNT-16: llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<1> + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16x!tt.ptr, #blocked> -> tensor<16x16x!tt.ptr, #blocked1> + tt.return %0 : tensor<16x16x!tt.ptr, #blocked1> + } + + // CHECK-LABEL: llvm.func spir_kernelcc @test_i1( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test_i1(%arg0: tensor<16x16xi1, #blocked>) -> tensor<16x16xi1, #blocked1> { + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK-COUNT-16: llvm.zext %{{.*}} : i1 to i8 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_uc16PU3AS3hDv16_h(%[[VAL_59]] + // CHECK-SAME: (!llvm.ptr<3>, vector<16xi8>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi8> + // CHECK-COUNT-16: llvm.trunc %{{.*}} : i8 to i1 + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xi1, #blocked> -> tensor<16x16xi1, #blocked1> + tt.return %0 : tensor<16x16xi1, #blocked1> + } +} + +// ----- + +// Test with two sub-groups in the first dimension. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 1], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test(%arg0: tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked1> { + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_68]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 + %0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #blocked1> + tt.return %0 : tensor<32x16xf32, #blocked1> + } +} + +// ----- + +// Test with two sub-groups in the second dimension. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 2], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 2], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test(%arg0: tensor<16x32xf32, #blocked>) -> tensor<16x32xf32, #blocked1> { + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_68]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 + %0 = triton_gpu.convert_layout %arg0 : tensor<16x32xf32, #blocked> -> tensor<16x32xf32, #blocked1> + tt.return %0 : tensor<16x32xf32, #blocked1> + } +} + +// ----- + +// Test with four sub-groups in each dimension. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 4], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [4, 4], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test(%arg0: tensor<64x64xf32, #blocked>) -> tensor<64x64xf32, #blocked1> { + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_68]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1> + tt.return %0 : tensor<64x64xf32, #blocked1> + } +} + +// ----- + +// Test with four sub-groups in each dimension and an additional dimension. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1, 1], threadsPerWarp = [1, 16, 1], warpsPerCTA = [4, 4, 1], order = [0, 1, 2]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [4, 4, 1], order = [0, 1, 2]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test(%arg0: tensor<64x64x1xf32, #blocked>) -> tensor<64x64x1xf32, #blocked1> { + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_68]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64x1xf32, #blocked> -> tensor<64x64x1xf32, #blocked1> + tt.return %0 : tensor<64x64x1xf32, #blocked1> + } +} +// ----- + +// Test with four sub-groups in each dimension and sliced layout. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1, 1], threadsPerWarp = [1, 16, 1], warpsPerCTA = [4, 4, 1], order = [0, 1, 2]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [4, 4], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test(%arg0: tensor<64x64xf32, #triton_gpu.slice<{dim = 2, parent = #blocked}>>) -> tensor<64x64xf32, #blocked1> { + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_68]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 + %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> -> tensor<64x64xf32, #blocked1> + tt.return %0 : tensor<64x64xf32, #blocked1> + } +} + +// ----- + +// Test with one sub-group and double-sliced layout. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1], order = [1, 2, 3, 4, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 1], order = [1, 2, 0]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test(%arg0: tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>>) -> tensor<16x16x1xf32, #blocked1> { + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_68]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 + %0 = triton_gpu.convert_layout %arg0 : tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> -> tensor<16x16x1xf32, #blocked1> + tt.return %0 : tensor<16x16x1xf32, #blocked1> + } +} + +// ----- + +// Test with four sub-groups in each dimension and double-sliced layout. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1], threadsPerWarp = [1, 16, 1, 1, 1], warpsPerCTA = [4, 1, 1, 4, 1], order = [1, 2, 3, 4, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [4, 1, 4], order = [1, 2, 0]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> + tt.func @test(%arg0: tensor<64x16x4xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>>) -> tensor<64x16x4xf32, #blocked1> { + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_68]], + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 + %0 = triton_gpu.convert_layout %arg0 : tensor<64x16x4xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> -> tensor<64x16x4xf32, #blocked1> + tt.return %0 : tensor<64x16x4xf32, #blocked1> + } +} diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 44f33b13e2..44e85b869f 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -2,6 +2,8 @@ #include "TargetInfo.h" #include "Utility.h" +#include "llvm/ADT/TypeSwitch.h" + #include "intel/include/Analysis/Utility.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h" @@ -432,11 +434,34 @@ struct ConvertLayoutOpConversion struct ConvertLayoutOpUsingLinearLayoutsConversion : public ConvertOpToLLVMPattern { + constexpr static unsigned maxSubGroupTransposeWidth = 64; + // Set benefit to 2 so that this pattern applies before other convert-layout // conversions. TODO(jlebar): Eventually we want this to be the only pattern. - explicit ConvertLayoutOpUsingLinearLayoutsConversion( - LLVMTypeConverter &typeConverter, PatternBenefit benefit = 2) - : ConvertOpToLLVMPattern(typeConverter, benefit) {} + ConvertLayoutOpUsingLinearLayoutsConversion( + LLVMTypeConverter &typeConverter, + const triton::intel::TargetInfo &targetInfo, PatternBenefit benefit = 2) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + bool isSubGroupTranspose(const LinearLayout &srcLayout, + const LinearLayout &dstLayout) const { + if (srcLayout.getInDimNames().empty()) + return false; + + MLIRContext *ctx = srcLayout.getInDimNames().begin()->getContext(); + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + if (!srcLayout.hasInDim(kLane) || !dstLayout.hasInDim(kLane) || + !srcLayout.hasInDim(kRegister) || !dstLayout.hasInDim(kRegister)) + return false; + + auto srcBases = srcLayout.getBases(); + auto dstBases = dstLayout.getBases(); + + return srcBases[kRegister] == dstBases[kLane] && + srcBases[kLane] == dstBases[kRegister]; + } LogicalResult matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor, @@ -537,10 +562,186 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion const LinearLayout &dstLayout, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - // TODO(jlebar): Implement me. + if (isSubGroupTranspose(srcLayout, dstLayout)) + return performSubGroupTranspose(op, srcLayout, dstLayout, adaptor, + rewriter); return failure(); } + bool isValidTypeForSubGroupTranspose(Type type) const { + return TypeSwitch(type) + .Case([](IntegerType intTy) { + unsigned width = intTy.getWidth(); + return width == 8 || width == 16 || width == 32 || width == 64; + }) + .Default([](auto) { return false; }); + } + + LogicalResult + performSubGroupTranspose(ConvertLayoutOp op, const LinearLayout &srcLayout, + const LinearLayout &dstLayout, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op.getLoc(); + + SmallVector inVals = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + + // TODO: Support multiples of sub_group_size + auto mod = op->getParentOfType(); + if (inVals.size() != + mlir::triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod)) + return failure(); + + auto srcTy = cast(op.getSrc().getType()); + Type origElemTy = srcTy.getElementType(); + + LogicalResult conversionRes = + TypeSwitch(origElemTy) + .Case([&](FloatType floatTy) { + // TODO: Support FP4. + Type dstType = int_ty(floatTy.getWidth()); + if (!isValidTypeForSubGroupTranspose(dstType)) + return failure(); + llvm::transform( + inVals, std::begin(inVals), + [&](Value val) -> Value { return bitcast(val, dstType); }); + return success(); + }) + .Case([&](IntegerType intTy) { + if (isValidTypeForSubGroupTranspose(intTy)) + return success(); + if (intTy.getWidth() > maxSubGroupTransposeWidth) + return failure(); + // intTy.getWidth() < minSubGroupTransposeWidth + Type dstType = i8_ty; + llvm::transform( + inVals, std::begin(inVals), + [&](Value val) -> Value { return zext(dstType, val); }); + return success(); + }) + .Case([&](triton::PointerType) { + Type dstType = i64_ty; + assert(isValidTypeForSubGroupTranspose(dstType) && + "i64 type should be supported"); + llvm::transform( + inVals, std::begin(inVals), + [&](Value val) -> Value { return ptrtoint(dstType, val); }); + return success(); + }) + .Default([&](auto) { return failure(); }); + + if (failed(conversionRes)) + return conversionRes; + + SmallVector outVals = + performSubGroupTranspose(loc, inVals, rewriter); + + TypeSwitch(origElemTy) + .Case([&](FloatType floatTy) { + llvm::transform( + outVals, std::begin(outVals), + [&](Value val) -> Value { return bitcast(val, origElemTy); }); + }) + .Case([&](IntegerType intTy) { + // Check whether conversion took place. + if (intTy == outVals.front().getType()) + return; + llvm::transform( + outVals, std::begin(outVals), + [&](Value val) -> Value { return trunc(origElemTy, val); }); + }) + .Case([&](triton::PointerType ptrTy) { + Type llvmPtrTy = getTypeConverter()->convertType(ptrTy); + assert(llvmPtrTy && "Type conversion failed"); + llvm::transform( + outVals, std::begin(outVals), + [&](Value val) -> Value { return inttoptr(llvmPtrTy, val); }); + }); + + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, + op.getType()); + rewriter.replaceOp(op, result); + return success(); + } + + VectorType + getTypeForSubGroupTranspose(ArrayRef inVals, + ConversionPatternRewriter &rewriter) const { + auto elementTy = cast(inVals.front().getType()); + return elementTy.getWidth() <= 16 ? vec_ty(elementTy, 16) + : vec_ty(elementTy, 8); + } + + Value wrapInVector(Location loc, VectorType type, ArrayRef values, + ConversionPatternRewriter &rewriter) const { + assert(type.getShape()[0] == values.size() && "Size mismatch"); + Value res = rewriter.create(loc, type); + for (auto [index, val] : llvm::enumerate(values)) + res = insert_element(res, val, i32_val(index)); + return res; + } + + SmallVector + unwrapFromVector(Location loc, Value vec, + ConversionPatternRewriter &rewriter) const { + SmallVector res; + for (unsigned i = 0, n = cast(vec.getType()).getShape()[0]; + i < n; ++i) + res.push_back(extract_element(vec, i32_val(i))); + return res; + } + + SmallVector + performSubGroupTranspose(Location loc, ArrayRef inVals, + ConversionPatternRewriter &rewriter) const { + VectorType opType = getTypeForSubGroupTranspose(inVals, rewriter); + auto mod = rewriter.getInsertionPoint()->getParentOfType(); + unsigned vecWidth = opType.getShape()[0]; + + Value smemBase = LLVM::intel::getSharedMemoryBase( + loc, rewriter, targetInfo, &*rewriter.getInsertionPoint()); + Type ptrType = smemBase.getType(); + + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + int offset = threadsPerWarp; + Type offsetType = getTypeConverter()->getIndexType(); + Value subGroupId = getValueOrCreateCastToIndexLike( + rewriter, loc, offsetType, + rewriter.create( + loc, /*upper_bound=*/IntegerAttr{})); + Value subGroupLocalId = getValueOrCreateCastToIndexLike( + rewriter, loc, offsetType, + rewriter.create(loc, + /*upper_bound=*/IntegerAttr{})); + Value wiStride = + rewriter.create(loc, offsetType, threadsPerWarp); + Value sgStride = rewriter.create( + loc, offsetType, threadsPerWarp * threadsPerWarp); + Value subGroupOffset = mul(sgStride, subGroupId); + Type elementType = opType.getElementType(); + Value subGroupBasePtr = gep(ptrType, elementType, smemBase, + ValueRange{subGroupOffset}, /*inbounds=*/true); + Value base = subGroupBasePtr; + // Store in matrix, non-transposed + for (ArrayRef vals = inVals; !vals.empty(); + vals = vals.drop_front(vecWidth)) { + ArrayRef curr = vals.take_front(vecWidth); + Value vec = wrapInVector(loc, opType, curr, rewriter); + rewriter.create(loc, base, vec); + base = gep(base.getType(), opType, base, ArrayRef{offset}, + /*inbounds=*/true); + } + + // Load from matrix, trasposed. + Value workItemOffset = mul(wiStride, subGroupLocalId); + Value workItemBasePtr = gep(ptrType, elementType, subGroupBasePtr, + ValueRange{workItemOffset}, /*inbounds=*/true); + Value transposedVec = + load(vec_ty(opType.getElementType(), inVals.size()), workItemBasePtr); + + return unwrapFromVector(loc, transposedVec, rewriter); + } + LogicalResult transferWithinBlockGroup(ConvertLayoutOp op, const LinearLayout &srcLayout, const LinearLayout &dstLayout, OpAdaptor adaptor, @@ -548,6 +749,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // TODO(jlebar): Implement me. return failure(); } + +private: + const triton::intel::TargetInfo &targetInfo; }; } // namespace @@ -560,7 +764,7 @@ void mlir::triton::intel::populateConvertLayoutOpToLLVMPatterns( // Eventually the LL conversion will subsume all of the others and be the only // one left. patterns.add( - typeConverter, benefit.getBenefit() + 1); + typeConverter, targetInfo, benefit.getBenefit() + 1); patterns.add(typeConverter, targetInfo, benefit); } From 986d273f43bb5c4b623e56e66b6d4ffed5703f5a Mon Sep 17 00:00:00 2001 From: victor-eds Date: Fri, 18 Oct 2024 14:00:35 +0100 Subject: [PATCH 2/7] Check transposition feasibility before performing it --- .../ConvertLayoutOpToLLVM.cpp | 132 ++++++++++-------- 1 file changed, 77 insertions(+), 55 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 44e85b869f..65694cb9e0 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -434,7 +434,7 @@ struct ConvertLayoutOpConversion struct ConvertLayoutOpUsingLinearLayoutsConversion : public ConvertOpToLLVMPattern { - constexpr static unsigned maxSubGroupTransposeWidth = 64; + constexpr static unsigned minSubGroupTransposeWidth = 8; // Set benefit to 2 so that this pattern applies before other convert-layout // conversions. TODO(jlebar): Eventually we want this to be the only pattern. @@ -557,14 +557,44 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return success(); } + bool isSupportedSubGroupTranspose(ConvertLayoutOp op, + OpAdaptor adaptor) const { + auto srcType = cast(adaptor.getSrc().getType()); + ArrayRef body = srcType.getBody(); + auto mod = op->getParentOfType(); + // Only supporting sub_group_size^2 transpositions for now. + if (body.size() != + mlir::triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod)) + return false; + return TypeSwitch(body.front()) + .Case([this](FloatType floatTy) { + // Support via bitcasting to integer type. + return isValidTypeForSubGroupTranspose( + IntegerType::get(floatTy.getContext(), floatTy.getWidth())); + }) + .Case([this](IntegerType intTy) { + // Support via extending to supported type. + return isValidTypeForSubGroupTranspose(intTy) || + intTy.getWidth() < minSubGroupTransposeWidth; + }) + .Case([](LLVM::LLVMPointerType) { + // Support via ptrtoint + return true; + }) + .Default([](auto) { return false; }); + } + LogicalResult transferWithinLane(ConvertLayoutOp op, const LinearLayout &srcLayout, const LinearLayout &dstLayout, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - if (isSubGroupTranspose(srcLayout, dstLayout)) - return performSubGroupTranspose(op, srcLayout, dstLayout, adaptor, - rewriter); + // If the operation is a supported sub-group transposition, perform via SLM. + if (isSubGroupTranspose(srcLayout, dstLayout) && + isSupportedSubGroupTranspose(op, adaptor)) { + performSubGroupTranspose(op, srcLayout, dstLayout, adaptor, rewriter); + return success(); + } return failure(); } @@ -577,10 +607,16 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion .Default([](auto) { return false; }); } - LogicalResult - performSubGroupTranspose(ConvertLayoutOp op, const LinearLayout &srcLayout, - const LinearLayout &dstLayout, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + void performSubGroupTranspose(ConvertLayoutOp op, + const LinearLayout &srcLayout, + const LinearLayout &dstLayout, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(isSubGroupTranspose(srcLayout, dstLayout) && + "Expecting sub-group transpose"); + assert(isSupportedSubGroupTranspose(op, adaptor) && + "Expecting supported sub-group transpose"); + Location loc = op.getLoc(); SmallVector inVals = @@ -588,50 +624,39 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // TODO: Support multiples of sub_group_size auto mod = op->getParentOfType(); - if (inVals.size() != - mlir::triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod)) - return failure(); auto srcTy = cast(op.getSrc().getType()); - Type origElemTy = srcTy.getElementType(); - - LogicalResult conversionRes = - TypeSwitch(origElemTy) - .Case([&](FloatType floatTy) { - // TODO: Support FP4. - Type dstType = int_ty(floatTy.getWidth()); - if (!isValidTypeForSubGroupTranspose(dstType)) - return failure(); - llvm::transform( - inVals, std::begin(inVals), - [&](Value val) -> Value { return bitcast(val, dstType); }); - return success(); - }) - .Case([&](IntegerType intTy) { - if (isValidTypeForSubGroupTranspose(intTy)) - return success(); - if (intTy.getWidth() > maxSubGroupTransposeWidth) - return failure(); - // intTy.getWidth() < minSubGroupTransposeWidth - Type dstType = i8_ty; - llvm::transform( - inVals, std::begin(inVals), - [&](Value val) -> Value { return zext(dstType, val); }); - return success(); - }) - .Case([&](triton::PointerType) { - Type dstType = i64_ty; - assert(isValidTypeForSubGroupTranspose(dstType) && - "i64 type should be supported"); - llvm::transform( - inVals, std::begin(inVals), - [&](Value val) -> Value { return ptrtoint(dstType, val); }); - return success(); - }) - .Default([&](auto) { return failure(); }); - - if (failed(conversionRes)) - return conversionRes; + Type origElemTy = inVals.front().getType(); + + TypeSwitch(origElemTy) + .Case([&](FloatType floatTy) { + // TODO: Support FP4. + Type dstType = int_ty(floatTy.getWidth()); + assert(isValidTypeForSubGroupTranspose(dstType) && + "Expecting valid type"); + llvm::transform(inVals, std::begin(inVals), [&](Value val) -> Value { + return bitcast(val, dstType); + }); + }) + .Case([&](IntegerType intTy) { + if (isValidTypeForSubGroupTranspose(intTy)) + return; + assert(intTy.getWidth() < minSubGroupTransposeWidth && + "Expecting type to extend to i8"); + Type dstType = i8_ty; + llvm::transform(inVals, std::begin(inVals), [&](Value val) -> Value { + return zext(dstType, val); + }); + }) + .Case([&](LLVM::LLVMPointerType) { + Type dstType = i64_ty; + assert(isValidTypeForSubGroupTranspose(dstType) && + "i64 type should be supported"); + llvm::transform(inVals, std::begin(inVals), [&](Value val) -> Value { + return ptrtoint(dstType, val); + }); + }) + .Default([](auto) { llvm_unreachable("Unsupported type"); }); SmallVector outVals = performSubGroupTranspose(loc, inVals, rewriter); @@ -650,18 +675,15 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion outVals, std::begin(outVals), [&](Value val) -> Value { return trunc(origElemTy, val); }); }) - .Case([&](triton::PointerType ptrTy) { - Type llvmPtrTy = getTypeConverter()->convertType(ptrTy); - assert(llvmPtrTy && "Type conversion failed"); + .Case([&](LLVM::LLVMPointerType ptrTy) { llvm::transform( outVals, std::begin(outVals), - [&](Value val) -> Value { return inttoptr(llvmPtrTy, val); }); + [&](Value val) -> Value { return inttoptr(ptrTy, val); }); }); Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, op.getType()); rewriter.replaceOp(op, result); - return success(); } VectorType From fe480efb4cf82deba06e6c9f1ab76620f6eb89ba Mon Sep 17 00:00:00 2001 From: victor-eds Date: Fri, 18 Oct 2024 16:28:39 +0100 Subject: [PATCH 3/7] Make safer and simplify tests --- .../Conversion/intel/sub-group-transpose.mlir | 237 +++++------------- .../ConvertLayoutOpToLLVM.cpp | 40 +-- 2 files changed, 86 insertions(+), 191 deletions(-) diff --git a/test/Conversion/intel/sub-group-transpose.mlir b/test/Conversion/intel/sub-group-transpose.mlir index a666bee600..f3f2933671 100644 --- a/test/Conversion/intel/sub-group-transpose.mlir +++ b/test/Conversion/intel/sub-group-transpose.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm --canonicalize | FileCheck %s +// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s // Basic 16x16 transpose test @@ -9,18 +9,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: llvm.func spir_kernelcc @test_f16( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_f16(%arg0: tensor<16x16xf16, #blocked>) -> tensor<16x16xf16, #blocked1> { - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f16 to i16 + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_us16PU3AS3tDv16_t(%[[VAL_59]] // CHECK-SAME: (!llvm.ptr<3>, vector<16xi16>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi16> // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i16 to f16 @@ -31,18 +33,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: llvm.func spir_kernelcc @test_bf16( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_bf16(%arg0: tensor<16x16xbf16, #blocked>) -> tensor<16x16xbf16, #blocked1> { - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 // CHECK-COUNT-16: llvm.bitcast %{{.*}} : bf16 to i16 + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_us16PU3AS3tDv16_t(%[[VAL_59]] // CHECK-SAME: (!llvm.ptr<3>, vector<16xi16>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i16 // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi16> // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i16 to bf16 @@ -53,21 +57,23 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: llvm.func spir_kernelcc @test_f32( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_f32(%arg0: tensor<16x16xf32, #blocked>) -> tensor<16x16xf32, #blocked1> { - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]], + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]] // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_68]], + // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]] // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 @@ -78,17 +84,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: llvm.func spir_kernelcc @test_i8( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_i8(%arg0: tensor<16x16xi8, #blocked>) -> tensor<16x16xi8, #blocked1> { - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_uc16PU3AS3hDv16_h(%[[VAL_59]] // CHECK-SAME: (!llvm.ptr<3>, vector<16xi8>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi8> %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xi8, #blocked> -> tensor<16x16xi8, #blocked1> @@ -98,20 +106,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: llvm.func spir_kernelcc @test_i64( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_i64(%arg0: tensor<16x16xi64, #blocked>) -> tensor<16x16xi64, #blocked1> { - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_59]], + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_59]] // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () - // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi64> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_68]], + // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi64> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_60]] // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi64> %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xi64, #blocked> -> tensor<16x16xi64, #blocked1> @@ -121,21 +131,23 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: llvm.func spir_kernelcc @test_ptr( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_ptr(%arg0: tensor<16x16x!tt.ptr, #blocked>) -> tensor<16x16x!tt.ptr, #blocked1> { - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 // CHECK-COUNT-16: llvm.ptrtoint %{{.*}} : !llvm.ptr<1> to i64 + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_59]], + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_59]] // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () - // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi64> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_68]], + // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi64> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul8PU3AS3mDv8_m(%[[VAL_60]] // CHECK-SAME: (!llvm.ptr<3>, vector<8xi64>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64 // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi64> // CHECK-COUNT-16: llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<1> @@ -146,18 +158,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: llvm.func spir_kernelcc @test_i1( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test_i1(%arg0: tensor<16x16xi1, #blocked>) -> tensor<16x16xi1, #blocked1> { - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 // CHECK-COUNT-16: llvm.zext %{{.*}} : i1 to i8 + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 // CHECK: llvm.call spir_funccc @_Z32intel_sub_group_block_write_uc16PU3AS3hDv16_h(%[[VAL_59]] // CHECK-SAME: (!llvm.ptr<3>, vector<16xi8>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi8> // CHECK-COUNT-16: llvm.trunc %{{.*}} : i8 to i1 @@ -177,24 +191,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked1> { - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]], - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_68]], - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> - // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( %0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #blocked1> tt.return %0 : tensor<32x16xf32, #blocked1> } @@ -211,24 +208,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<16x32xf32, #blocked>) -> tensor<16x32xf32, #blocked1> { - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]], - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_68]], - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> - // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( %0 = triton_gpu.convert_layout %arg0 : tensor<16x32xf32, #blocked> -> tensor<16x32xf32, #blocked1> tt.return %0 : tensor<16x32xf32, #blocked1> } @@ -245,24 +225,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<64x64xf32, #blocked>) -> tensor<64x64xf32, #blocked1> { - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]], - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_68]], - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> - // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1> tt.return %0 : tensor<64x64xf32, #blocked1> } @@ -279,24 +242,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<64x64x1xf32, #blocked>) -> tensor<64x64x1xf32, #blocked1> { - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]], - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_68]], - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> - // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( %0 = triton_gpu.convert_layout %arg0 : tensor<64x64x1xf32, #blocked> -> tensor<64x64x1xf32, #blocked1> tt.return %0 : tensor<64x64x1xf32, #blocked1> } @@ -312,24 +258,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<64x64xf32, #triton_gpu.slice<{dim = 2, parent = #blocked}>>) -> tensor<64x64xf32, #blocked1> { - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]], - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_68]], - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> - // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( %0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf32, #triton_gpu.slice<{dim = 2, parent = #blocked}>> -> tensor<64x64xf32, #blocked1> tt.return %0 : tensor<64x64xf32, #blocked1> } @@ -346,24 +275,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>>) -> tensor<16x16x1xf32, #blocked1> { - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]], - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_68]], - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> - // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( %0 = triton_gpu.convert_layout %arg0 : tensor<16x16x1xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> -> tensor<16x16x1xf32, #blocked1> tt.return %0 : tensor<16x16x1xf32, #blocked1> } @@ -380,24 +292,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3> tt.func @test(%arg0: tensor<64x16x4xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>>) -> tensor<64x16x4xf32, #blocked1> { - // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(256 : i64) : i64 - // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 - // CHECK-COUNT-16: llvm.bitcast %{{.*}} : f32 to i32 - // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() - // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 - // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() - // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 - // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_55]], %[[VAL_19]] : i64 - // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[VAL_1]][%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]], - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_68:.*]] = llvm.getelementptr inbounds %[[VAL_59]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> - // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_68]], - // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () - // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_57]], %[[VAL_20]] : i64 - // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 - // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> - // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i32 to f32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j( %0 = triton_gpu.convert_layout %arg0 : tensor<64x16x4xf32, #triton_gpu.slice<{dim = 2, parent = #triton_gpu.slice<{dim = 4, parent = #blocked}>}>> -> tensor<64x16x4xf32, #blocked1> tt.return %0 : tensor<64x16x4xf32, #blocked1> } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 65694cb9e0..466cc26d78 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -446,21 +446,23 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion bool isSubGroupTranspose(const LinearLayout &srcLayout, const LinearLayout &dstLayout) const { - if (srcLayout.getInDimNames().empty()) - return false; - MLIRContext *ctx = srcLayout.getInDimNames().begin()->getContext(); StringAttr kRegister = str_attr("register"); StringAttr kLane = str_attr("lane"); - if (!srcLayout.hasInDim(kLane) || !dstLayout.hasInDim(kLane) || - !srcLayout.hasInDim(kRegister) || !dstLayout.hasInDim(kRegister)) - return false; - - auto srcBases = srcLayout.getBases(); - auto dstBases = dstLayout.getBases(); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); - return srcBases[kRegister] == dstBases[kLane] && - srcBases[kLane] == dstBases[kRegister]; + LinearLayout comp = srcLayout.invertAndCompose(dstLayout); + std::optional conversion = comp.divideRight( + LinearLayout::identity1D(comp.getInDimSize(kWarp), kWarp, kWarp) * + LinearLayout::identity1D(comp.getInDimSize(kBlock), kBlock, kBlock)); + assert(conversion && "Expecting valid conversion"); + LinearLayout id = + LinearLayout::identity1D(conversion->getInDimSize(kRegister), kRegister, + kRegister) * + LinearLayout::identity1D(conversion->getInDimSize(kLane), kLane, kLane); + // Composing the transposition with itself should give us the identity. + return id == conversion->compose(*conversion); } LogicalResult @@ -561,10 +563,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion OpAdaptor adaptor) const { auto srcType = cast(adaptor.getSrc().getType()); ArrayRef body = srcType.getBody(); + // TODO: Support more configurations. auto mod = op->getParentOfType(); - // Only supporting sub_group_size^2 transpositions for now. - if (body.size() != - mlir::triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod)) + int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + if (body.size() != threadsPerWarp) return false; return TypeSwitch(body.front()) .Case([this](FloatType floatTy) { @@ -581,7 +583,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // Support via ptrtoint return true; }) - .Default([](auto) { return false; }); + .Default(false); } LogicalResult transferWithinLane(ConvertLayoutOp op, @@ -604,7 +606,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion unsigned width = intTy.getWidth(); return width == 8 || width == 16 || width == 32 || width == 64; }) - .Default([](auto) { return false; }); + .Default(false); } void performSubGroupTranspose(ConvertLayoutOp op, @@ -622,9 +624,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion SmallVector inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - // TODO: Support multiples of sub_group_size - auto mod = op->getParentOfType(); - auto srcTy = cast(op.getSrc().getType()); Type origElemTy = inVals.front().getType(); @@ -679,7 +678,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion llvm::transform( outVals, std::begin(outVals), [&](Value val) -> Value { return inttoptr(ptrTy, val); }); - }); + }) + .Default([](auto) { llvm_unreachable("Unsupported type"); }); Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, op.getType()); From 5b59ae404b26d07c6a29cd13e27359583a070c26 Mon Sep 17 00:00:00 2001 From: victor-eds Date: Fri, 18 Oct 2024 16:31:49 +0100 Subject: [PATCH 4/7] Add comment back --- .../intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 466cc26d78..635bfc590b 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -597,6 +597,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion performSubGroupTranspose(op, srcLayout, dstLayout, adaptor, rewriter); return success(); } + // TODO(jlebar): Implement me. return failure(); } From f6649b9a2b895fbe039caceb4878ff94e5b19ff6 Mon Sep 17 00:00:00 2001 From: victor-eds Date: Fri, 18 Oct 2024 17:38:53 +0100 Subject: [PATCH 5/7] Fix layout check --- .../ConvertLayoutOpToLLVM.cpp | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 635bfc590b..e3699680ce 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -457,12 +457,21 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion LinearLayout::identity1D(comp.getInDimSize(kWarp), kWarp, kWarp) * LinearLayout::identity1D(comp.getInDimSize(kBlock), kBlock, kBlock)); assert(conversion && "Expecting valid conversion"); - LinearLayout id = - LinearLayout::identity1D(conversion->getInDimSize(kRegister), kRegister, - kRegister) * - LinearLayout::identity1D(conversion->getInDimSize(kLane), kLane, kLane); - // Composing the transposition with itself should give us the identity. - return id == conversion->compose(*conversion); + // Expected conversion is: + // - register=1 -> (0, 1) + // register=2 -> (0, 2) + // register=4 -> (0, 4) + // register=8 -> (0, 8) + // - lane=1 -> (1, 0) + // lane=2 -> (2, 0) + // lane=4 -> (4, 0) + // lane=8 -> (8, 0) + // where out dims are: [register (size 16), lane (size 16)] + std::array>>, 2> + bases{{{kRegister, {{0, 1}, {0, 2}, {0, 4}, {0, 8}}}, + {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}}}}}; + std::array outDimNames{kRegister, kLane}; + return conversion == LinearLayout(bases, outDimNames); } LogicalResult From 602175b66276542fa16cb1766b127775ef6fadd0 Mon Sep 17 00:00:00 2001 From: victor-eds Date: Mon, 21 Oct 2024 15:39:56 +0100 Subject: [PATCH 6/7] [TritonIntelGPUToLLVM] Extend sub-group transposition support Extend sub-group transposition support allowing `N*sub_group_size` elements per thread. As per block load semantics (matrix of `sub_group_size` columns), we need `N` vector loads to load the transposed matrix from local memory. Signed-off-by: victor-eds --- .../Conversion/intel/sub-group-transpose.mlir | 129 ++++++++++++++++++ .../ConvertLayoutOpToLLVM.cpp | 52 ++++--- 2 files changed, 164 insertions(+), 17 deletions(-) diff --git a/test/Conversion/intel/sub-group-transpose.mlir b/test/Conversion/intel/sub-group-transpose.mlir index f3f2933671..71dd96f9ca 100644 --- a/test/Conversion/intel/sub-group-transpose.mlir +++ b/test/Conversion/intel/sub-group-transpose.mlir @@ -297,3 +297,132 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 tt.return %0 : tensor<64x16x4xf32, #blocked1> } } + +// ----- + +// Test transposition with 32 elements per work-item. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) + tt.func @test(%arg0: tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked1> { + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32> + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32 + %0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #blocked> -> tensor<32x16xf32, #blocked1> + tt.return %0 : tensor<32x16xf32, #blocked1> + } +} + +// ----- + +// Test transposition with 32 elements per work-item with a different layout. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) + tt.func @test(%arg0: tensor<16x32xf32, #blocked>) -> tensor<16x32xf32, #blocked1> { + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32> + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32 + %0 = triton_gpu.convert_layout %arg0 : tensor<16x32xf32, #blocked> -> tensor<16x32xf32, #blocked1> + tt.return %0 : tensor<16x32xf32, #blocked1> + } +} + +// ----- + +// Test transposition with 32 elements per work-item and two warps in each dimension. + +#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [0, 1]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test( + // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) + tt.func @test(%arg0: tensor<32x64xf32, #blocked>) -> tensor<32x64xf32, #blocked1> { + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : f32 to i32 + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %[[VAL_1]]{{\[}}%[[ZERO]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_54:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK: %[[VAL_55:.*]] = llvm.zext %[[VAL_54]] : i32 to i64 + // CHECK: %[[VAL_56:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK: %[[VAL_57:.*]] = llvm.zext %[[VAL_56]] : i32 to i64 + // CHECK-DAG: %[[VAL_19:.*]] = llvm.mlir.constant(512 : i64) : i64 + // CHECK-DAG: %[[VAL_20:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[VAL_58:.*]] = llvm.mul %[[VAL_19]], %[[VAL_55]] : i64 + // CHECK: %[[VAL_59:.*]] = llvm.getelementptr inbounds %[[BASE]]{{\[}}%[[VAL_58]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_59]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_60:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_60]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_61:.*]] = llvm.getelementptr inbounds %[[VAL_60]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_61]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_62:.*]] = llvm.getelementptr inbounds %[[VAL_61]]{{\[}}16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<8xi32> + // CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS3jDv8_j(%[[VAL_62]] + // CHECK-SAME: (!llvm.ptr<3>, vector<8xi32>) -> () + // CHECK: %[[VAL_76:.*]] = llvm.mul %[[VAL_20]], %[[VAL_57]] : i64 + // CHECK: %[[VAL_77:.*]] = llvm.getelementptr inbounds %[[VAL_59]]{{\[}}%[[VAL_76]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i32 + // CHECK: llvm.load %[[VAL_77]] : !llvm.ptr<3> -> vector<16xi32> + // CHECK: %[[VAL_78:.*]] = llvm.getelementptr inbounds %[[VAL_77]][16] : (!llvm.ptr<3>) -> !llvm.ptr<3>, vector<16xi32> + // CHECK-COUNT-32: llvm.bitcast %{{.*}} : i32 to f32 + %0 = triton_gpu.convert_layout %arg0 : tensor<32x64xf32, #blocked> -> tensor<32x64xf32, #blocked1> + tt.return %0 : tensor<32x64xf32, #blocked1> + } +} diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index e3699680ce..895ec5c27e 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -462,13 +462,26 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // register=2 -> (0, 2) // register=4 -> (0, 4) // register=8 -> (0, 8) + // register=N -> (N, 0) + // ... // - lane=1 -> (1, 0) // lane=2 -> (2, 0) // lane=4 -> (4, 0) // lane=8 -> (8, 0) - // where out dims are: [register (size 16), lane (size 16)] + // where out dims are: [register (size 2*N), lane (size 16)] + std::vector> registerBases{ + {0, 1}, {0, 2}, {0, 4}, {0, 8}}; + { + // Populate register bases for N > 8. + std::vector base(2); + for (int32_t i = 16, n = conversion->getInDimSize(kRegister); i < n; + i *= 2) { + base.front() = i; + registerBases.push_back(base); + } + } std::array>>, 2> - bases{{{kRegister, {{0, 1}, {0, 2}, {0, 4}, {0, 8}}}, + bases{{{kRegister, std::move(registerBases)}, {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}}}}}; std::array outDimNames{kRegister, kLane}; return conversion == LinearLayout(bases, outDimNames); @@ -572,11 +585,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion OpAdaptor adaptor) const { auto srcType = cast(adaptor.getSrc().getType()); ArrayRef body = srcType.getBody(); - // TODO: Support more configurations. - auto mod = op->getParentOfType(); - int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - if (body.size() != threadsPerWarp) - return false; return TypeSwitch(body.front()) .Case([this](FloatType floatTy) { // Support via bitcasting to integer type. @@ -714,12 +722,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } SmallVector - unwrapFromVector(Location loc, Value vec, - ConversionPatternRewriter &rewriter) const { + unwrapFromVectors(Location loc, ArrayRef vecs, + ConversionPatternRewriter &rewriter) const { SmallVector res; - for (unsigned i = 0, n = cast(vec.getType()).getShape()[0]; - i < n; ++i) - res.push_back(extract_element(vec, i32_val(i))); + for (Value vec : vecs) { + for (unsigned i = 0, n = cast(vec.getType()).getShape()[0]; + i < n; ++i) + res.push_back(extract_element(vec, i32_val(i))); + } return res; } @@ -734,6 +744,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion loc, rewriter, targetInfo, &*rewriter.getInsertionPoint()); Type ptrType = smemBase.getType(); + int numElements = inVals.size(); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); int offset = threadsPerWarp; Type offsetType = getTypeConverter()->getIndexType(); @@ -748,7 +759,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion Value wiStride = rewriter.create(loc, offsetType, threadsPerWarp); Value sgStride = rewriter.create( - loc, offsetType, threadsPerWarp * threadsPerWarp); + loc, offsetType, threadsPerWarp * numElements); Value subGroupOffset = mul(sgStride, subGroupId); Type elementType = opType.getElementType(); Value subGroupBasePtr = gep(ptrType, elementType, smemBase, @@ -765,13 +776,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } // Load from matrix, trasposed. + // As per SIMD block semantics, we have stored the elements in a matrix of + // `Nxsub_group_size` size, so we need to load back in blocks of + // `sub_group_size` (`N/sub_group_size` loads). Value workItemOffset = mul(wiStride, subGroupLocalId); Value workItemBasePtr = gep(ptrType, elementType, subGroupBasePtr, ValueRange{workItemOffset}, /*inbounds=*/true); - Value transposedVec = - load(vec_ty(opType.getElementType(), inVals.size()), workItemBasePtr); - - return unwrapFromVector(loc, transposedVec, rewriter); + SmallVector transposedVecs; + Type loadTy = vec_ty(opType.getElementType(), threadsPerWarp); + for (std::size_t i = 0, n = inVals.size(); i < n; i += threadsPerWarp) { + transposedVecs.push_back(load(loadTy, workItemBasePtr)); + workItemBasePtr = gep(ptrType, loadTy, workItemBasePtr, + ArrayRef{offset}, /*inbounds=*/true); + } + return unwrapFromVectors(loc, transposedVecs, rewriter); } LogicalResult From 79410a5229f00914c9d00876d07eaed4d9bb523f Mon Sep 17 00:00:00 2001 From: victor-eds Date: Thu, 24 Oct 2024 13:06:53 +0100 Subject: [PATCH 7/7] Fix rebase issues --- .../Conversion/intel/sub-group-transpose.mlir | 6 +- .../ConvertLayoutOpToLLVM.cpp | 69 ++++++++++--------- 2 files changed, 40 insertions(+), 35 deletions(-) diff --git a/test/Conversion/intel/sub-group-transpose.mlir b/test/Conversion/intel/sub-group-transpose.mlir index f8ecd9193b..8b2c5bd6aa 100644 --- a/test/Conversion/intel/sub-group-transpose.mlir +++ b/test/Conversion/intel/sub-group-transpose.mlir @@ -305,7 +305,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 16 #blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) tt.func @test(%arg0: tensor<32x16xf32, #blocked>) -> tensor<32x16xf32, #blocked1> { @@ -348,7 +348,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) tt.func @test(%arg0: tensor<16x32xf32, #blocked>) -> tensor<16x32xf32, #blocked1> { @@ -391,7 +391,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [0, 1]}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { // CHECK-LABEL: llvm.func spir_kernelcc @test( // CHECK-SAME: , %[[VAL_1:.*]]: !llvm.ptr<3>) tt.func @test(%arg0: tensor<32x64xf32, #blocked>) -> tensor<32x64xf32, #blocked1> { diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index ac9051b42f..aaec293bac 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -461,30 +461,49 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion assert(conversion && "Expecting valid conversion"); // Expected conversion is: // - register=1 -> (0, 1) - // register=2 -> (0, 2) - // register=4 -> (0, 4) - // register=8 -> (0, 8) - // register=N -> (N, 0) - // ... - // - lane=1 -> (1, 0) - // lane=2 -> (2, 0) - // lane=4 -> (4, 0) - // lane=8 -> (8, 0) - // where out dims are: [register (size 2*N), lane (size 16)] - std::vector> registerBases{ - {0, 1}, {0, 2}, {0, 4}, {0, 8}}; + // ... + // - register=2**i -> (0, 2**i) + // ... + // - register=M -> (0, 2**M) + // ... + // - register=2**k -> (2**k, 0) + // ... + // - register=N -> (2**N, 0) + // - lane=1 -> (0, 1) + // ... + // - lane=2**j -> (2**j, 0) + // ... + // lane=2**M -> (2**M, 0) + // where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))] + // + // With N >= M. + const auto buildBasis = [&](int32_t size, std::size_t index) { + std::vector> basis; + std::vector curr(2); + for (int32_t i = 1; i < size; i *= 2) { + curr[index] = i; + basis.push_back(curr); + } + return basis; + }; + constexpr std::size_t laneIndex = 0; + constexpr std::size_t registerIndex = 1; + int32_t laneSize = conversion->getInDimSize(kLane); + std::vector> registerBases = + buildBasis(laneSize, registerIndex); { - // Populate register bases for N > 8. + // Populate register bases for N > M. std::vector base(2); - for (int32_t i = 16, n = conversion->getInDimSize(kRegister); i < n; - i *= 2) { - base.front() = i; + for (int32_t i = laneSize, + registerSize = conversion->getInDimSize(kRegister); + i < registerSize; i *= 2) { + base[laneIndex] = i; registerBases.push_back(base); } } std::array>>, 2> bases{{{kRegister, std::move(registerBases)}, - {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}}}}}; + {kLane, buildBasis(laneSize, laneIndex)}}}; std::array outDimNames{kRegister, kLane}; return conversion == LinearLayout(bases, outDimNames); } @@ -853,18 +872,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion llvm::transform( outVals, std::begin(outVals), [&](Value val) -> Value { return inttoptr(ptrTy, val); }); - })As a follow up to #2266, extend work in #2531 to detect more complex broadcast shuffles. - -Cases with more than 1 warp in the "sliced" dimension are problematic here, e.g.: - -```mlir -#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, 16, 1, 1, 1], warpsPerCTA = [1, 1, 16, 1, 1, 1, 1], order = [3, 4, 5, 6, 0, 1, 2]}> -#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1, 1, 16, 1], threadsPerWarp = [16, 1, 1, 1, 1], warpsPerCTA = [1, 1, 16, 1, 1], order = [3, 4, 0, 1, 2]}> -// ... -triton_gpu.convert_layout %arg : tensor<16x1x16x16x1xf32, #triton_gpu.slice<{dim = 4, parent = #triton_gpu.slice<{dim = 6, parent = #blocked}>}>> -> tensor<16x1x16x16x1xf32, #blocked1> -``` - -Is lowered to a shufle via + }) .Default([](auto) { llvm_unreachable("Unsupported type"); }); Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, @@ -967,9 +975,6 @@ Is lowered to a shufle via // TODO(jlebar): Implement me. return failure(); } - -private: - const triton::intel::TargetInfo &targetInfo; }; } // namespace