diff --git a/test/Conversion/intel/sub-group-shuffle.mlir b/test/Conversion/intel/sub-group-shuffle.mlir index 1c9de97f59..1e9d32a8c7 100644 --- a/test/Conversion/intel/sub-group-shuffle.mlir +++ b/test/Conversion/intel/sub-group-shuffle.mlir @@ -257,3 +257,106 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : tt.return %0 : tensor<32xf32, #sliced1> } } + +// ----- + +// Case of more than one element per thread in the non-sliced dimension. + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> +#sliced = #triton_gpu.slice<{dim = 1, parent = #blocked}> +#sliced1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test_non_sliced_multi_register( + // CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f64, f64)>, + // CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f64, f64)> + // CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(f64, f64)> + // CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_5]]) + // CHECK: %[[VAL_8:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_8]]) + // CHECK: llvm.mlir.constant(true) : i1 + // CHECK: %[[VAL_11:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_11]]) + // CHECK: %[[VAL_14:.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_14]]) + // CHECK: %[[VAL_17:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_17]]) + // CHECK: %[[VAL_20:.*]] = llvm.mlir.constant(5 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_20]]) + // CHECK: %[[VAL_23:.*]] = llvm.mlir.constant(6 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_23]]) + // CHECK: %[[VAL_26:.*]] = llvm.mlir.constant(7 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_26]]) + // CHECK: %[[VAL_29:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_29]]) + // CHECK: %[[VAL_32:.*]] = llvm.mlir.constant(9 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_32]]) + // CHECK: %[[VAL_35:.*]] = llvm.mlir.constant(10 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_35]]) + // CHECK: %[[VAL_38:.*]] = llvm.mlir.constant(11 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_38]]) + // CHECK: %[[VAL_41:.*]] = llvm.mlir.constant(12 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_41]]) + // CHECK: %[[VAL_44:.*]] = llvm.mlir.constant(13 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_44]]) + // CHECK: %[[VAL_47:.*]] = llvm.mlir.constant(14 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_47]]) + // CHECK: %[[VAL_50:.*]] = llvm.mlir.constant(15 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_2]], %[[VAL_50]]) + // CHECK: %[[VAL_53:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_53]]) + // CHECK: %[[VAL_56:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_56]]) + // CHECK: %[[VAL_59:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_59]]) + // CHECK: %[[VAL_62:.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_62]]) + // CHECK: %[[VAL_65:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_65]]) + // CHECK: %[[VAL_68:.*]] = llvm.mlir.constant(5 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_68]]) + // CHECK: %[[VAL_71:.*]] = llvm.mlir.constant(6 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_71]]) + // CHECK: %[[VAL_74:.*]] = llvm.mlir.constant(7 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_74]]) + // CHECK: %[[VAL_77:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_77]]) + // CHECK: %[[VAL_80:.*]] = llvm.mlir.constant(9 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_80]]) + // CHECK: %[[VAL_83:.*]] = llvm.mlir.constant(10 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_83]]) + // CHECK: %[[VAL_86:.*]] = llvm.mlir.constant(11 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_86]]) + // CHECK: %[[VAL_89:.*]] = llvm.mlir.constant(12 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_89]]) + // CHECK: %[[VAL_92:.*]] = llvm.mlir.constant(13 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_92]]) + // CHECK: %[[VAL_95:.*]] = llvm.mlir.constant(14 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_95]]) + // CHECK: %[[VAL_98:.*]] = llvm.mlir.constant(15 : i32) : i32 + // CHECK: llvm.call spir_funccc @_Z17sub_group_shuffledj(%[[VAL_3]], %[[VAL_98]]) + tt.func @test_non_sliced_multi_register(%arg0: tensor<32xf64, #sliced>) -> tensor<32xf64, #sliced1> { + %0 = triton_gpu.convert_layout %arg0 : tensor<32xf64, #sliced> -> tensor<32xf64, #sliced1> + tt.return %0 : tensor<32xf64, #sliced1> + } +} + +// ----- + +// Case of more than one element per thread and 2 warps in the non-sliced dimension. + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 1], order = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}> +#sliced = #triton_gpu.slice<{dim = 1, parent = #blocked}> +#sliced1 = #triton_gpu.slice<{dim = 1, parent = #blocked1}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { + // CHECK-LABEL: llvm.func spir_kernelcc @test_non_sliced_multi_register_multi_warp + // CHECK-COUNT-64: llvm.call spir_funccc @_Z17sub_group_shuffleij + tt.func @test_non_sliced_multi_register_multi_warp(%arg0: tensor<128xi32, #sliced>) -> tensor<128xi32, #sliced1> { + %0 = triton_gpu.convert_layout %arg0 : tensor<128xi32, #sliced> -> tensor<128xi32, #sliced1> + tt.return %0 : tensor<128xi32, #sliced1> + } +} diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 18144d91fc..f5a4bf0bad 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -446,6 +446,62 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { } + // Return a vector such as: + // [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [laneSize, 0], ..., + // [registerSize / 2, 0]], + // i.e., mapping registers to lanes till laneSize and performing an ID + // conversion afterwards. + static std::vector> + buildSubGroupTransposeRegisterBases(int32_t registerSize, int32_t laneSize) { + std::vector> bases; + std::vector curr(2); + for (int32_t i = 1; i < laneSize; i *= 2) { + curr[1] = i; + bases.push_back(curr); + } + curr[1] = 0; + for (int32_t i = laneSize; i < registerSize; i *= 2) { + curr[0] = i; + bases.push_back(curr); + } + return bases; + } + + // Return a vector such as: + // [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [1, 0], ..., + // [registerSize / (2 * laneSize), 0]] + // i.e., mapping registers to lanes till laneSize and repeating the pattern + // afterwards. + static std::vector> + buildSubGroupShuffleRegisterBases(int32_t registerSize, int32_t laneSize) { + std::vector> bases; + std::vector curr(2); + for (int32_t i = 1; i < laneSize; i *= 2) { + curr[1] = i; + bases.push_back(curr); + } + curr[1] = 0; + for (int32_t i = laneSize, val = 1; i < registerSize; i *= 2, val *= 2) { + curr[0] = val; + bases.push_back(curr); + } + return bases; + } + + // Return a vector such as: + // [[1, 0], [2, 0], [4, 0], ..., [laneSize / 2, 0]], + // i.e., mapping lanes to registers. + static std::vector> + buildSubGroupTransposeLaneBases(int32_t laneSize) { + std::vector> bases; + std::vector curr(2); + for (int32_t i = 1; i < laneSize; i *= 2) { + curr[0] = i; + bases.push_back(curr); + } + return bases; + } + bool isSubGroupTranspose(const LinearLayout &srcLayout, const LinearLayout &dstLayout) const { MLIRContext *ctx = srcLayout.getInDimNames().begin()->getContext(); @@ -476,35 +532,13 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))] // // With N >= M. - const auto buildBasis = [&](int32_t size, std::size_t index) { - std::vector> basis; - std::vector curr(2); - for (int32_t i = 1; i < size; i *= 2) { - curr[index] = i; - basis.push_back(curr); - } - return basis; - }; - constexpr std::size_t laneIndex = 0; - constexpr std::size_t registerIndex = 1; - int32_t laneSize = conversion->getInDimSize(kLane); - std::vector> registerBases = - buildBasis(laneSize, registerIndex); - { - // Populate register bases for N > M. - std::vector base(2); - for (int32_t i = laneSize, - registerSize = conversion->getInDimSize(kRegister); - i < registerSize; i *= 2) { - base[laneIndex] = i; - registerBases.push_back(base); - } - } - std::array>>, 2> - bases{{{kRegister, std::move(registerBases)}, - {kLane, buildBasis(laneSize, laneIndex)}}}; - std::array outDimNames{kRegister, kLane}; - return conversion == LinearLayout(bases, outDimNames); + int32_t registerInDimSize = conversion->getInDimSize(kRegister); + int32_t laneInDimSize = conversion->getInDimSize(kLane); + return conversion->getBases().lookup(kRegister) == + buildSubGroupTransposeRegisterBases(registerInDimSize, + laneInDimSize) && + conversion->getBases().lookup(kLane) == + buildSubGroupTransposeLaneBases(laneInDimSize); } LogicalResult @@ -619,32 +653,27 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // Expected conversion is: // - register=1 -> (0, 1) // ... - // register=i -> (0, i) + // - register=2**i -> (0, 2**i) + // ... + // - register=M -> (0, 2**M) + // ... + // - register=2**k -> (2**(k-M), 0) // ... - // register=N -> (0, N) + // - register=2**N -> (2**(N-M), 0) // - lane=1 -> (0, 0) // ... - // lane=i -> (0, 0) + // - lane=2**j -> (0, 0) // ... - // lane=N -> (0, 0) - // where out dims are: [register (size 1), lane (size N)] - std::vector> registerBases; - { - constexpr std::size_t registerIndex = 1; - std::vector base(2); - for (int32_t i = 1, n = conversion->getInDimSize(kLane); i < n; i *= 2) { - base[registerIndex] = i; - registerBases.push_back(base); - } - } - - std::vector> laneBases( - conversion->getInDimSizeLog2(kLane), std::vector{0, 0}); - std::array>>, 2> - bases{{{kRegister, std::move(registerBases)}, - {kLane, std::move(laneBases)}}}; - std::array outDimNames{kRegister, kLane}; - return conversion == LinearLayout(bases, outDimNames); + // lane=2**M -> (0, 0) + // where out dims are: [register (size 2**(N - M)), lane (size 2**(M + 1))] + // + // With N >= M. + int32_t registerInDimSize = conversion->getInDimSize(kRegister); + int32_t laneOutDimSize = conversion->getOutDimSize(kLane); + return conversion->sublayoutIsZero({kLane}, {kRegister, kLane}) && + conversion->getBases().lookup(kRegister) == + buildSubGroupShuffleRegisterBases(registerInDimSize, + laneOutDimSize); } bool isSupportedSubGroupShuffle(ConvertLayoutOp, OpAdaptor) const { @@ -674,7 +703,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion SmallVector inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - assert(inVals.size() == 1 && "Expecting single element"); // TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR // upstream level. We are not enabling support for all types here as that @@ -703,7 +731,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion }); SmallVector outVals = - performSubGroupShuffle(loc, inVals.front(), subGroupSize, rewriter); + performSubGroupShuffle(loc, inVals, subGroupSize, rewriter); // TODO: Drop 'BFloat16Type' and 'IntegerType' cases when supported at MLIR // upstream level. We are not enabling support for all types here as that @@ -734,16 +762,19 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } SmallVector - performSubGroupShuffle(Location loc, Value val, int32_t subGroupSize, + performSubGroupShuffle(Location loc, ArrayRef inVals, + int32_t subGroupSize, ConversionPatternRewriter &rewriter) const { SmallVector res; Value width = i32_val(subGroupSize); - for (int32_t i = 0; i < subGroupSize; ++i) - res.push_back( - rewriter - .create(loc, val, i32_val(i), width, - mlir::gpu::ShuffleMode::IDX) - .getShuffleResult()); + for (Value val : inVals) { + for (int32_t i = 0; i < subGroupSize; ++i) + res.push_back( + rewriter + .create(loc, val, i32_val(i), width, + mlir::gpu::ShuffleMode::IDX) + .getShuffleResult()); + } return res; }