Skip to content

Commit 003cbbd

Browse files
authored
[mlir][amdgpu] Promote gpu.shuffle to amdgpu.permlane_swap (llvm#154933)
- promote `gpu.shuffle %src xor {16,32} 64` to `amdgpu.permlane_swap %src {16,32}`
1 parent ec860d1 commit 003cbbd

File tree

9 files changed

+114
-17
lines changed

9 files changed

+114
-17
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -670,8 +670,8 @@ def AMDGPU_PermlaneSwapOp : AMDGPU_Op<"permlane_swap", [Pure, AllTypesMatch<["re
670670

671671
Example:
672672
```mlir
673-
%0 = amdgpu.permlane %src 16 : f16
674-
%1 = amdgpu.permlane %src 32 { fetch_inactive = true, bound_ctrl = true } : f16
673+
%0 = amdgpu.permlane_swap %src 16 : f16
674+
%1 = amdgpu.permlane_swap %src 32 { fetch_inactive = true, bound_ctrl = true } : f16
675675
```
676676

677677
Operands:

mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,10 @@ def ApplyGPUPromoteShuffleToAMDGPUPatternsOp : Op<Transform_Dialect,
331331
Collects patterns that are tryin to promote `gpu.shuffle`s to specialized
332332
AMDGPU intrinsics.
333333
}];
334-
let assemblyFormat = "attr-dict";
334+
let arguments = (ins OptionalAttr<StrAttr>:$chipset);
335+
let assemblyFormat = [{
336+
(`chipset` `=` $chipset^)? attr-dict
337+
}];
335338
}
336339

337340

mlir/include/mlir/Dialect/GPU/Transforms/Passes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns);
114114
void populateGpuEliminateBarriersPatterns(RewritePatternSet &patterns);
115115

116116
/// Tries to promote `gpu.shuffle`s to specialized AMDGPU intrinsics.
117-
void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns);
117+
void populateGpuPromoteShuffleToAMDGPUPatterns(
118+
RewritePatternSet &patterns, std::optional<amdgpu::Chipset> maybeChipset);
118119

119120
/// Generate the code for registering passes.
120121
#define GEN_PASS_REGISTRATION

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1893,7 +1893,7 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
18931893
Location loc = op.getLoc();
18941894
Type i32 = rewriter.getI32Type();
18951895
Value src = adaptor.getSrc();
1896-
unsigned row_length = op.getRowLength();
1896+
unsigned rowLength = op.getRowLength();
18971897
bool fi = op.getFetchInactive();
18981898
bool boundctrl = op.getBoundCtrl();
18991899

@@ -1906,10 +1906,10 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
19061906
Type i32pair = LLVM::LLVMStructType::getLiteral(
19071907
rewriter.getContext(), {v.getType(), v.getType()});
19081908

1909-
if (row_length == 16)
1909+
if (rowLength == 16)
19101910
res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
19111911
boundctrl);
1912-
else if (row_length == 32)
1912+
else if (rowLength == 32)
19131913
res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
19141914
boundctrl);
19151915
else

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ struct LowerGpuOpsToROCDLOpsPass final
327327
{
328328
RewritePatternSet patterns(ctx);
329329
populateGpuRewritePatterns(patterns);
330-
populateGpuPromoteShuffleToAMDGPUPatterns(patterns);
330+
populateGpuPromoteShuffleToAMDGPUPatterns(patterns, maybeChipset);
331331
(void)applyPatternsGreedily(m, std::move(patterns));
332332
}
333333

mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
1414
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1515
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
16+
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1617
#include "mlir/Dialect/Arith/IR/Arith.h"
1718
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1819
#include "mlir/Dialect/GPU/TransformOps/Utils.h"
@@ -43,6 +44,7 @@
4344
#include "llvm/Support/ErrorHandling.h"
4445
#include "llvm/Support/InterleavedRange.h"
4546
#include "llvm/Support/LogicalResult.h"
47+
#include <optional>
4648
#include <type_traits>
4749

4850
using namespace mlir;
@@ -170,7 +172,16 @@ void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) {
170172

171173
void transform::ApplyGPUPromoteShuffleToAMDGPUPatternsOp::populatePatterns(
172174
RewritePatternSet &patterns) {
173-
populateGpuPromoteShuffleToAMDGPUPatterns(patterns);
175+
std::optional<StringRef> chipsetName = getChipset();
176+
std::optional<amdgpu::Chipset> maybeChipset;
177+
if (chipsetName) {
178+
FailureOr<amdgpu::Chipset> parsedChipset =
179+
amdgpu::Chipset::parse(*chipsetName);
180+
assert(llvm::succeeded(parsedChipset) && "expected valid chipset");
181+
maybeChipset = parsedChipset;
182+
}
183+
184+
populateGpuPromoteShuffleToAMDGPUPatterns(patterns, maybeChipset);
174185
}
175186

176187
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,21 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14+
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1415
#include "mlir/Dialect/GPU/Transforms/Passes.h"
1516

1617
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
1718
#include "mlir/Dialect/Arith/IR/Arith.h"
1819
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1920
#include "mlir/IR/PatternMatch.h"
21+
#include <optional>
2022

2123
using namespace mlir;
2224

2325
namespace {
26+
27+
constexpr amdgpu::Chipset kGfx950 = amdgpu::Chipset(9, 5, 0);
28+
2429
/// Try to promote `gpu.shuffle` to `amdgpu.swizzle_bitmode`, width must be 64
2530
/// and offset must be a constant integer in the range [0, 31].
2631
struct PromoteShuffleToSwizzlePattern
@@ -56,9 +61,48 @@ struct PromoteShuffleToSwizzlePattern
5661
return success();
5762
}
5863
};
64+
65+
/// Try to promote `gpu.shuffle` to `amdgpu.permlane_swap`, width must be 64
66+
/// and offset must be a constant integer in the set {16, 32}.
67+
struct PromoteShuffleToPermlanePattern
68+
: public OpRewritePattern<gpu::ShuffleOp> {
69+
using OpRewritePattern::OpRewritePattern;
70+
71+
LogicalResult matchAndRewrite(gpu::ShuffleOp op,
72+
PatternRewriter &rewriter) const override {
73+
if (op.getMode() != gpu::ShuffleMode::XOR)
74+
return rewriter.notifyMatchFailure(op,
75+
"only xor shuffle mode is supported");
76+
77+
if (!isConstantIntValue(op.getWidth(), 64))
78+
return rewriter.notifyMatchFailure(op,
79+
"only 64 width shuffle is supported");
80+
81+
std::optional<int64_t> offset = getConstantIntValue(op.getOffset());
82+
if (!offset)
83+
return rewriter.notifyMatchFailure(op,
84+
"offset must be a constant integer");
85+
86+
int64_t offsetValue = *offset;
87+
if (offsetValue != 16 && offsetValue != 32)
88+
return rewriter.notifyMatchFailure(op, "offset must be either 15 or 31");
89+
90+
Location loc = op.getLoc();
91+
Value res = amdgpu::PermlaneSwapOp::create(
92+
rewriter, loc, op.getResult(0).getType(), op.getValue(), offsetValue);
93+
Value valid = arith::ConstantIntOp::create(rewriter, loc, 1, /*width*/ 1);
94+
rewriter.replaceOp(op, {res, valid});
95+
return success();
96+
}
97+
};
98+
5999
} // namespace
60100

61101
void mlir::populateGpuPromoteShuffleToAMDGPUPatterns(
62-
RewritePatternSet &patterns) {
63-
patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext());
102+
RewritePatternSet &patterns, std::optional<amdgpu::Chipset> maybeChipset) {
103+
patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext(),
104+
/*benefit*/ 1);
105+
if (maybeChipset && *maybeChipset >= kGfx950)
106+
patterns.add<PromoteShuffleToPermlanePattern>(patterns.getContext(),
107+
/*benefit*/ 2);
64108
}

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
// RUN: mlir-opt %s -convert-gpu-to-rocdl -split-input-file | FileCheck %s
2-
// RUN: mlir-opt %s -convert-gpu-to-rocdl='allowed-dialects=func,arith,math' -split-input-file | FileCheck %s
3-
// RUN: mlir-opt %s -convert-gpu-to-rocdl='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s
1+
// RUN: mlir-opt %s -convert-gpu-to-rocdl='chipset=gfx950' -split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s -convert-gpu-to-rocdl='chipset=gfx950 allowed-dialects=func,arith,math' -split-input-file | FileCheck %s
3+
// RUN: mlir-opt %s -convert-gpu-to-rocdl='chipset=gfx950 index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s
44

55
// CHECK-LABEL: @test_module
66
// CHECK-SAME: llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"
@@ -734,14 +734,40 @@ gpu.module @test_module {
734734
func.return %shfl, %shfli, %shflu, %shfld : f32, f32, f32, f32
735735
}
736736

737+
// CHECK-LABEL: func @gpu_shuffle_promote()
738+
func.func @gpu_shuffle_promote() -> (f32, f32, f32) {
739+
// CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : f32
740+
%arg0 = arith.constant 1.0 : f32
741+
%arg1 = arith.constant 4 : i32
742+
%arg2 = arith.constant 16 : i32
743+
%arg3 = arith.constant 32 : i32
744+
%arg4 = arith.constant 64 : i32
745+
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
746+
// CHECK: %[[#MASK:]] = llvm.mlir.constant(4127 : i32) : i32
747+
// CHECK: %[[#PERMUTE:]] = rocdl.ds_swizzle %[[#CAST_VALUE]], %[[#MASK]] : (i32, i32) -> i32
748+
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#PERMUTE]] : i32 to f32
749+
%shfl1, %pred1 = gpu.shuffle xor %arg0, %arg1, %arg4 : f32
750+
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
751+
// CHECK: %[[#PERMUTE:]] = rocdl.permlane16.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
752+
// CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
753+
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
754+
%shfl2, %pred2 = gpu.shuffle xor %arg0, %arg2, %arg4 : f32
755+
// CHECK: %[[#CAST_VALUE:]] = llvm.bitcast %[[#VALUE]] : f32 to i32
756+
// CHECK: %[[#PERMUTE:]] = rocdl.permlane32.swap %[[#CAST_VALUE]], %[[#CAST_VALUE]], false, false : (i32, i32) -> <(i32, i32)>
757+
// CHECK: %[[#EXTRACT:]] = llvm.extractvalue %[[#PERMUTE:]][0] : !llvm.struct<(i32, i32)>
758+
// CHECK: %[[#CAST_SHFL_VALUE:]] = llvm.bitcast %[[#EXTRACT]] : i32 to f32
759+
%shfl3, %pred3 = gpu.shuffle xor %arg0, %arg3, %arg4 : f32
760+
func.return %shfl1, %shfl2, %shfl3 : f32, f32, f32
761+
}
762+
737763
// CHECK-LABEL: func @gpu_shuffle_vec
738764
// CHECK-SAME: (%[[ARG:.*]]: vector<4xf16>, %{{.*}}: i32, %{{.*}}: i32)
739765
func.func @gpu_shuffle_vec(%arg0: vector<4xf16>, %arg1: i32, %arg2: i32) -> vector<4xf16> {
740766
// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG]] : vector<4xf16> to vector<2xi32>
741767
// CHECK: %[[IDX0:.*]] = llvm.mlir.constant(0 : i32) : i32
742-
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %13[%[[IDX0]] : i32] : vector<2xi32>
768+
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[IDX0]] : i32] : vector<2xi32>
743769
// CHECK: %[[IDX1:.*]] = llvm.mlir.constant(1 : i32) : i32
744-
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %13[%[[IDX1]] : i32] : vector<2xi32>
770+
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[IDX1]] : i32] : vector<2xi32>
745771
// CHECK: %[[PERM0:.*]] = rocdl.ds_bpermute %{{.*}}, %[[ELEM0]] : (i32, i32) -> i32
746772
// CHECK: %[[PERM1:.*]] = rocdl.ds_bpermute %{{.*}}, %[[ELEM1]] : (i32, i32) -> i32
747773
// CHECK: %[[V0:.*]] = llvm.mlir.poison : vector<2xi32>

mlir/test/Dialect/GPU/promote-shuffle-amdgpu.mlir

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ module attributes {transform.with_named_sequence} {
44
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
55
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
66
transform.apply_patterns to %func {
7-
transform.apply_patterns.gpu.gpu_shuffle_to_amdgpu
7+
transform.apply_patterns.gpu.gpu_shuffle_to_amdgpu chipset = "gfx950"
88
} : !transform.any_op
99
transform.yield
1010
}
@@ -21,3 +21,15 @@ func.func @gpu_shuffle_swizzle(%arg0: i32) -> (i32, i1) {
2121
%shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : i32
2222
func.return %shfl, %pred : i32, i1
2323
}
24+
25+
// CHECK-LABEL: func @gpu_shuffle_permlane_swap
26+
// CHECK-SAME: (%[[ARG:.*]]: i32)
27+
func.func @gpu_shuffle_permlane_swap(%arg0: i32) -> (i32, i1) {
28+
// CHECK: %[[TRUE:.*]] = arith.constant true
29+
// CHECK: %[[RES:.*]] = amdgpu.permlane_swap %[[ARG]] 32 : i32
30+
// CHECK: return %[[RES]], %[[TRUE]] : i32, i1
31+
%width = arith.constant 64 : i32
32+
%offset = arith.constant 32 : i32
33+
%shfl, %pred = gpu.shuffle xor %arg0, %offset, %width : i32
34+
func.return %shfl, %pred : i32, i1
35+
}

0 commit comments

Comments
 (0)