Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,48 @@ def AMDGPU_SwizzleBitModeOp : AMDGPU_Op<"swizzle_bitmode",
}];
}

def AMDGPU_PermlaneSwapOp : AMDGPU_Op<"permlane_swap", [Pure, AllTypesMatch<["result", "src"]>]> {
let summary = "AMDGPU permlane swap op";
let description = [{
High-level wrapper on `rocdl.permlane{16,32}.swap` variants for permutations
on rows of lanes in a subgroup.

Supports arbitrary int/float/vector types, which will be repacked to i32 and
one or more `rocdl.permlane_swap` ops during lowering.
Supported lane permutations:
- Swap the data between odd and even rows of 16 lanes
- Swap the data between the first 32 lanes and the last 32 lanes

Example:
```mlir
%0 = amdgpu.permlane %src 16 : f16
%1 = amdgpu.permlane %src 32 { fetch_inactive = true, bound_ctrl = true } : f16
```

Operands:
* `$src`: Vector register to permute across lanes of the subgroup.
* `$row_length`: The length of a row to permute in number of lanes (valid values are 16 and 32).
* `$fetch_inactive`: Optional. Used to dertermine behavior of a fetch from a disabled lane.
`fetch_inactive = false`: If the source lane is disabled, use `bound_ctrl` to determine the source value.
`fetch_inactive = true`: If the source lane is disabled, fetch the source value anyway (ignoring `bound_ctrl`).
* `$bound_ctrl`: Optional. Used to determine what a thread should do if its source operand is from
a disabled lane: use the value zero, or disable the write.
`bound_ctrl = false`: Do not write when source is from a disabled lane
`bound_ctrl = true`: Use zero as input if source is from a disabled lane

Note: Lowering is only supported on gfx950 and up.
}];
let arguments = (ins AnyIntegerOrFloatOr1DVector:$src,
I32Attr:$row_length,
DefaultValuedAttr<BoolAttr, "false">:$fetch_inactive,
DefaultValuedAttr<BoolAttr, "false">:$bound_ctrl);
let results = (outs AnyIntegerOrFloatOr1DVector:$result);
let assemblyFormat = [{
$src $row_length attr-dict `:` type($result)
}];
let hasVerifier = 1;
}

def AMDGPU_LDSBarrierOp : AMDGPU_Op<"lds_barrier"> {
let summary = "Barrier that includes a wait for LDS memory operations.";
let description = [{
Expand Down
51 changes: 50 additions & 1 deletion mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
Expand Down Expand Up @@ -1876,6 +1877,54 @@ struct AMDGPUSwizzleBitModeLowering
}
};

struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
Chipset chipset;

LogicalResult
matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (chipset < kGfx950)
return op->emitOpError("permlane_swap is only supported on gfx950+");

Location loc = op.getLoc();
Type i32 = rewriter.getI32Type();
Value src = adaptor.getSrc();
unsigned row_length = op.getRowLength();
bool fi = op.getFetchInactive();
bool boundctrl = op.getBoundCtrl();

SmallVector<Value> decomposed =
LLVM::decomposeValue(rewriter, loc, src, i32);

SmallVector<Value> permuted;
for (Value v : decomposed) {
Value res;
Type i32pair = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), {v.getType(), v.getType()});

if (row_length == 16)
res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
boundctrl);
else if (row_length == 32)
res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
boundctrl);
else
llvm_unreachable("unsupported row length");

Value vdstNew = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
permuted.emplace_back(vdstNew);
}

Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
rewriter.replaceOp(op, result);
return success();
}
};

struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
using Base::Base;
Expand Down Expand Up @@ -1944,6 +1993,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
TransposeLoadOpLowering>(converter, chipset);
TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
12 changes: 12 additions & 0 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,18 @@ LogicalResult DPPOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// PermlaneSwapOp
//===----------------------------------------------------------------------===//
LogicalResult PermlaneSwapOp::verify() {
unsigned rowLength = getRowLength();

if (rowLength != 16 && rowLength != 32)
return emitOpError("row_length attribute must either be 16 or 32.");

return success();
}

//===----------------------------------------------------------------------===//
// GatherToLDSOp
//===----------------------------------------------------------------------===//
Expand Down
163 changes: 163 additions & 0 deletions mlir/test/Conversion/AMDGPUToROCDL/permlane.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
// RUN: mlir-opt --convert-amdgpu-to-rocdl=chipset=gfx950 --canonicalize %s | FileCheck %s

// CHECK-LABEL: func @test_permlane16_i32
// CHECK-SAME: (%[[ARG0:.*]]: i32)
func.func @test_permlane16_i32(%arg0 : i32) -> i32 {
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: return %[[RES]] : i32
%0 = amdgpu.permlane_swap %arg0 16 : i32
return %0 : i32
}

// CHECK-LABEL: func @test_permlane16_i32_optional_attr
// CHECK-SAME: (%[[ARG0:.*]]: i32)
func.func @test_permlane16_i32_optional_attr(%arg0 : i32) -> i32 {
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], true, true : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: return %[[RES]] : i32
%0 = amdgpu.permlane_swap %arg0 16 { fetch_inactive = true, bound_ctrl = true } : i32
return %0 : i32
}

// CHECK-LABEL: func @test_permlane32_i32
// CHECK-SAME: (%[[ARG0:.*]]: i32)
func.func @test_permlane32_i32(%arg0 : i32) -> i32 {
// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: return %[[RES]] : i32
%0 = amdgpu.permlane_swap %arg0 32 : i32
return %0 : i32
}

// CHECK-LABEL: func @test_permlane16_f32
// CHECK-SAME: (%[[ARG0:.*]]: f32)
func.func @test_permlane16_f32(%arg0 : f32) -> f32 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
// CHECK: return %[[RES_CAST]] : f32
%0 = amdgpu.permlane_swap %arg0 16 : f32
return %0 : f32
}

// CHECK-LABEL: func @test_permlane32_f32
// CHECK-SAME: (%[[ARG0:.*]]: f32)
func.func @test_permlane32_f32(%arg0 : f32) -> f32 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
// CHECK: return %[[RES_CAST]] : f32
%0 = amdgpu.permlane_swap %arg0 32 : f32
return %0 : f32
}

// CHECK-LABEL: func @test_permlane16_f16
// CHECK-SAME: (%[[ARG0:.*]]: f16)
func.func @test_permlane16_f16(%arg0 : f16) -> f16 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
// CHECK: return %[[RES_CAST]] : f16
%0 = amdgpu.permlane_swap %arg0 16 : f16
return %0 : f16
}

// CHECK-LABEL: func @test_permlane32_f16
// CHECK-SAME: (%[[ARG0:.*]]: f16)
func.func @test_permlane32_f16(%arg0 : f16) -> f16 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
// CHECK: return %[[RES_CAST]] : f16
%0 = amdgpu.permlane_swap %arg0 32 : f16
return %0 : f16
}

// CHECK-LABEL: func @test_permlane16_2xi32
// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
func.func @test_permlane16_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: return %[[VEC_INSERT1]] : vector<2xi32>
%0 = amdgpu.permlane_swap %arg0 16 : vector<2xi32>
return %0 : vector<2xi32>
}

// CHECK-LABEL: func @test_permlane32_2xi32
// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
func.func @test_permlane32_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: return %[[VEC_INSERT1]] : vector<2xi32>
%0 = amdgpu.permlane_swap %arg0 32 : vector<2xi32>
return %0 : vector<2xi32>
}

// CHECK-LABEL: func @test_permlane16_4xf16
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
func.func @test_permlane16_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>
// CHECK: return %[[CAST2]] : vector<4xf16>
%0 = amdgpu.permlane_swap %arg0 16 : vector<4xf16>
return %0 : vector<4xf16>
}

// CHECK-LABEL: func @test_permlane32_4xf16
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
func.func @test_permlane32_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>
// CHECK: return %[[CAST2]] : vector<4xf16>
%0 = amdgpu.permlane_swap %arg0 32 : vector<4xf16>
return %0 : vector<4xf16>
}
14 changes: 14 additions & 0 deletions mlir/test/Dialect/AMDGPU/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,20 @@ func.func @swizzle_bitmode(%arg0 : f32) -> f32 {
func.return %0 : f32
}

// CHECK-LABEL: func @permlane16_swap
func.func @permlane16_swap(%arg0 : f32) -> f32 {
// CHECK: amdgpu.permlane_swap
%0 = amdgpu.permlane_swap %arg0 16 : f32
func.return %0 : f32
}

// CHECK-LABEL: func @permlane32_swap
func.func @permlane32_swap(%arg0 : f32) -> f32 {
// CHECK: amdgpu.permlane_swap
%0 = amdgpu.permlane_swap %arg0 32 : f32
func.return %0 : f32
}

// CHECK-LABEL: func @scaled_mfma
func.func @scaled_mfma(%arg0 : f8E8M0FNU, %arg1 : vector<32xf6E2M3FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
// CHECK: amdgpu.scaled_mfma
Expand Down