Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,60 @@ def AMDGPU_SwizzleBitModeOp : AMDGPU_Op<"swizzle_bitmode",
}];
}

def AMDGPU_PermlanePerm : I32EnumAttr<"PermlanePerm",
"The possible permutations for a permlane operation",
[
I32EnumAttrCase<"swap_16", 0>,
I32EnumAttrCase<"swap_32", 1>,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::amdgpu";
}

def AMDGPU_PermlanePermAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_PermlanePerm,
"permlane_perm">;

def AMDGPU_PermlaneOp : AMDGPU_Op<"permlane", [Pure, AllTypesMatch<["result", "src"]>]> {
let summary = "AMDGPU permlane op";
let description = [{
High-level wrapper on `rocdl.permlane.*` variants for permutations
on rows of lanes in a subgroup.

Supports arbitrary int/float/vector types, which will be repacked to i32 and
one or more `rocdl.permlane.*` ops during lowering.
Supported lane permutations:
- Swap the data between odd and even rows of 16 lanes (`swap_16`)
- Swap the data between the first 32 lanes and the last 32 lanes (`swap_32`)

Example:
```mlir
%0 = amdgpu.permlane %src swap_16 : f16
%1 = amdgpu.permlane %src swap_32 { fetch_inactive = true, bound_ctrl = true } : f16
```

Operands:
* `$src`: Vector register to permute across lanes of the subgroup.
* `$kind`: The kind of permutation operation.
* `$fetch_inactive`: Optional. Used to dertermine behavior of a fetch from a disabled lane.
`fetch_inactive = false`: If the source lane is disabled, use `bound_ctrl` to determine the source value.
`fetch_inactive = true`: If the source lane is disabled, fetch the source value anyway (ignoring `bound_ctrl`).
* `$bound_ctrl`: Optional. Used to determine what a thread should do if its source operand is from
a disabled lane: use the value zero, or disable the write.
`bound_ctrl = false`: Do not write when source is from a disabled lane
`bound_ctrl = true`: Use zero as input if source is from a disabled lane

Note: Lowering is only supported on gfx950 and up.
}];
let arguments = (ins AnyIntegerOrFloatOr1DVector:$src,
AMDGPU_PermlanePermAttr:$kind,
DefaultValuedAttr<BoolAttr, "false">:$fetch_inactive,
DefaultValuedAttr<BoolAttr, "false">:$bound_ctrl);
let results = (outs AnyIntegerOrFloatOr1DVector:$result);
let assemblyFormat = [{
$src $kind attr-dict `:` type($result)
}];
}

def AMDGPU_LDSBarrierOp : AMDGPU_Op<"lds_barrier"> {
let summary = "Barrier that includes a wait for LDS memory operations.";
let description = [{
Expand Down
52 changes: 51 additions & 1 deletion mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
Expand Down Expand Up @@ -1876,6 +1877,55 @@ struct AMDGPUSwizzleBitModeLowering
}
};

struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<PermlaneOp>(converter), chipset(chipset) {}
Chipset chipset;

LogicalResult
matchAndRewrite(PermlaneOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (chipset < kGfx950)
return op->emitOpError("permlane_swap is only supported on gfx950+");

Location loc = op.getLoc();
Type i32 = rewriter.getI32Type();
Value src = adaptor.getSrc();
auto kind = op.getKind();
bool fi = op.getFetchInactive();
bool boundctrl = op.getBoundCtrl();

SmallVector<Value> decomposed =
LLVM::decomposeValue(rewriter, loc, src, i32);

SmallVector<Value> permuted;
for (Value v : decomposed) {
Value res;
Type i32pair = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), {v.getType(), v.getType()});
switch (kind) {
case PermlanePerm::swap_16:
res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
boundctrl);
break;
case PermlanePerm::swap_32:
res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
boundctrl);
break;
}

Value vdstNew = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
permuted.emplace_back(vdstNew);
}

Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
rewriter.replaceOp(op, result);
return success();
}
};

struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
using Base::Base;
Expand Down Expand Up @@ -1944,6 +1994,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
TransposeLoadOpLowering>(converter, chipset);
TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
163 changes: 163 additions & 0 deletions mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
// RUN: mlir-opt --convert-amdgpu-to-rocdl=chipset=gfx950 --canonicalize %s | FileCheck %s

// CHECK-LABEL: func @test_permlane16_i32
// CHECK-SAME: (%[[ARG0:.*]]: i32)
func.func @test_permlane16_i32(%arg0 : i32) -> i32 {
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: return %[[RES]] : i32
%0 = amdgpu.permlane %arg0 swap_16 : i32
return %0 : i32
}

// CHECK-LABEL: func @test_permlane16_i32_optional_attr
// CHECK-SAME: (%[[ARG0:.*]]: i32)
func.func @test_permlane16_i32_optional_attr(%arg0 : i32) -> i32 {
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], true, true : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: return %[[RES]] : i32
%0 = amdgpu.permlane %arg0 swap_16 { fetch_inactive = true, bound_ctrl = true } : i32
return %0 : i32
}

// CHECK-LABEL: func @test_permlane32_i32
// CHECK-SAME: (%[[ARG0:.*]]: i32)
func.func @test_permlane32_i32(%arg0 : i32) -> i32 {
// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: return %[[RES]] : i32
%0 = amdgpu.permlane %arg0 swap_32 : i32
return %0 : i32
}

// CHECK-LABEL: func @test_permlane16_f32
// CHECK-SAME: (%[[ARG0:.*]]: f32)
func.func @test_permlane16_f32(%arg0 : f32) -> f32 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
// CHECK: return %[[RES_CAST]] : f32
%0 = amdgpu.permlane %arg0 swap_16 : f32
return %0 : f32
}

// CHECK-LABEL: func @test_permlane32_f32
// CHECK-SAME: (%[[ARG0:.*]]: f32)
func.func @test_permlane32_f32(%arg0 : f32) -> f32 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
// CHECK: return %[[RES_CAST]] : f32
%0 = amdgpu.permlane %arg0 swap_32 : f32
return %0 : f32
}

// CHECK-LABEL: func @test_permlane16_f16
// CHECK-SAME: (%[[ARG0:.*]]: f16)
func.func @test_permlane16_f16(%arg0 : f16) -> f16 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
// CHECK: return %[[RES_CAST]] : f16
%0 = amdgpu.permlane %arg0 swap_16 : f16
return %0 : f16
}

// CHECK-LABEL: func @test_permlane32_f16
// CHECK-SAME: (%[[ARG0:.*]]: f16)
func.func @test_permlane32_f16(%arg0 : f16) -> f16 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
// CHECK: return %[[RES_CAST]] : f16
%0 = amdgpu.permlane %arg0 swap_32 : f16
return %0 : f16
}

// CHECK-LABEL: func @test_permlane16_2xi32
// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
func.func @test_permlane16_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: return %[[VEC_INSERT1]] : vector<2xi32>
%0 = amdgpu.permlane %arg0 swap_16 : vector<2xi32>
return %0 : vector<2xi32>
}

// CHECK-LABEL: func @test_permlane32_2xi32
// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
func.func @test_permlane32_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: return %[[VEC_INSERT1]] : vector<2xi32>
%0 = amdgpu.permlane %arg0 swap_32 : vector<2xi32>
return %0 : vector<2xi32>
}

// CHECK-LABEL: func @test_permlane16_4xf16
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
func.func @test_permlane16_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>
// CHECK: return %[[CAST2]] : vector<4xf16>
%0 = amdgpu.permlane %arg0 swap_16 : vector<4xf16>
return %0 : vector<4xf16>
}

// CHECK-LABEL: func @test_permlane32_4xf16
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
func.func @test_permlane32_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>
// CHECK: return %[[CAST2]] : vector<4xf16>
%0 = amdgpu.permlane %arg0 swap_32 : vector<4xf16>
return %0 : vector<4xf16>
}
14 changes: 14 additions & 0 deletions mlir/test/Dialect/AMDGPU/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,20 @@ func.func @swizzle_bitmode(%arg0 : f32) -> f32 {
func.return %0 : f32
}

// CHECK-LABEL: func @permlane16_swap
func.func @permlane16_swap(%arg0 : f32) -> f32 {
// CHECK: amdgpu.permlane
%0 = amdgpu.permlane %arg0 swap_16 : f32
func.return %0 : f32
}

// CHECK-LABEL: func @permlane32_swap
func.func @permlane32_swap(%arg0 : f32) -> f32 {
// CHECK: amdgpu.permlane
%0 = amdgpu.permlane %arg0 swap_32 : f32
func.return %0 : f32
}

// CHECK-LABEL: func @scaled_mfma
func.func @scaled_mfma(%arg0 : f8E8M0FNU, %arg1 : vector<32xf6E2M3FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
// CHECK: amdgpu.scaled_mfma
Expand Down