Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 144 additions & 5 deletions mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"

#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/PatternMatch.h"
#include <optional>

Expand Down Expand Up @@ -85,7 +86,7 @@ struct PromoteShuffleToPermlanePattern

int64_t offsetValue = *offset;
if (offsetValue != 16 && offsetValue != 32)
return rewriter.notifyMatchFailure(op, "offset must be either 15 or 31");
return rewriter.notifyMatchFailure(op, "offset must be either 16 or 32");

Location loc = op.getLoc();
Value res = amdgpu::PermlaneSwapOp::create(
Expand All @@ -96,13 +97,151 @@ struct PromoteShuffleToPermlanePattern
}
};

static Value getLaneId(RewriterBase &rewriter, Location loc) {
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
NamedAttribute noundef = {LLVM::LLVMDialect::getNoUndefAttrName(),
rewriter.getUnitAttr()};
NamedAttribute lowRange = {LLVM::LLVMDialect::getRangeAttrName(),
LLVM::ConstantRangeAttr::get(rewriter.getContext(),
APInt::getZero(32),
APInt(32, 32))};
NamedAttribute highRange = {
LLVM::LLVMDialect::getRangeAttrName(),
LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
APInt(32, 64))};
Value mbcntLo = ROCDL::MbcntLoOp::create(
rewriter, loc, int32Type, minus1, zero, /*arg_attrs=*/{},
/*res_attrs=*/
rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, lowRange})));
Value laneId = ROCDL::MbcntHiOp::create(
rewriter, loc, int32Type, minus1, mbcntLo, /*arg_attrs=*/{},
rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, highRange})));
return laneId;
}

/// Try to promote `gpu.shuffle` to `amdgpu.dpp`, width must be 64
/// and offset must be a constant integer in the set {16, 32}.
struct PromoteShuffleToDPPPattern : public OpRewritePattern<gpu::ShuffleOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(gpu::ShuffleOp op,
PatternRewriter &rewriter) const override {
std::optional<int64_t> width = getConstantIntValue(op.getWidth());
if (!width)
return rewriter.notifyMatchFailure(op,
"width must be a constant integer");
int64_t widthValue = *width;
if (!llvm::is_contained({4, 8, 12, 16, 32, 48, 64}, widthValue))
return rewriter.notifyMatchFailure(
op, "width must be 4, 8, 12, 16, 32, 48 or 64");

std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
if (!offset)
return rewriter.notifyMatchFailure(op,
"offset must be a constant integer");

int64_t offsetValue = *offset;
Location loc = op.getLoc();
auto int32Type = IntegerType::get(rewriter.getContext(), 32);

amdgpu::DPPPerm kind;
Attribute permAttr = rewriter.getUnitAttr();
Value srcLane;
Value dstLane;
switch (op.getMode()) {
case gpu::ShuffleMode::XOR: {
if (offsetValue != 1 && offsetValue != 2)
return rewriter.notifyMatchFailure(
op, "xor shuffle mode is only supported for offsets of 1 or 2");
kind = amdgpu::DPPPerm::quad_perm;
srcLane = getLaneId(rewriter, loc);
dstLane = LLVM::XOrOp::create(rewriter, loc, int32Type, srcLane,
op.getOffset());

if (offsetValue == 1)
permAttr = rewriter.getI32ArrayAttr({1, 0, 3, 2});
else if (offsetValue == 2)
permAttr = rewriter.getI32ArrayAttr({2, 3, 0, 1});
break;
}
case gpu::ShuffleMode::UP: {
if (offsetValue != 1)
return rewriter.notifyMatchFailure(
op, "up shuffle mode is only supported for offset 1");
kind = amdgpu::DPPPerm::wave_shr;
srcLane = getLaneId(rewriter, loc);
dstLane = LLVM::SubOp::create(rewriter, loc, int32Type, srcLane,
op.getOffset());
break;
}
case gpu::ShuffleMode::DOWN: {
if (offsetValue != 1)
return rewriter.notifyMatchFailure(
op, "down shuffle mode is only supported for offset 1");
kind = amdgpu::DPPPerm::wave_shl;
srcLane = getLaneId(rewriter, loc);
dstLane = LLVM::AddOp::create(rewriter, loc, int32Type, srcLane,
op.getOffset());
break;
}
case gpu::ShuffleMode::IDX:
return rewriter.notifyMatchFailure(op,
"idx shuffle mode is not supported");
}

unsigned bankMask = 0xF;
if (widthValue == 4)
bankMask = 0x1;
else if (widthValue == 8)
bankMask = 0x3;
else if (widthValue == 12)
bankMask = 0x7;

unsigned rowMask = 0xF;
if (widthValue == 16)
rowMask = 0x1;
else if (widthValue == 32)
rowMask = 0x3;
else if (widthValue == 48)
rowMask = 0x7;

constexpr bool boundCtrl = false;

Value negwidth =
arith::ConstantIntOp::create(rewriter, loc, int32Type, -widthValue);
Value add =
arith::AddIOp::create(rewriter, loc, int32Type, srcLane, op.getWidth());
Value widthOrZeroIfOutside =
arith::AndIOp::create(rewriter, loc, int32Type, add, negwidth);
Value isActiveSrcLane =
arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, dstLane,
widthOrZeroIfOutside);

Value dpp = amdgpu::DPPOp::create(rewriter, loc, op.getResult(0).getType(),
op.getValue(), op.getValue(), kind,
permAttr, rowMask, bankMask, boundCtrl);
Value poison =
LLVM::PoisonOp::create(rewriter, loc, op.getResult(0).getType());

Value selectResult =
arith::SelectOp::create(rewriter, loc, isActiveSrcLane, dpp, poison);

rewriter.replaceOp(op, {selectResult, isActiveSrcLane});
return success();
}
};

} // namespace

void mlir::populateGpuPromoteShuffleToAMDGPUPatterns(
RewritePatternSet &patterns, std::optional<amdgpu::Chipset> maybeChipset) {
patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext(),
/*benefit*/ 1);
patterns.add<PromoteShuffleToDPPPattern>(patterns.getContext(),
/*benefit*/ 2);
if (maybeChipset && *maybeChipset >= kGfx950)
patterns.add<PromoteShuffleToPermlanePattern>(patterns.getContext(),
/*benefit*/ 2);
/*benefit*/ 3);
}
86 changes: 84 additions & 2 deletions mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -735,13 +735,18 @@ gpu.module @test_module {
}

// CHECK-LABEL: func @gpu_shuffle_promote()
func.func @gpu_shuffle_promote() -> (f32, f32, f32) {
func.func @gpu_shuffle_promote() -> (f32, f32, f32, f32, f32) {
// CHECK: %[[#POISON:]] = llvm.mlir.poison : f32
// CHECK: %[[#NEGWIDTH:]] = llvm.mlir.constant(-64 : i32) : i32
// CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
%arg0 = arith.constant 1.0 : f32
%arg1 = arith.constant 4 : i32
%arg2 = arith.constant 16 : i32
%arg3 = arith.constant 32 : i32
// CHECK: %[[#WIDTH:]] = llvm.mlir.constant(64 : i32) : i32
%arg4 = arith.constant 64 : i32
// CHECK: %[[#C1:]] = llvm.mlir.constant(1 : i32) : i32
%arg5 = arith.constant 1 : i32
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
// CHECK: %[[#MASK:]] = llvm.mlir.constant(4127 : i32) : i32
// CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32
Expand All @@ -763,7 +768,84 @@ gpu.module @test_module {
// CHECK: %[[#SEL:]] = llvm.select %[[#CMP]], %[[#EXTRACT1]], %[[#EXTRACT0]] : i1, i32
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#SEL]] : i32 to f32
%shfl3, %pred3 = gpu.shuffle xor %arg0, %arg3, %arg4 : f32
func.return %shfl1, %shfl2, %shfl3 : f32, f32, f32
// CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
// CHECK: %[[#SUB:]] = llvm.sub %[[#LANE_ID]], %[[#C1]] : i32
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
// CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
// CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#SUB]], %[[#AND]] : i32
// CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 312, 15, 15, false : f32
// CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
%shflu, %predu = gpu.shuffle up %arg0, %arg5, %arg4 : f32
// CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
// CHECK: %[[#OP:]] = llvm.add %[[#LANE_ID]], %[[#C1]] : i32
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
// CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
// CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#OP]], %[[#AND]] : i32
// CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 304, 15, 15, false : f32
// CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
%shfld, %predd = gpu.shuffle down %arg0, %arg5, %arg4 : f32
func.return %shfl1, %shfl2, %shfl3, %shflu, %shfld : f32, f32, f32, f32, f32
}

// CHECK-LABEL: func @gpu_butterfly_shuffle()
func.func @gpu_butterfly_shuffle() -> (f32, f32, f32, f32, f32, f32) {
// CHECK: %[[#POISON:]] = llvm.mlir.poison : f32
// CHECK: %[[#NEGWIDTH:]] = llvm.mlir.constant(-64 : i32) : i32
// CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
%arg0 = arith.constant 1.0 : f32
// CHECK: %[[#C1:]] = llvm.mlir.constant(1 : i32) : i32
%c1 = arith.constant 1 : i32
// CHECK: %[[#C2:]] = llvm.mlir.constant(2 : i32) : i32
%c2 = arith.constant 2 : i32
%c4 = arith.constant 4 : i32
%c8 = arith.constant 8 : i32
%c16 = arith.constant 16 : i32
%c32 = arith.constant 32 : i32
// CHECK: %[[#WIDTH:]] = llvm.mlir.constant(64 : i32) : i32
%c64 = arith.constant 64 : i32
// CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
// CHECK: %[[#XOR:]] = llvm.xor %[[#LANE_ID]], %[[#C1]] : i32
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
// CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
// CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#XOR]], %[[#AND]] : i32
// CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 177, 15, 15, false : f32
// CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
%shfl1, %pred1 = gpu.shuffle xor %arg0, %c1, %c64 : f32
// CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
// CHECK: %[[#XOR:]] = llvm.xor %[[#LANE_ID]], %[[#C2]] : i32
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
// CHECK: %[[#AND:]] = llvm.and %[[#ADD]], %[[#NEGWIDTH]] : i32
// CHECK: %[[#VALID:]] = llvm.icmp "slt" %[[#XOR]], %[[#AND]] : i32
// CHECK: %[[#PERMUTE:]] = rocdl.update.dpp %[[#VALUE]], %[[#VALUE]] with 78, 15, 15, false : f32
// CHECK: %[[#SELECT:]] = llvm.select %[[#VALID]], %[[#PERMUTE]], %[[#POISON]] : i1, f32
%shfl2, %pred2 = gpu.shuffle xor %arg0, %c2, %c64 : f32
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
// CHECK: %[[#MASK:]] = llvm.mlir.constant(4127 : i32) : i32
// CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
%shfl3, %pred3 = gpu.shuffle xor %arg0, %c4, %c64 : f32
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
// CHECK: %[[#MASK:]] = llvm.mlir.constant(8223 : i32) : i32
// CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
%shfl4, %pred4 = gpu.shuffle xor %arg0, %c8, %c64 : f32
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
// CHECK: %[[#PERMUTE:]] = rocdl.permlane16.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[#EXTRACT0:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[#EXTRACT1:]] = llvm.extractvalue %[[#PERMUTE:]][1] : !llvm.struct<(i32, i32)>
// CHECK: %[[#CMP:]] = llvm.icmp "eq" %[[#EXTRACT0]], %[[#CAST_VALUE]] : i32
// CHECK: %[[#SEL:]] = llvm.select %[[#CMP]], %[[#EXTRACT1]], %[[#EXTRACT0]] : i1, i32
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#SEL]] : i32 to f32
%shfl5, %pred5 = gpu.shuffle xor %arg0, %c16, %c64 : f32
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
// CHECK: %[[#PERMUTE:]] = rocdl.permlane32.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[#EXTRACT0:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[#EXTRACT1:]] = llvm.extractvalue %[[#PERMUTE:]][1] : !llvm.struct<(i32, i32)>
// CHECK: %[[#CMP:]] = llvm.icmp "eq" %[[#EXTRACT0]], %[[#CAST_VALUE]] : i32
// CHECK: %[[#SEL:]] = llvm.select %[[#CMP]], %[[#EXTRACT1]], %[[#EXTRACT0]] : i1, i32
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#SEL]] : i32 to f32
%shfl6, %pred6 = gpu.shuffle xor %arg0, %c32, %c64 : f32
func.return %shfl1, %shfl2, %shfl3, %shfl4, %shfl5, %shfl6 : f32, f32, f32, f32, f32, f32
}

// CHECK-LABEL: func @gpu_shuffle_vec
Expand Down