diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 2c646934c11c2..72aca2938e029 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -656,6 +656,48 @@ def AMDGPU_SwizzleBitModeOp : AMDGPU_Op<"swizzle_bitmode", }]; } +def AMDGPU_PermlaneSwapOp : AMDGPU_Op<"permlane_swap", [Pure, AllTypesMatch<["result", "src"]>]> { + let summary = "AMDGPU permlane swap op"; + let description = [{ + High-level wrapper on `rocdl.permlane{16,32}.swap` variants for permutations + on rows of lanes in a subgroup. + + Supports arbitrary int/float/vector types, which will be repacked to i32 and + one or more `rocdl.permlane_swap` ops during lowering. + Supported lane permutations: + - Swap the data between odd and even rows of 16 lanes + - Swap the data between the first 32 lanes and the last 32 lanes + + Example: + ```mlir + %0 = amdgpu.permlane %src 16 : f16 + %1 = amdgpu.permlane %src 32 { fetch_inactive = true, bound_ctrl = true } : f16 + ``` + + Operands: + * `$src`: Vector register to permute across lanes of the subgroup. + * `$row_length`: The length of a row to permute in number of lanes (valid values are 16 and 32). + * `$fetch_inactive`: Optional. Used to dertermine behavior of a fetch from a disabled lane. + `fetch_inactive = false`: If the source lane is disabled, use `bound_ctrl` to determine the source value. + `fetch_inactive = true`: If the source lane is disabled, fetch the source value anyway (ignoring `bound_ctrl`). + * `$bound_ctrl`: Optional. Used to determine what a thread should do if its source operand is from + a disabled lane: use the value zero, or disable the write. + `bound_ctrl = false`: Do not write when source is from a disabled lane + `bound_ctrl = true`: Use zero as input if source is from a disabled lane + + Note: Lowering is only supported on gfx950 and up. + }]; + let arguments = (ins AnyIntegerOrFloatOr1DVector:$src, + I32Attr:$row_length, + DefaultValuedAttr:$fetch_inactive, + DefaultValuedAttr:$bound_ctrl); + let results = (outs AnyIntegerOrFloatOr1DVector:$result); + let assemblyFormat = [{ + $src $row_length attr-dict `:` type($result) + }]; + let hasVerifier = 1; +} + def AMDGPU_LDSBarrierOp : AMDGPU_Op<"lds_barrier"> { let summary = "Barrier that includes a wait for LDS memory operations."; let description = [{ diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 64720bfe6cf50..b44d647cf7632 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" @@ -1876,6 +1877,54 @@ struct AMDGPUSwizzleBitModeLowering } }; +struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset) + : ConvertOpToLLVMPattern(converter), chipset(chipset) {} + Chipset chipset; + + LogicalResult + matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (chipset < kGfx950) + return op->emitOpError("permlane_swap is only supported on gfx950+"); + + Location loc = op.getLoc(); + Type i32 = rewriter.getI32Type(); + Value src = adaptor.getSrc(); + unsigned row_length = op.getRowLength(); + bool fi = op.getFetchInactive(); + bool boundctrl = op.getBoundCtrl(); + + SmallVector decomposed = + LLVM::decomposeValue(rewriter, loc, src, i32); + + SmallVector permuted; + for (Value v : decomposed) { + Value res; + Type i32pair = LLVM::LLVMStructType::getLiteral( + rewriter.getContext(), {v.getType(), v.getType()}); + + if (row_length == 16) + res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi, + boundctrl); + else if (row_length == 32) + res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi, + boundctrl); + else + llvm_unreachable("unsupported row length"); + + Value vdstNew = LLVM::ExtractValueOp::create(rewriter, loc, res, {0}); + permuted.emplace_back(vdstNew); + } + + Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType()); + rewriter.replaceOp(op, result); + return success(); + } +}; + struct ConvertAMDGPUToROCDLPass : public impl::ConvertAMDGPUToROCDLPassBase { using Base::Base; @@ -1944,6 +1993,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, GatherToLDSOpLowering, - TransposeLoadOpLowering>(converter, chipset); + TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset); patterns.add(converter); } diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index d7ffdcb58ddb5..11a40d663a201 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -510,6 +510,18 @@ LogicalResult DPPOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// PermlaneSwapOp +//===----------------------------------------------------------------------===// +LogicalResult PermlaneSwapOp::verify() { + unsigned rowLength = getRowLength(); + + if (rowLength != 16 && rowLength != 32) + return emitOpError("row_length attribute must either be 16 or 32."); + + return success(); +} + //===----------------------------------------------------------------------===// // GatherToLDSOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir b/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir new file mode 100644 index 0000000000000..aae2b1d0fd90c --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir @@ -0,0 +1,163 @@ +// RUN: mlir-opt --convert-amdgpu-to-rocdl=chipset=gfx950 --canonicalize %s | FileCheck %s + +// CHECK-LABEL: func @test_permlane16_i32 +// CHECK-SAME: (%[[ARG0:.*]]: i32) +func.func @test_permlane16_i32(%arg0 : i32) -> i32 { +// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)> +// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> +// CHECK: return %[[RES]] : i32 + %0 = amdgpu.permlane_swap %arg0 16 : i32 + return %0 : i32 +} + +// CHECK-LABEL: func @test_permlane16_i32_optional_attr +// CHECK-SAME: (%[[ARG0:.*]]: i32) +func.func @test_permlane16_i32_optional_attr(%arg0 : i32) -> i32 { +// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], true, true : (i32, i32) -> <(i32, i32)> +// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> +// CHECK: return %[[RES]] : i32 + %0 = amdgpu.permlane_swap %arg0 16 { fetch_inactive = true, bound_ctrl = true } : i32 + return %0 : i32 +} + +// CHECK-LABEL: func @test_permlane32_i32 +// CHECK-SAME: (%[[ARG0:.*]]: i32) +func.func @test_permlane32_i32(%arg0 : i32) -> i32 { +// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)> +// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> +// CHECK: return %[[RES]] : i32 + %0 = amdgpu.permlane_swap %arg0 32 : i32 + return %0 : i32 +} + +// CHECK-LABEL: func @test_permlane16_f32 +// CHECK-SAME: (%[[ARG0:.*]]: f32) +func.func @test_permlane16_f32(%arg0 : f32) -> f32 { +// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32 +// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)> +// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32 +// CHECK: return %[[RES_CAST]] : f32 + %0 = amdgpu.permlane_swap %arg0 16 : f32 + return %0 : f32 +} + +// CHECK-LABEL: func @test_permlane32_f32 +// CHECK-SAME: (%[[ARG0:.*]]: f32) +func.func @test_permlane32_f32(%arg0 : f32) -> f32 { +// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32 +// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)> +// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32 +// CHECK: return %[[RES_CAST]] : f32 + %0 = amdgpu.permlane_swap %arg0 32 : f32 + return %0 : f32 +} + +// CHECK-LABEL: func @test_permlane16_f16 +// CHECK-SAME: (%[[ARG0:.*]]: f16) +func.func @test_permlane16_f16(%arg0 : f16) -> f16 { +// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16 +// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32 +// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)> +// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16 +// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16 +// CHECK: return %[[RES_CAST]] : f16 + %0 = amdgpu.permlane_swap %arg0 16 : f16 + return %0 : f16 +} + +// CHECK-LABEL: func @test_permlane32_f16 +// CHECK-SAME: (%[[ARG0:.*]]: f16) +func.func @test_permlane32_f16(%arg0 : f16) -> f16 { +// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16 +// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32 +// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)> +// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16 +// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16 +// CHECK: return %[[RES_CAST]] : f16 + %0 = amdgpu.permlane_swap %arg0 32 : f16 + return %0 : f16 +} + +// CHECK-LABEL: func @test_permlane16_2xi32 +// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>) +func.func @test_permlane16_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> { +// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32> +// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32> +// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32> +// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)> +// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)> +// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32> +// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32> +// CHECK: return %[[VEC_INSERT1]] : vector<2xi32> + %0 = amdgpu.permlane_swap %arg0 16 : vector<2xi32> + return %0 : vector<2xi32> +} + +// CHECK-LABEL: func @test_permlane32_2xi32 +// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>) +func.func @test_permlane32_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> { +// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32> +// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32> +// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32> +// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)> +// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)> +// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32> +// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32> +// CHECK: return %[[VEC_INSERT1]] : vector<2xi32> + %0 = amdgpu.permlane_swap %arg0 32 : vector<2xi32> + return %0 : vector<2xi32> +} + +// CHECK-LABEL: func @test_permlane16_4xf16 +// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>) +func.func @test_permlane16_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> { +// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32> +// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32> +// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32> +// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32> +// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)> +// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)> +// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32> +// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32> +// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16> +// CHECK: return %[[CAST2]] : vector<4xf16> + %0 = amdgpu.permlane_swap %arg0 16 : vector<4xf16> + return %0 : vector<4xf16> +} + +// CHECK-LABEL: func @test_permlane32_4xf16 +// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>) +func.func @test_permlane32_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> { +// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32> +// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32> +// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32> +// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32> +// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)> +// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)> +// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)> +// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32> +// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32> +// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16> +// CHECK: return %[[CAST2]] : vector<4xf16> + %0 = amdgpu.permlane_swap %arg0 32 : vector<4xf16> + return %0 : vector<4xf16> +} diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir index 87e11c028c62a..369e0fff538e1 100644 --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -524,6 +524,20 @@ func.func @swizzle_bitmode(%arg0 : f32) -> f32 { func.return %0 : f32 } +// CHECK-LABEL: func @permlane16_swap +func.func @permlane16_swap(%arg0 : f32) -> f32 { + // CHECK: amdgpu.permlane_swap + %0 = amdgpu.permlane_swap %arg0 16 : f32 + func.return %0 : f32 +} + +// CHECK-LABEL: func @permlane32_swap +func.func @permlane32_swap(%arg0 : f32) -> f32 { + // CHECK: amdgpu.permlane_swap + %0 = amdgpu.permlane_swap %arg0 32 : f32 + func.return %0 : f32 +} + // CHECK-LABEL: func @scaled_mfma func.func @scaled_mfma(%arg0 : f8E8M0FNU, %arg1 : vector<32xf6E2M3FN>, %arg2 : vector<16xf32>) -> vector<16xf32> { // CHECK: amdgpu.scaled_mfma