Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 61 additions & 5 deletions mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,60 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
}
};

/// Lowers gpu.shuffle xor to ds_swizzle if possible.
struct GPUShuffleOpLoweringSwizzle
: public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op.getMode() != gpu::ShuffleMode::XOR)
return rewriter.notifyMatchFailure(op, "only xor mode is supported");

// Check unconverted width and offset.
if (!isConstantIntValue(op.getWidth(), 64))
return rewriter.notifyMatchFailure(op, "width must be 64");

std::optional<int64_t> offsetVal = getConstantIntValue(op.getOffset());
if (!offsetVal)
return rewriter.notifyMatchFailure(op, "offset must be a constant");

int64_t offset = *offsetVal;
if (offset < 0 || offset >= (1 << 5))
return rewriter.notifyMatchFailure(op, "unsupported offset value");

Location loc = op.getLoc();
Value initShflValue = adaptor.getValue();

auto int32Type = rewriter.getI32Type();

// TODO: It may be possible to lower specific xor patterns to DPP ops.

// bit 15 is 0 for the BitMode swizzle.
// https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
int64_t mask = ((1 << 5) - 1) | (offset << 10);
Value maskValue = rewriter.create<LLVM::ConstantOp>(loc, int32Type, mask);

SmallVector<Value> decomposed =
LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type);
SmallVector<Value> swizzled;
for (Value v : decomposed) {
Value res =
rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
swizzled.emplace_back(res);
}
Value shflValue =
LLVM::composeValue(rewriter, loc, swizzled, initShflValue.getType());

// We checked width is 64, so it's always true.
Value isActiveSrcLane =
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), 1);
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
return success();
}
};

struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;

Expand All @@ -135,13 +189,13 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
LogicalResult
matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Location loc = op.getLoc();
Value initShflValue = adaptor.getValue();

const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);

auto int32Type = IntegerType::get(rewriter.getContext(), 32);
auto int32Type = rewriter.getI32Type();
Value width = adaptor.getWidth();
Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0);
Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width);
Expand Down Expand Up @@ -177,14 +231,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {

SmallVector<Value> decomposed =
LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type);
SmallVector<Value> swizzled;
SmallVector<Value> permuted;
for (Value v : decomposed) {
Value res = rewriter.create<ROCDL::DsBpermuteOp>(loc, int32Type,
dwordAlignedDstLane, v);
swizzled.emplace_back(res);
permuted.emplace_back(res);
}
Value shflValue =
LLVM::composeValue(rewriter, loc, swizzled, initShflValue.getType());
LLVM::composeValue(rewriter, loc, permuted, initShflValue.getType());
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
return success();
}
Expand Down Expand Up @@ -405,6 +459,8 @@ void mlir::populateGpuToROCDLConversionPatterns(
// TODO: Add alignment for workgroup memory
patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);

// Try to lower to swizzle first
patterns.add<GPUShuffleOpLoweringSwizzle>(converter, /*benefit*/ 10);
patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);

populateMathToROCDLConversionPatterns(converter, patterns);
Expand Down
45 changes: 26 additions & 19 deletions mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -659,52 +659,47 @@ gpu.module @test_module {
// -----

gpu.module @test_module {
// CHECK-LABEL: func @gpu_shuffle()
func.func @gpu_shuffle() -> (f32, f32, f32) {
// CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
%arg0 = arith.constant 1.0 : f32
// CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : i32
%arg1 = arith.constant 4 : i32
// CHECK: %[[#WIDTH:]] = llvm.mlir.constant(23 : i32) : i32
%arg2 = arith.constant 23 : i32
// CHECK-LABEL: func @gpu_shuffle
// CHECK-SAME: (%[[VALUE:.*]]: f32, %[[OFFSET:.*]]: i32, %[[WIDTH:.*]]: i32)
func.func @gpu_shuffle(%arg0: f32, %arg1: i32, %arg2: i32) -> (f32, f32, f32) {
// CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
// CHECK: %[[#ZERO:]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[#WIDTH]] : i32
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
// CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[WIDTH]] : i32
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[WIDTH]] : i32
// CHECK: %[[#WARP_OR_ZERO:]] = llvm.and %[[#ADD]], %[[#NEG_WIDTH]] : i32
// CHECK: %[[#XOR:]] = llvm.xor %[[#LANE_ID]], %{{.*}} : i32
// CHECK: %[[#CMP:]] = llvm.icmp "slt" %[[#XOR]], %[[#WARP_OR_ZERO]] : i32
// CHECK: %[[#DST_LANE:]] = llvm.select %[[#CMP]], %[[#XOR]], %{{.*}} : i1, i32
// CHECK: %[[#TWO:]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[#ALIGNED_DST_LANE:]] = llvm.shl %[[#DST_LANE]], %[[#TWO]] : i32
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[VALUE]] : f32 to i32
// CHECK: %[[#PERMUTE:]] = rocdl.ds_bpermute %[[#ALIGNED_DST_LANE]], %[[#CAST_VALUE]] : (i32, i32) -> i32
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
%shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : f32
// CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
// CHECK: %[[#ZERO:]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[#WIDTH]] : i32
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
// CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[WIDTH]] : i32
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[WIDTH]] : i32
// CHECK: %[[#WARP_OR_ZERO:]] = llvm.and %[[#ADD]], %[[#NEG_WIDTH]] : i32
// CHECK: %[[#CMP:]] = llvm.icmp "slt" %[[#OFFSET]], %[[#WARP_OR_ZERO]] : i32
// CHECK: %[[#DST_LANE:]] = llvm.select %[[#CMP]], %[[#OFFSET]], %{{.*}} : i1, i32
// CHECK: %[[#CMP:]] = llvm.icmp "slt" %[[OFFSET]], %[[#WARP_OR_ZERO]] : i32
// CHECK: %[[#DST_LANE:]] = llvm.select %[[#CMP]], %[[OFFSET]], %{{.*}} : i1, i32
// CHECK: %[[#TWO:]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[#ALIGNED_DST_LANE:]] = llvm.shl %[[#DST_LANE]], %[[#TWO]] : i32
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[VALUE]] : f32 to i32
// CHECK: %[[#PERMUTE:]] = rocdl.ds_bpermute %[[#ALIGNED_DST_LANE]], %[[#CAST_VALUE]] : (i32, i32) -> i32
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
%shfli, %predi = gpu.shuffle idx %arg0, %arg1, %arg2 : f32
// CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
// CHECK: %[[#ZERO:]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[#WIDTH]] : i32
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
// CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[WIDTH]] : i32
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[WIDTH]] : i32
// CHECK: %[[#WARP_OR_ZERO:]] = llvm.and %[[#ADD]], %[[#NEG_WIDTH]] : i32
// CHECK: %[[#DOWN:]] = llvm.add %[[#LANE_ID]], %{{.*}} : i32
// CHECK: %[[#CMP:]] = llvm.icmp "slt" %[[#DOWN]], %[[#WARP_OR_ZERO]] : i32
// CHECK: %[[#DST_LANE:]] = llvm.select %[[#CMP]], %[[#DOWN]], %{{.*}} : i1, i32
// CHECK: %[[#TWO:]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[#ALIGNED_DST_LANE:]] = llvm.shl %[[#DST_LANE]], %[[#TWO]] : i32
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[VALUE]] : f32 to i32
// CHECK: %[[#PERMUTE:]] = rocdl.ds_bpermute %[[#ALIGNED_DST_LANE]], %[[#CAST_VALUE]] : (i32, i32) -> i32
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
%shfld, %predd = gpu.shuffle down %arg0, %arg1, %arg2 : f32
Expand All @@ -731,6 +726,18 @@ gpu.module @test_module {
%shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : vector<4xf16>
func.return %shfl : vector<4xf16>
}

// CHECK-LABEL: func @gpu_shuffle_swizzle
// CHECK-SAME: (%[[ARG:.*]]: i32)
func.func @gpu_shuffle_swizzle(%arg0: i32) -> i32 {
// CHECK: %[[MASK:.*]] = llvm.mlir.constant(4127 : i32) : i32
// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ARG]], %[[MASK]] : (i32, i32) -> i32
// CHECK: llvm.return %[[RES]] : i32
%width = arith.constant 64 : i32
%offset = arith.constant 4 : i32
%shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : i32
func.return %shfl : i32
}
}

// -----
Expand Down
Loading