From 70e7eb27c6c90c40cfeeb283d7d396aa72e9b378 Mon Sep 17 00:00:00 2001 From: victor-eds Date: Tue, 22 Oct 2024 13:12:16 +0100 Subject: [PATCH 1/3] [TritonIntelGPUToLLVM] Detect basic sub-group shuffle convert_layout cases Detect basic shuffles and lower to `gpu.shuffle` operations. Basically, support cases in which we go from each work-item having a single tensor element to having `sub_group_size` tensor elements such as element `i` corresponds to the element originally held by work-item `i` in the sub-group. Upstream MLIR pass should handle all integer and floating point types. Drop code handling type legalization for such types when done. Pointer type should still be done in this project. Code should be extended to support other kind of shuffles. Multi-warp case not yet implemented. Signed-off-by: victor-eds --- test/Conversion/intel/sub-group-shuffle.mlir | 259 ++++++++++++++++++ .../ConvertLayoutOpToLLVM.cpp | 157 +++++++++++ 2 files changed, 416 insertions(+) create mode 100644 test/Conversion/intel/sub-group-shuffle.mlir diff --git a/test/Conversion/intel/sub-group-shuffle.mlir b/test/Conversion/intel/sub-group-shuffle.mlir new file mode 100644 index 0000000000..0e80822368 --- /dev/null +++ b/test/Conversion/intel/sub-group-shuffle.mlir @@ -0,0 +1,259 @@ +// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s + +// Basic 16x16 shuffle test + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> +#sliced = #triton_gpu.slice<{dim = 1, parent = #blocked}> +#sliced1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test_f16( + // CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f16)>, + // CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f16)> + // CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]], %[[VAL_4]]) + // CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]], %[[VAL_7]]) + // CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]], %[[VAL_10]]) + // CHECK: %[[VAL_13:.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]], %[[VAL_13]]) + // CHECK: %[[VAL_16:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]], %[[VAL_16]]) + // CHECK: %[[VAL_19:.*]] = llvm.mlir.constant(5 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]], %[[VAL_19]]) + // CHECK: %[[VAL_22:.*]] = llvm.mlir.constant(6 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]], %[[VAL_22]]) + // CHECK: %[[VAL_25:.*]] = llvm.mlir.constant(7 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]], %[[VAL_25]]) + // CHECK: %[[VAL_28:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]], %[[VAL_28]]) + // CHECK: %[[VAL_31:.*]] = llvm.mlir.constant(9 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]], %[[VAL_31]]) + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(10 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]], %[[VAL_34]]) + // CHECK: %[[VAL_37:.*]] = llvm.mlir.constant(11 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]], %[[VAL_37]]) + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(12 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]], %[[VAL_40]]) + // CHECK: %[[VAL_43:.*]] = llvm.mlir.constant(13 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]], %[[VAL_43]]) + // CHECK: %[[VAL_46:.*]] = llvm.mlir.constant(14 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]], %[[VAL_46]]) + // CHECK: %[[VAL_49:.*]] = llvm.mlir.constant(15 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]], %[[VAL_49]]) + tt.func @test_f16(%arg0: tensor<16xf16, #sliced>) -> tensor<16xf16, #sliced1> { + %0 = triton_gpu.convert_layout %arg0 : tensor<16xf16, #sliced> -> tensor<16xf16, #sliced1> + tt.return %0 : tensor<16xf16, #sliced1> + } + + // CHECK-LABEL: llvm.func spir_kernelcc @test_bf16( + // CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(bf16)>, + // CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(bf16)> + // CHECK: %[[VAL_2:.*]] = llvm.bitcast %[[VAL_1]] : bf16 to i16 + // CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflesj(%[[VAL_2]], %[[VAL_4]]) + // CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflesj(%[[VAL_2]], %[[VAL_7]]) + // CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflesj(%[[VAL_2]], %[[VAL_10]]) + // CHECK: %[[VAL_13:.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflesj(%[[VAL_2]], %[[VAL_13]]) + // CHECK: %[[VAL_16:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflesj(%[[VAL_2]], %[[VAL_16]]) + // CHECK: %[[VAL_19:.*]] = llvm.mlir.constant(5 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflesj(%[[VAL_2]], %[[VAL_19]]) + // CHECK: %[[VAL_22:.*]] = llvm.mlir.constant(6 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflesj(%[[VAL_2]], %[[VAL_22]]) + // CHECK: %[[VAL_25:.*]] = llvm.mlir.constant(7 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflesj(%[[VAL_2]], %[[VAL_25]]) + // CHECK: %[[VAL_28:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflesj(%[[VAL_2]], %[[VAL_28]]) + // CHECK: %[[VAL_31:.*]] = llvm.mlir.constant(9 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflesj(%[[VAL_2]], %[[VAL_31]]) + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(10 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflesj(%[[VAL_2]], %[[VAL_34]]) + // CHECK: %[[VAL_37:.*]] = llvm.mlir.constant(11 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflesj(%[[VAL_2]], %[[VAL_37]]) + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(12 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflesj(%[[VAL_2]], %[[VAL_40]]) + // CHECK: %[[VAL_43:.*]] = llvm.mlir.constant(13 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflesj(%[[VAL_2]], %[[VAL_43]]) + // CHECK: %[[VAL_46:.*]] = llvm.mlir.constant(14 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflesj(%[[VAL_2]], %[[VAL_46]]) + // CHECK: %[[VAL_49:.*]] = llvm.mlir.constant(15 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflesj(%[[VAL_2]], %[[VAL_49]]) + // CHECK-COUNT-16: llvm.bitcast %{{.*}} : i16 to bf16 + tt.func @test_bf16(%arg0: tensor<16xbf16, #sliced>) -> tensor<16xbf16, #sliced1> { + %0 = triton_gpu.convert_layout %arg0 : tensor<16xbf16, #sliced> -> tensor<16xbf16, #sliced1> + tt.return %0 : tensor<16xbf16, #sliced1> + } + + // CHECK-LABEL: llvm.func spir_kernelcc @test_i1( + // CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(i1)>, + // CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(i1)> + // CHECK: %[[VAL_2:.*]] = llvm.zext %[[VAL_1]] : i1 to i8 + // CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[VAL_2]], %[[VAL_4]]) + // CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[VAL_2]], %[[VAL_7]]) + // CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[VAL_2]], %[[VAL_10]]) + // CHECK: %[[VAL_13:.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[VAL_2]], %[[VAL_13]]) + // CHECK: %[[VAL_16:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[VAL_2]], %[[VAL_16]]) + // CHECK: %[[VAL_19:.*]] = llvm.mlir.constant(5 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[VAL_2]], %[[VAL_19]]) + // CHECK: %[[VAL_22:.*]] = llvm.mlir.constant(6 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[VAL_2]], %[[VAL_22]]) + // CHECK: %[[VAL_25:.*]] = llvm.mlir.constant(7 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[VAL_2]], %[[VAL_25]]) + // CHECK: %[[VAL_28:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[VAL_2]], %[[VAL_28]]) + // CHECK: %[[VAL_31:.*]] = llvm.mlir.constant(9 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[VAL_2]], %[[VAL_31]]) + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(10 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[VAL_2]], %[[VAL_34]]) + // CHECK: %[[VAL_37:.*]] = llvm.mlir.constant(11 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[VAL_2]], %[[VAL_37]]) + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(12 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[VAL_2]], %[[VAL_40]]) + // CHECK: %[[VAL_43:.*]] = llvm.mlir.constant(13 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[VAL_2]], %[[VAL_43]]) + // CHECK: %[[VAL_46:.*]] = llvm.mlir.constant(14 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[VAL_2]], %[[VAL_46]]) + // CHECK: %[[VAL_49:.*]] = llvm.mlir.constant(15 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[VAL_2]], %[[VAL_49]]) + // CHECK-COUNT-16: llvm.trunc %{{.*}} : i8 to i1 + tt.func @test_i1(%arg0: tensor<16xi1, #sliced>) -> tensor<16xi1, #sliced1> { + %0 = triton_gpu.convert_layout %arg0 : tensor<16xi1, #sliced> -> tensor<16xi1, #sliced1> + tt.return %0 : tensor<16xi1, #sliced1> + } + + // CHECK-LABEL: llvm.func spir_kernelcc @test_ptr( + // CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(ptr<1>)>, + // CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(ptr<1>)> + // CHECK: %[[VAL_2:.*]] = llvm.ptrtoint %[[VAL_1]] : !llvm.ptr<1> to i64 + // CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_2]], %[[VAL_4]]) + // CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_2]], %[[VAL_7]]) + // CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_2]], %[[VAL_10]]) + // CHECK: %[[VAL_13:.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_2]], %[[VAL_13]]) + // CHECK: %[[VAL_16:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_2]], %[[VAL_16]]) + // CHECK: %[[VAL_19:.*]] = llvm.mlir.constant(5 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_2]], %[[VAL_19]]) + // CHECK: %[[VAL_22:.*]] = llvm.mlir.constant(6 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_2]], %[[VAL_22]]) + // CHECK: %[[VAL_25:.*]] = llvm.mlir.constant(7 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_2]], %[[VAL_25]]) + // CHECK: %[[VAL_28:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_2]], %[[VAL_28]]) + // CHECK: %[[VAL_31:.*]] = llvm.mlir.constant(9 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_2]], %[[VAL_31]]) + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(10 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_2]], %[[VAL_34]]) + // CHECK: %[[VAL_37:.*]] = llvm.mlir.constant(11 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_2]], %[[VAL_37]]) + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(12 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_2]], %[[VAL_40]]) + // CHECK: %[[VAL_43:.*]] = llvm.mlir.constant(13 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_2]], %[[VAL_43]]) + // CHECK: %[[VAL_46:.*]] = llvm.mlir.constant(14 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_2]], %[[VAL_46]]) + // CHECK: %[[VAL_49:.*]] = llvm.mlir.constant(15 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[VAL_2]], %[[VAL_49]]) + // CHECK-COUNT-16: llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<1> + tt.func @test_ptr(%arg0: tensor<16x!tt.ptr, #sliced>) -> tensor<16x!tt.ptr, #sliced1> { + %0 = triton_gpu.convert_layout %arg0 : tensor<16x!tt.ptr, #sliced> -> tensor<16x!tt.ptr, #sliced1> + tt.return %0 : tensor<16x!tt.ptr, #sliced1> + } +} + +// ----- + +// Sub-group size 32 variant. + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [32, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 1], order = [0, 1]}> +#sliced = #triton_gpu.slice<{dim = 1, parent = #blocked}> +#sliced1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test_f32( + // CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f32)>, + // CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f32)> + // CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_4]]) + // CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_7]]) + // CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_10]]) + // CHECK: %[[VAL_13:.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_13]]) + // CHECK: %[[VAL_16:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_16]]) + // CHECK: %[[VAL_19:.*]] = llvm.mlir.constant(5 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_19]]) + // CHECK: %[[VAL_22:.*]] = llvm.mlir.constant(6 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_22]]) + // CHECK: %[[VAL_25:.*]] = llvm.mlir.constant(7 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_25]]) + // CHECK: %[[VAL_28:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_28]]) + // CHECK: %[[VAL_31:.*]] = llvm.mlir.constant(9 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_31]]) + // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(10 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_34]]) + // CHECK: %[[VAL_37:.*]] = llvm.mlir.constant(11 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_37]]) + // CHECK: %[[VAL_40:.*]] = llvm.mlir.constant(12 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_40]]) + // CHECK: %[[VAL_43:.*]] = llvm.mlir.constant(13 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_43]]) + // CHECK: %[[VAL_46:.*]] = llvm.mlir.constant(14 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_46]]) + // CHECK: %[[VAL_49:.*]] = llvm.mlir.constant(15 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_49]]) + // CHECK: %[[VAL_52:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_52]]) + // CHECK: %[[VAL_55:.*]] = llvm.mlir.constant(17 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_55]]) + // CHECK: %[[VAL_58:.*]] = llvm.mlir.constant(18 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_58]]) + // CHECK: %[[VAL_61:.*]] = llvm.mlir.constant(19 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_61]]) + // CHECK: %[[VAL_64:.*]] = llvm.mlir.constant(20 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_64]]) + // CHECK: %[[VAL_67:.*]] = llvm.mlir.constant(21 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_67]]) + // CHECK: %[[VAL_70:.*]] = llvm.mlir.constant(22 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_70]]) + // CHECK: %[[VAL_73:.*]] = llvm.mlir.constant(23 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_73]]) + // CHECK: %[[VAL_76:.*]] = llvm.mlir.constant(24 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_76]]) + // CHECK: %[[VAL_79:.*]] = llvm.mlir.constant(25 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_79]]) + // CHECK: %[[VAL_82:.*]] = llvm.mlir.constant(26 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_82]]) + // CHECK: %[[VAL_85:.*]] = llvm.mlir.constant(27 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_85]]) + // CHECK: %[[VAL_88:.*]] = llvm.mlir.constant(28 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_88]]) + // CHECK: %[[VAL_91:.*]] = llvm.mlir.constant(29 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_91]]) + // CHECK: %[[VAL_94:.*]] = llvm.mlir.constant(30 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_94]]) + // CHECK: %[[VAL_97:.*]] = llvm.mlir.constant(31 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_97]]) + tt.func @test_f32(%arg0: tensor<32xf32, #sliced>) -> tensor<32xf32, #sliced1> { + %0 = triton_gpu.convert_layout %arg0 : tensor<32xf32, #sliced> -> tensor<32xf32, #sliced1> + tt.return %0 : tensor<32xf32, #sliced1> + } +} diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 44f33b13e2..7ad29a147d 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -2,6 +2,8 @@ #include "TargetInfo.h" #include "Utility.h" +#include "llvm/ADT/TypeSwitch.h" + #include "intel/include/Analysis/Utility.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h" @@ -532,11 +534,166 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return success(); } + bool isSubGroupShuffle(const LinearLayout &srcLayout, + const LinearLayout &dstLayout) const { + MLIRContext *ctx = srcLayout.getInDimNames().begin()->getContext(); + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + + LinearLayout comp = dstLayout.invertAndCompose(srcLayout); + std::optional conversion = comp.divideRight( + LinearLayout::identity1D(comp.getInDimSize(kWarp), kWarp, kWarp) * + LinearLayout::identity1D(comp.getInDimSize(kBlock), kBlock, kBlock)); + assert(conversion && "Expecting valid conversion"); + // TODO: Support more kind of shuffles. + // Expected conversion is: + // - register=1 -> (0, 1) + // ... + // register=i -> (0, i) + // ... + // register=N -> (0, N) + // - lane=1 -> (0, 0) + // ... + // lane=i -> (0, 0) + // ... + // lane=N -> (0, 0) + // where out dims are: [register (size 1), lane (size N)] + std::vector> registerBases; + { + constexpr std::size_t registerIndex = 1; + std::vector base(2); + for (int32_t i = 1, n = conversion->getInDimSize(kLane); i < n; i *= 2) { + base[registerIndex] = i; + registerBases.push_back(base); + } + } + + std::vector> laneBases( + conversion->getInDimSizeLog2(kLane), std::vector{0, 0}); + std::array>>, 2> + bases{{{kRegister, std::move(registerBases)}, + {kLane, std::move(laneBases)}}}; + std::array outDimNames{kRegister, kLane}; + return conversion == LinearLayout(bases, outDimNames); + } + + bool isSupportedSubGroupShuffle(ConvertLayoutOp, OpAdaptor) const { + // TODO: Limit when sub-group shuffles get more complex. + // We do not need to limit by type here as `gpu.shuffle` conversion will + // fail for us. + return true; + } + + void performSubGroupShuffle(ConvertLayoutOp op, const LinearLayout &srcLayout, + const LinearLayout &dstLayout, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(isSubGroupShuffle(srcLayout, dstLayout) && + "Expecting sub-group shuffle"); + assert(isSupportedSubGroupShuffle(op, adaptor) && + "Expecting supported sub-group shuffle"); + + MLIRContext *ctx = op->getContext(); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + StringAttr kBlock = str_attr("block"); + LinearLayout comp = dstLayout.invertAndCompose(srcLayout); + std::optional conversion = comp.divideRight( + LinearLayout::identity1D(comp.getInDimSize(kWarp), kWarp, kWarp) * + LinearLayout::identity1D(comp.getInDimSize(kBlock), kBlock, kBlock)); + assert(conversion && "Expecting valid layout"); + int32_t subGroupSize = conversion->getOutDimSize(kLane); + + Location loc = op.getLoc(); + + SmallVector inVals = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + assert(inVals.size() == 1 && "Expecting single element"); + + // TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR + // upstream level. We are not enabling support for all types here as that + // should be done upstream. + Type origElemTy = inVals.front().getType(); + TypeSwitch(origElemTy) + .Case([&](BFloat16Type) { + auto intTy = i16_ty; + llvm::transform(inVals, std::begin(inVals), [&](Value val) -> Value { + return bitcast(val, intTy); + }); + }) + .Case([&](IntegerType intTy) { + constexpr unsigned minWidth = 8; + if (intTy.getWidth() >= minWidth) + return; + auto dstTy = i8_ty; + llvm::transform(inVals, std::begin(inVals), + [&](Value val) -> Value { return zext(dstTy, val); }); + }) + .Case([&](LLVM::LLVMPointerType) { + Type dstType = i64_ty; + llvm::transform(inVals, std::begin(inVals), [&](Value val) -> Value { + return ptrtoint(dstType, val); + }); + }); + + SmallVector outVals = + performSubGroupShuffle(loc, inVals.front(), subGroupSize, rewriter); + + // TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR + // upstream level. We are not enabling support for all types here as that + // should be done upstream. + TypeSwitch(origElemTy) + .Case([&](BFloat16Type) { + llvm::transform( + outVals, std::begin(outVals), + [&](Value val) -> Value { return bitcast(val, origElemTy); }); + }) + .Case([&](IntegerType intTy) { + // Check whether conversion took place. + if (intTy == outVals.front().getType()) + return; + llvm::transform( + outVals, std::begin(outVals), + [&](Value val) -> Value { return trunc(origElemTy, val); }); + }) + .Case([&](LLVM::LLVMPointerType ptrTy) { + llvm::transform( + outVals, std::begin(outVals), + [&](Value val) -> Value { return inttoptr(ptrTy, val); }); + }); + + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, + op.getType()); + rewriter.replaceOp(op, result); + } + + SmallVector + performSubGroupShuffle(Location loc, Value val, int32_t subGroupSize, + ConversionPatternRewriter &rewriter) const { + SmallVector res; + Value width = i32_val(subGroupSize); + for (int32_t i = 0; i < subGroupSize; ++i) + res.push_back( + rewriter + .create(loc, val, i32_val(i), width, + mlir::gpu::ShuffleMode::IDX) + .getShuffleResult()); + return res; + } + LogicalResult transferWithinLane(ConvertLayoutOp op, const LinearLayout &srcLayout, const LinearLayout &dstLayout, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + // If the operation is a supported sub-group shuffle, perform via shuffle + // operations. + if (isSubGroupShuffle(srcLayout, dstLayout) && + isSupportedSubGroupShuffle(op, adaptor)) { + performSubGroupShuffle(op, srcLayout, dstLayout, adaptor, rewriter); + return success(); + } // TODO(jlebar): Implement me. return failure(); } From 67fe7b16405ecfa16894a27707d81e09be41a4f7 Mon Sep 17 00:00:00 2001 From: victor-eds Date: Thu, 24 Oct 2024 11:07:24 +0100 Subject: [PATCH 2/3] Drop not needed attrs --- test/Conversion/intel/sub-group-shuffle.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/Conversion/intel/sub-group-shuffle.mlir b/test/Conversion/intel/sub-group-shuffle.mlir index 0e80822368..1c9de97f59 100644 --- a/test/Conversion/intel/sub-group-shuffle.mlir +++ b/test/Conversion/intel/sub-group-shuffle.mlir @@ -7,7 +7,7 @@ #sliced = #triton_gpu.slice<{dim = 1, parent = #blocked}> #sliced1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { // CHECK-LABEL: llvm.func spir_kernelcc @test_f16( // CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f16)>, // CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f16)> @@ -184,7 +184,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : #sliced = #triton_gpu.slice<{dim = 1, parent = #blocked}> #sliced1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}> -module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 32 : i32} { +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { // CHECK-LABEL: llvm.func spir_kernelcc @test_f32( // CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f32)>, // CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f32)> From b5baa69356d4cc2d577f5f0099ba05248e03753e Mon Sep 17 00:00:00 2001 From: Victor Perez Date: Thu, 24 Oct 2024 12:58:42 +0200 Subject: [PATCH 3/3] Remove whitespace --- .../intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 1e3fb3219c..a03989d765 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -734,7 +734,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion .getShuffleResult()); return res; } - + bool isSupportedSubGroupTranspose(ConvertLayoutOp op, OpAdaptor adaptor) const { auto srcType = cast(adaptor.getSrc().getType());