Skip to content

Commit 5e3eecb

Browse files
committed
[mlir] GPUOpsToROCDL: lower gpu.shuffle xor to rocdl.ds_swizzle when possible
Lower to `rocdl.ds_swizzle` if `width==64` and `0<=offset<(1 << 5)`. It may be possible to lower specific patterns to DPP, but i will leave it for later.
1 parent 5eabece commit 5e3eecb

File tree

2 files changed

+87
-24
lines changed

2 files changed

+87
-24
lines changed

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,60 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
113113
}
114114
};
115115

116+
/// Lowers gpu.shuffle xor to ds_swizzle if possible.
117+
struct GPUShuffleOpLoweringSwizzle
118+
: public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
119+
using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
120+
121+
LogicalResult
122+
matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
123+
ConversionPatternRewriter &rewriter) const override {
124+
if (op.getMode() != gpu::ShuffleMode::XOR)
125+
return rewriter.notifyMatchFailure(op, "only xor mode is supported");
126+
127+
// Check unconverted width and offset.
128+
if (!isConstantIntValue(op.getWidth(), 64))
129+
return rewriter.notifyMatchFailure(op, "width must be 64");
130+
131+
std::optional<int64_t> offsetVal = getConstantIntValue(op.getOffset());
132+
if (!offsetVal)
133+
return rewriter.notifyMatchFailure(op, "offset must be a constant");
134+
135+
int64_t offset = *offsetVal;
136+
if (offset < 0 || offset >= (1 << 5))
137+
return rewriter.notifyMatchFailure(op, "unsupported offset value");
138+
139+
Location loc = op.getLoc();
140+
Value initShflValue = adaptor.getValue();
141+
142+
auto int32Type = rewriter.getI32Type();
143+
144+
// TODO: It may be possible to lower specific xor patterns to DPP ops.
145+
146+
// bit 15 is 0 for the BitMode swizzle.
147+
// https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/
148+
int64_t mask = ((1 << 5) - 1) | (offset << 10);
149+
Value maskValue = rewriter.create<LLVM::ConstantOp>(loc, int32Type, mask);
150+
151+
SmallVector<Value> decomposed =
152+
LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type);
153+
SmallVector<Value> swizzled;
154+
for (Value v : decomposed) {
155+
Value res =
156+
rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
157+
swizzled.emplace_back(res);
158+
}
159+
Value shflValue =
160+
LLVM::composeValue(rewriter, loc, swizzled, initShflValue.getType());
161+
162+
// We checked width is 64, so it's always true.
163+
Value isActiveSrcLane =
164+
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), 1);
165+
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
166+
return success();
167+
}
168+
};
169+
116170
struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
117171
using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
118172

@@ -135,13 +189,13 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
135189
LogicalResult
136190
matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
137191
ConversionPatternRewriter &rewriter) const override {
138-
Location loc = op->getLoc();
192+
Location loc = op.getLoc();
139193
Value initShflValue = adaptor.getValue();
140194

141195
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
142196
Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
143197

144-
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
198+
auto int32Type = rewriter.getI32Type();
145199
Value width = adaptor.getWidth();
146200
Value zero = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 0);
147201
Value negwidth = rewriter.create<LLVM::SubOp>(loc, int32Type, zero, width);
@@ -177,14 +231,14 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
177231

178232
SmallVector<Value> decomposed =
179233
LLVM::decomposeValue(rewriter, loc, initShflValue, int32Type);
180-
SmallVector<Value> swizzled;
234+
SmallVector<Value> permuted;
181235
for (Value v : decomposed) {
182236
Value res = rewriter.create<ROCDL::DsBpermuteOp>(loc, int32Type,
183237
dwordAlignedDstLane, v);
184-
swizzled.emplace_back(res);
238+
permuted.emplace_back(res);
185239
}
186240
Value shflValue =
187-
LLVM::composeValue(rewriter, loc, swizzled, initShflValue.getType());
241+
LLVM::composeValue(rewriter, loc, permuted, initShflValue.getType());
188242
rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
189243
return success();
190244
}
@@ -405,6 +459,8 @@ void mlir::populateGpuToROCDLConversionPatterns(
405459
// TODO: Add alignment for workgroup memory
406460
patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
407461

462+
// Try to lower to swizzle first
463+
patterns.add<GPUShuffleOpLoweringSwizzle>(converter, /*benefit*/ 10);
408464
patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
409465

410466
populateMathToROCDLConversionPatterns(converter, patterns);

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -659,52 +659,47 @@ gpu.module @test_module {
659659
// -----
660660

661661
gpu.module @test_module {
662-
// CHECK-LABEL: func @gpu_shuffle()
663-
func.func @gpu_shuffle() -> (f32, f32, f32) {
664-
// CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
665-
%arg0 = arith.constant 1.0 : f32
666-
// CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : i32
667-
%arg1 = arith.constant 4 : i32
668-
// CHECK: %[[#WIDTH:]] = llvm.mlir.constant(23 : i32) : i32
669-
%arg2 = arith.constant 23 : i32
662+
// CHECK-LABEL: func @gpu_shuffle
663+
// CHECK-SAME: (%[[VALUE:.*]]: f32, %[[OFFSET:.*]]: i32, %[[WIDTH:.*]]: i32)
664+
func.func @gpu_shuffle(%arg0: f32, %arg1: i32, %arg2: i32) -> (f32, f32, f32) {
670665
// CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
671666
// CHECK: %[[#ZERO:]] = llvm.mlir.constant(0 : i32) : i32
672-
// CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[#WIDTH]] : i32
673-
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
667+
// CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[WIDTH]] : i32
668+
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[WIDTH]] : i32
674669
// CHECK: %[[#WARP_OR_ZERO:]] = llvm.and %[[#ADD]], %[[#NEG_WIDTH]] : i32
675670
// CHECK: %[[#XOR:]] = llvm.xor %[[#LANE_ID]], %{{.*}} : i32
676671
// CHECK: %[[#CMP:]] = llvm.icmp "slt" %[[#XOR]], %[[#WARP_OR_ZERO]] : i32
677672
// CHECK: %[[#DST_LANE:]] = llvm.select %[[#CMP]], %[[#XOR]], %{{.*}} : i1, i32
678673
// CHECK: %[[#TWO:]] = llvm.mlir.constant(2 : i32) : i32
679674
// CHECK: %[[#ALIGNED_DST_LANE:]] = llvm.shl %[[#DST_LANE]], %[[#TWO]] : i32
680-
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
675+
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[VALUE]] : f32 to i32
681676
// CHECK: %[[#PERMUTE:]] = rocdl.ds_bpermute %[[#ALIGNED_DST_LANE]], %[[#CAST_VALUE]] : (i32, i32) -> i32
682677
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
683678
%shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : f32
684679
// CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
685680
// CHECK: %[[#ZERO:]] = llvm.mlir.constant(0 : i32) : i32
686-
// CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[#WIDTH]] : i32
687-
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
681+
// CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[WIDTH]] : i32
682+
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[WIDTH]] : i32
688683
// CHECK: %[[#WARP_OR_ZERO:]] = llvm.and %[[#ADD]], %[[#NEG_WIDTH]] : i32
689-
// CHECK: %[[#CMP:]] = llvm.icmp "slt" %[[#OFFSET]], %[[#WARP_OR_ZERO]] : i32
690-
// CHECK: %[[#DST_LANE:]] = llvm.select %[[#CMP]], %[[#OFFSET]], %{{.*}} : i1, i32
684+
// CHECK: %[[#CMP:]] = llvm.icmp "slt" %[[OFFSET]], %[[#WARP_OR_ZERO]] : i32
685+
// CHECK: %[[#DST_LANE:]] = llvm.select %[[#CMP]], %[[OFFSET]], %{{.*}} : i1, i32
691686
// CHECK: %[[#TWO:]] = llvm.mlir.constant(2 : i32) : i32
692687
// CHECK: %[[#ALIGNED_DST_LANE:]] = llvm.shl %[[#DST_LANE]], %[[#TWO]] : i32
693-
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
688+
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[VALUE]] : f32 to i32
694689
// CHECK: %[[#PERMUTE:]] = rocdl.ds_bpermute %[[#ALIGNED_DST_LANE]], %[[#CAST_VALUE]] : (i32, i32) -> i32
695690
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
696691
%shfli, %predi = gpu.shuffle idx %arg0, %arg1, %arg2 : f32
697692
// CHECK: %[[#LANE_ID:]] = rocdl.mbcnt.hi
698693
// CHECK: %[[#ZERO:]] = llvm.mlir.constant(0 : i32) : i32
699-
// CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[#WIDTH]] : i32
700-
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[#WIDTH]] : i32
694+
// CHECK: %[[#NEG_WIDTH:]] = llvm.sub %[[#ZERO]], %[[WIDTH]] : i32
695+
// CHECK: %[[#ADD:]] = llvm.add %[[#LANE_ID]], %[[WIDTH]] : i32
701696
// CHECK: %[[#WARP_OR_ZERO:]] = llvm.and %[[#ADD]], %[[#NEG_WIDTH]] : i32
702697
// CHECK: %[[#DOWN:]] = llvm.add %[[#LANE_ID]], %{{.*}} : i32
703698
// CHECK: %[[#CMP:]] = llvm.icmp "slt" %[[#DOWN]], %[[#WARP_OR_ZERO]] : i32
704699
// CHECK: %[[#DST_LANE:]] = llvm.select %[[#CMP]], %[[#DOWN]], %{{.*}} : i1, i32
705700
// CHECK: %[[#TWO:]] = llvm.mlir.constant(2 : i32) : i32
706701
// CHECK: %[[#ALIGNED_DST_LANE:]] = llvm.shl %[[#DST_LANE]], %[[#TWO]] : i32
707-
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
702+
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[VALUE]] : f32 to i32
708703
// CHECK: %[[#PERMUTE:]] = rocdl.ds_bpermute %[[#ALIGNED_DST_LANE]], %[[#CAST_VALUE]] : (i32, i32) -> i32
709704
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
710705
%shfld, %predd = gpu.shuffle down %arg0, %arg1, %arg2 : f32
@@ -731,6 +726,18 @@ gpu.module @test_module {
731726
%shfl, %pred = gpu.shuffle xor %arg0, %arg1, %arg2 : vector<4xf16>
732727
func.return %shfl : vector<4xf16>
733728
}
729+
730+
// CHECK-LABEL: func @gpu_shuffle_swizzle
731+
// CHECK-SAME: (%[[ARG:.*]]: i32)
732+
func.func @gpu_shuffle_swizzle(%arg0: i32) -> i32 {
733+
// CHECK: %[[MASK:.*]] = llvm.mlir.constant(4127 : i32) : i32
734+
// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ARG]], %[[MASK]] : (i32, i32) -> i32
735+
// CHECK: llvm.return %[[RES]] : i32
736+
%width = arith.constant 64 : i32
737+
%offset = arith.constant 4 : i32
738+
%shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : i32
739+
func.return %shfl : i32
740+
}
734741
}
735742

736743
// -----

0 commit comments

Comments
 (0)