-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][AMDGPU] Add PermlaneOp #154345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][AMDGPU] Add PermlaneOp #154345
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Tim Gymnich (tgymnich) ChangesFull diff: https://github.com/llvm/llvm-project/pull/154345.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 2c646934c11c2..17386a2888788 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -656,6 +656,35 @@ def AMDGPU_SwizzleBitModeOp : AMDGPU_Op<"swizzle_bitmode",
}];
}
+def AMDGPU_PermlanePerm : I32EnumAttr<"PermlanePerm",
+ "The possible permutations for a permlane operation",
+ [
+ I32EnumAttrCase<"swap_16", 0>,
+ I32EnumAttrCase<"swap_32", 1>,
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::amdgpu";
+}
+
+def AMDGPU_PermlanePermAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_PermlanePerm,
+ "permlane_perm">;
+
+def AMDGPU_PermlaneOp : AMDGPU_Op<"permlane", [Pure, AllTypesMatch<["result", "src"]>]>,
+Arguments<(ins AnyIntegerOrFloatOr1DVector:$src,
+ AMDGPU_PermlanePermAttr:$kind)> {
+ let summary = "AMDGPU permlane op";
+ let description = [{
+ High-level wrapper on `rocdl.permlane` variants.
+
+ Supports arbitrary int/float/vector types, which will be repacked to i32 and
+ one or more `rocdl.permlane` ops during lowering.
+ }];
+ let results = (outs AnyIntegerOrFloatOr1DVector:$result);
+ let assemblyFormat = [{
+ $src $kind attr-dict `:` type($result)
+ }];
+}
+
def AMDGPU_LDSBarrierOp : AMDGPU_Op<"lds_barrier"> {
let summary = "Barrier that includes a wait for LDS memory operations.";
let description = [{
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 64720bfe6cf50..ea437a3d5d994 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
@@ -1876,6 +1877,49 @@ struct AMDGPUSwizzleBitModeLowering
}
};
+struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<PermlaneOp>(converter), chipset(chipset) {}
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(PermlaneOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Type i32 = rewriter.getI32Type();
+ Value src = adaptor.getSrc();
+ auto kind = op.getKind();
+ auto fi = rewriter.getBoolAttr(false);
+ auto boundctrl = rewriter.getBoolAttr(false);
+
+ SmallVector<Value> decomposed =
+ LLVM::decomposeValue(rewriter, loc, src, i32);
+
+ SmallVector<Value> permuted;
+ for (Value v : decomposed) {
+ Value res;
+ Type i32pair = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), { v.getType(), v.getType() });
+ switch (kind) {
+ case PermlanePerm::swap_16:
+ res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi, boundctrl);
+ break;
+ case PermlanePerm::swap_32:
+ res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi, boundctrl);
+ break;
+ }
+
+ Value vdstNew = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
+ permuted.emplace_back(vdstNew);
+ }
+
+ Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
using Base::Base;
@@ -1944,6 +1988,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
- TransposeLoadOpLowering>(converter, chipset);
+ TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
index 18c69f5f30e5d..a490286523329 100644
--- a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
@@ -58,7 +59,44 @@ struct PromoteShuffleToSwizzlePattern
};
} // namespace
+namespace {
+struct PromoteShuffleToPermlanePattern
+ : public OpRewritePattern<gpu::ShuffleOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(gpu::ShuffleOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getMode() != gpu::ShuffleMode::XOR)
+ return rewriter.notifyMatchFailure(op,
+ "only xor shuffle mode is supported");
+
+ if (!isConstantIntValue(op.getWidth(), 64))
+ return rewriter.notifyMatchFailure(op,
+ "only 64 width shuffle is supported");
+
+ std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
+ if (!offset)
+ return rewriter.notifyMatchFailure(op,
+ "offset must be a constant integer");
+
+ int64_t offsetValue = *offset;
+ if (offsetValue == 16 || offsetValue == 32)
+ return rewriter.notifyMatchFailure(op, "offset must be either 16 or 32");
+
+ Location loc = op.getLoc();
+ Value res = rewriter.create<ROCDL::Permlane16SwapOp>(
+ loc, op.getResult(0).getType(), op.getValue(), op.getValue(),
+ /*fi=*/false, /*boundControl=*/false);
+ Value valid = arith::ConstantIntOp::create(rewriter, loc, 1, /*width*/ 1);
+ rewriter.replaceOp(op, {res, valid});
+ return success();
+ }
+};
+
+} // namespace
+
void mlir::populateGpuPromoteShuffleToAMDGPUPatterns(
RewritePatternSet &patterns) {
patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext());
+ patterns.add<PromoteShuffleToPermlanePattern>(patterns.getContext());
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir b/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir
new file mode 100644
index 0000000000000..545be794e3375
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir
@@ -0,0 +1,153 @@
+// RUN: mlir-opt -convert-amdgpu-to-rocdl --canonicalize %s | FileCheck %s
+
+// CHECK-LABEL: func @test_permlane16_i32
+// CHECK-SAME: (%[[ARG0:.*]]: i32)
+func.func @test_permlane16_i32(%arg0 : i32) -> i32 {
+// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: return %[[RES]] : i32
+ %0 = amdgpu.permlane %arg0 swap_16 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: func @test_permlane32_i32
+// CHECK-SAME: (%[[ARG0:.*]]: i32)
+func.func @test_permlane32_i32(%arg0 : i32) -> i32 {
+// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: return %[[RES]] : i32
+ %0 = amdgpu.permlane %arg0 swap_32 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: func @test_permlane16_f32
+// CHECK-SAME: (%[[ARG0:.*]]: f32)
+func.func @test_permlane16_f32(%arg0 : f32) -> f32 {
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
+// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
+// CHECK: return %[[RES_CAST]] : f32
+ %0 = amdgpu.permlane %arg0 swap_16 : f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func @test_permlane32_f32
+// CHECK-SAME: (%[[ARG0:.*]]: f32)
+func.func @test_permlane32_f32(%arg0 : f32) -> f32 {
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
+// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
+// CHECK: return %[[RES_CAST]] : f32
+ %0 = amdgpu.permlane %arg0 swap_32 : f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func @test_permlane16_f16
+// CHECK-SAME: (%[[ARG0:.*]]: f16)
+func.func @test_permlane16_f16(%arg0 : f16) -> f16 {
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
+// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
+// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
+// CHECK: return %[[RES_CAST]] : f16
+ %0 = amdgpu.permlane %arg0 swap_16 : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: func @test_permlane32_f16
+// CHECK-SAME: (%[[ARG0:.*]]: f16)
+func.func @test_permlane32_f16(%arg0 : f16) -> f16 {
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
+// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
+// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
+// CHECK: return %[[RES_CAST]] : f16
+ %0 = amdgpu.permlane %arg0 swap_32 : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: func @test_permlane16_2xi32
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
+func.func @test_permlane16_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
+// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: return %[[VEC_INSERT1]] : vector<2xi32>
+ %0 = amdgpu.permlane %arg0 swap_16 : vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func @test_permlane32_2xi32
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
+func.func @test_permlane32_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
+// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: return %[[VEC_INSERT1]] : vector<2xi32>
+ %0 = amdgpu.permlane %arg0 swap_32 : vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func @test_permlane16_4xf16
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
+func.func @test_permlane16_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
+// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
+// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>
+// CHECK: return %[[CAST2]] : vector<4xf16>
+ %0 = amdgpu.permlane %arg0 swap_16 : vector<4xf16>
+ return %0 : vector<4xf16>
+}
+
+// CHECK-LABEL: func @test_permlane32_4xf16
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
+func.func @test_permlane32_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
+// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
+// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>
+// CHECK: return %[[CAST2]] : vector<4xf16>
+ %0 = amdgpu.permlane %arg0 swap_32 : vector<4xf16>
+ return %0 : vector<4xf16>
+}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 87e11c028c62a..5eb6d53d4d9fe 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -524,6 +524,20 @@ func.func @swizzle_bitmode(%arg0 : f32) -> f32 {
func.return %0 : f32
}
+// CHECK-LABEL: func @permlane16_swap
+func.func @permlane16_swap(%arg0 : f32) -> f32 {
+ // CHECK: amdgpu.permlane
+ %0 = amdgpu.permlane %arg0 swap_16 : f32
+ func.return %0 : f32
+}
+
+// CHECK-LABEL: func @permlane32_swap
+func.func @permlane32_swap(%arg0 : f32) -> f32 {
+ // CHECK: amdgpu.permlane
+ %0 = amdgpu.permlane %arg0 swap_32 : f32
+ func.return %0 : f32
+}
+
// CHECK-LABEL: func @scaled_mfma
func.func @scaled_mfma(%arg0 : f8E8M0FNU, %arg1 : vector<32xf6E2M3FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
// CHECK: amdgpu.scaled_mfma
|
|
@llvm/pr-subscribers-backend-amdgpu Author: Tim Gymnich (tgymnich) ChangesFull diff: https://github.com/llvm/llvm-project/pull/154345.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 2c646934c11c2..17386a2888788 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -656,6 +656,35 @@ def AMDGPU_SwizzleBitModeOp : AMDGPU_Op<"swizzle_bitmode",
}];
}
+def AMDGPU_PermlanePerm : I32EnumAttr<"PermlanePerm",
+ "The possible permutations for a permlane operation",
+ [
+ I32EnumAttrCase<"swap_16", 0>,
+ I32EnumAttrCase<"swap_32", 1>,
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::amdgpu";
+}
+
+def AMDGPU_PermlanePermAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_PermlanePerm,
+ "permlane_perm">;
+
+def AMDGPU_PermlaneOp : AMDGPU_Op<"permlane", [Pure, AllTypesMatch<["result", "src"]>]>,
+Arguments<(ins AnyIntegerOrFloatOr1DVector:$src,
+ AMDGPU_PermlanePermAttr:$kind)> {
+ let summary = "AMDGPU permlane op";
+ let description = [{
+ High-level wrapper on `rocdl.permlane` variants.
+
+ Supports arbitrary int/float/vector types, which will be repacked to i32 and
+ one or more `rocdl.permlane` ops during lowering.
+ }];
+ let results = (outs AnyIntegerOrFloatOr1DVector:$result);
+ let assemblyFormat = [{
+ $src $kind attr-dict `:` type($result)
+ }];
+}
+
def AMDGPU_LDSBarrierOp : AMDGPU_Op<"lds_barrier"> {
let summary = "Barrier that includes a wait for LDS memory operations.";
let description = [{
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 64720bfe6cf50..ea437a3d5d994 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
@@ -1876,6 +1877,49 @@ struct AMDGPUSwizzleBitModeLowering
}
};
+struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<PermlaneOp>(converter), chipset(chipset) {}
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(PermlaneOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Type i32 = rewriter.getI32Type();
+ Value src = adaptor.getSrc();
+ auto kind = op.getKind();
+ auto fi = rewriter.getBoolAttr(false);
+ auto boundctrl = rewriter.getBoolAttr(false);
+
+ SmallVector<Value> decomposed =
+ LLVM::decomposeValue(rewriter, loc, src, i32);
+
+ SmallVector<Value> permuted;
+ for (Value v : decomposed) {
+ Value res;
+ Type i32pair = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), { v.getType(), v.getType() });
+ switch (kind) {
+ case PermlanePerm::swap_16:
+ res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi, boundctrl);
+ break;
+ case PermlanePerm::swap_32:
+ res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi, boundctrl);
+ break;
+ }
+
+ Value vdstNew = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
+ permuted.emplace_back(vdstNew);
+ }
+
+ Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
using Base::Base;
@@ -1944,6 +1988,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
- TransposeLoadOpLowering>(converter, chipset);
+ TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
diff --git a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
index 18c69f5f30e5d..a490286523329 100644
--- a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
@@ -58,7 +59,44 @@ struct PromoteShuffleToSwizzlePattern
};
} // namespace
+namespace {
+struct PromoteShuffleToPermlanePattern
+ : public OpRewritePattern<gpu::ShuffleOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(gpu::ShuffleOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getMode() != gpu::ShuffleMode::XOR)
+ return rewriter.notifyMatchFailure(op,
+ "only xor shuffle mode is supported");
+
+ if (!isConstantIntValue(op.getWidth(), 64))
+ return rewriter.notifyMatchFailure(op,
+ "only 64 width shuffle is supported");
+
+ std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
+ if (!offset)
+ return rewriter.notifyMatchFailure(op,
+ "offset must be a constant integer");
+
+ int64_t offsetValue = *offset;
+ if (offsetValue == 16 || offsetValue == 32)
+ return rewriter.notifyMatchFailure(op, "offset must be either 16 or 32");
+
+ Location loc = op.getLoc();
+ Value res = rewriter.create<ROCDL::Permlane16SwapOp>(
+ loc, op.getResult(0).getType(), op.getValue(), op.getValue(),
+ /*fi=*/false, /*boundControl=*/false);
+ Value valid = arith::ConstantIntOp::create(rewriter, loc, 1, /*width*/ 1);
+ rewriter.replaceOp(op, {res, valid});
+ return success();
+ }
+};
+
+} // namespace
+
void mlir::populateGpuPromoteShuffleToAMDGPUPatterns(
RewritePatternSet &patterns) {
patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext());
+ patterns.add<PromoteShuffleToPermlanePattern>(patterns.getContext());
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir b/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir
new file mode 100644
index 0000000000000..545be794e3375
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir
@@ -0,0 +1,153 @@
+// RUN: mlir-opt -convert-amdgpu-to-rocdl --canonicalize %s | FileCheck %s
+
+// CHECK-LABEL: func @test_permlane16_i32
+// CHECK-SAME: (%[[ARG0:.*]]: i32)
+func.func @test_permlane16_i32(%arg0 : i32) -> i32 {
+// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: return %[[RES]] : i32
+ %0 = amdgpu.permlane %arg0 swap_16 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: func @test_permlane32_i32
+// CHECK-SAME: (%[[ARG0:.*]]: i32)
+func.func @test_permlane32_i32(%arg0 : i32) -> i32 {
+// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: return %[[RES]] : i32
+ %0 = amdgpu.permlane %arg0 swap_32 : i32
+ return %0 : i32
+}
+
+// CHECK-LABEL: func @test_permlane16_f32
+// CHECK-SAME: (%[[ARG0:.*]]: f32)
+func.func @test_permlane16_f32(%arg0 : f32) -> f32 {
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
+// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
+// CHECK: return %[[RES_CAST]] : f32
+ %0 = amdgpu.permlane %arg0 swap_16 : f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func @test_permlane32_f32
+// CHECK-SAME: (%[[ARG0:.*]]: f32)
+func.func @test_permlane32_f32(%arg0 : f32) -> f32 {
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
+// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
+// CHECK: return %[[RES_CAST]] : f32
+ %0 = amdgpu.permlane %arg0 swap_32 : f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func @test_permlane16_f16
+// CHECK-SAME: (%[[ARG0:.*]]: f16)
+func.func @test_permlane16_f16(%arg0 : f16) -> f16 {
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
+// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
+// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
+// CHECK: return %[[RES_CAST]] : f16
+ %0 = amdgpu.permlane %arg0 swap_16 : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: func @test_permlane32_f16
+// CHECK-SAME: (%[[ARG0:.*]]: f16)
+func.func @test_permlane32_f16(%arg0 : f16) -> f16 {
+// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
+// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
+// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
+// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
+// CHECK: return %[[RES_CAST]] : f16
+ %0 = amdgpu.permlane %arg0 swap_32 : f16
+ return %0 : f16
+}
+
+// CHECK-LABEL: func @test_permlane16_2xi32
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
+func.func @test_permlane16_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
+// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: return %[[VEC_INSERT1]] : vector<2xi32>
+ %0 = amdgpu.permlane %arg0 swap_16 : vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func @test_permlane32_2xi32
+// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
+func.func @test_permlane32_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
+// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: return %[[VEC_INSERT1]] : vector<2xi32>
+ %0 = amdgpu.permlane %arg0 swap_32 : vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func @test_permlane16_4xf16
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
+func.func @test_permlane16_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
+// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
+// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>
+// CHECK: return %[[CAST2]] : vector<4xf16>
+ %0 = amdgpu.permlane %arg0 swap_16 : vector<4xf16>
+ return %0 : vector<4xf16>
+}
+
+// CHECK-LABEL: func @test_permlane32_4xf16
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
+func.func @test_permlane32_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
+// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
+// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
+// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
+// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
+// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
+// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
+// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>
+// CHECK: return %[[CAST2]] : vector<4xf16>
+ %0 = amdgpu.permlane %arg0 swap_32 : vector<4xf16>
+ return %0 : vector<4xf16>
+}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 87e11c028c62a..5eb6d53d4d9fe 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -524,6 +524,20 @@ func.func @swizzle_bitmode(%arg0 : f32) -> f32 {
func.return %0 : f32
}
+// CHECK-LABEL: func @permlane16_swap
+func.func @permlane16_swap(%arg0 : f32) -> f32 {
+ // CHECK: amdgpu.permlane
+ %0 = amdgpu.permlane %arg0 swap_16 : f32
+ func.return %0 : f32
+}
+
+// CHECK-LABEL: func @permlane32_swap
+func.func @permlane32_swap(%arg0 : f32) -> f32 {
+ // CHECK: amdgpu.permlane
+ %0 = amdgpu.permlane %arg0 swap_32 : f32
+ func.return %0 : f32
+}
+
// CHECK-LABEL: func @scaled_mfma
func.func @scaled_mfma(%arg0 : f8E8M0FNU, %arg1 : vector<32xf6E2M3FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
// CHECK: amdgpu.scaled_mfma
|
e236d80 to
a061dba
Compare
krzysz00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High-level comment - taking the name permlane for these gfx950 swap things feels wrong. This should either handle all the permlane[x][N] variants that aren't obvious implementations of gpu.shuffle or should have a different name?
|
@krzysz00 The current design allows for the addition of future permlane variants like: up, down, bcast, idx_gen, xor. For the rdna variants like |
|
The sources for those would be optional SSA variants, no? |
|
I think for the immediate case, this could be called |
@krzysz00 yes. Similar to |
@krzysz00 I'd be ok with that. But I still have a slight preference for the existing design. |
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
Co-authored-by: Jakub Kuderski <[email protected]>
235ccc7 to
362fdda
Compare
kuhar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the fixes. Let's wait for a second approval from @krzysz00 before merging
krzysz00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Up to naming, I think this is fine
krzysz00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
rocdl.permlane16.swapandrocdl.permlane32.swap