Skip to content

Commit 0916901

Browse files
committed
[mlir] AMDGPUToROCDL: lower amdgpu.swizzle_bitmode
Repack `amdgpu.swizzle_bitmode` arguments and lower it to `rocdl.ds_swizzle`. Repacking logic is follows: * `sizeof(arg) < sizeof(i32)`: bitcast to integer and zext to i32 and then trunc and bitcast back. * `sizeof(arg) == sizeof(i32)`: just bitcast to i32 and back if not i32 * `sizeof(arg) > sizeof(i32)`: bitcast to `vector<Nxi32>`, extract individual elements and do a series of `rocdl.ds_swizzle` and then compose vector and bitcast back. Added repacking logic to LLVM utils so it can be used elsewhere. I'm planning to use it for `gpu.shuffle` later.
1 parent 0abf227 commit 0916901

File tree

4 files changed

+212
-0
lines changed

4 files changed

+212
-0
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ LogicalResult oneToOneRewrite(
3131
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
3232

3333
} // namespace detail
34+
35+
/// Decomposes a `src` value into a set of values of type `dstType` through
36+
/// series of bitcasts and vector ops. Src and dst types are expected to be int
37+
/// or float types or vector types of them.
38+
SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, Value src,
39+
Type dstType);
40+
41+
/// Composes a set of `src` values into a single value of type `dstType` through
42+
/// series of bitcasts and vector ops. Inversely to `decomposeValue`, this
43+
/// function is used to combine multiple values into a single value.
44+
Value composeValue(OpBuilder &builder, Location loc, ValueRange src,
45+
Type dstType);
3446
} // namespace LLVM
3547

3648
/// Base class for operation conversions targeting the LLVM IR dialect. It

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,38 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
13771377
}
13781378
};
13791379

1380+
struct AMDGPUSwizzleBitModeLowering
1381+
: public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
1382+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1383+
1384+
LogicalResult
1385+
matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
1386+
ConversionPatternRewriter &rewriter) const override {
1387+
Location loc = op.getLoc();
1388+
Type i32 = rewriter.getI32Type();
1389+
Value src = adaptor.getSrc();
1390+
SmallVector<Value> decomposed =
1391+
LLVM::decomposeValue(rewriter, loc, src, i32);
1392+
unsigned andMask = op.getAndMask();
1393+
unsigned orMask = op.getOrMask();
1394+
unsigned xorMask = op.getXorMask();
1395+
1396+
// bit 15 is 0 for the BitMode swizzle.
1397+
unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
1398+
Value maskValue = createI32Constant(rewriter, loc, mask);
1399+
SmallVector<Value> swizzled;
1400+
for (Value v : decomposed) {
1401+
Value res =
1402+
rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
1403+
swizzled.emplace_back(res);
1404+
}
1405+
1406+
Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType());
1407+
rewriter.replaceOp(op, result);
1408+
return success();
1409+
}
1410+
};
1411+
13801412
struct ConvertAMDGPUToROCDLPass
13811413
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
13821414
using Base::Base;
@@ -1444,4 +1476,5 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
14441476
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
14451477
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
14461478
GatherToLDSOpLowering>(converter, chipset);
1479+
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
14471480
}

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,3 +381,95 @@ LogicalResult LLVM::detail::oneToOneRewrite(
381381
rewriter.replaceOp(op, results);
382382
return success();
383383
}
384+
385+
static unsigned getBitWidth(Type type) {
386+
if (type.isIntOrFloat())
387+
return type.getIntOrFloatBitWidth();
388+
389+
auto vec = cast<VectorType>(type);
390+
return vec.getNumElements() * getBitWidth(vec.getElementType());
391+
}
392+
393+
static Value createI32Constant(OpBuilder &builder, Location loc,
394+
int32_t value) {
395+
Type i32 = builder.getI32Type();
396+
return builder.create<LLVM::ConstantOp>(loc, i32, value);
397+
}
398+
399+
SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc,
400+
Value src, Type dstType) {
401+
Type srcType = src.getType();
402+
if (srcType == dstType)
403+
return {src};
404+
405+
unsigned srcBitWidth = getBitWidth(srcType);
406+
unsigned dstBitWidth = getBitWidth(dstType);
407+
if (srcBitWidth == dstBitWidth) {
408+
Value cast = builder.create<LLVM::BitcastOp>(loc, dstType, src);
409+
return {cast};
410+
}
411+
412+
if (dstBitWidth > srcBitWidth) {
413+
auto smallerInt = builder.getIntegerType(srcBitWidth);
414+
if (srcType != smallerInt)
415+
src = builder.create<LLVM::BitcastOp>(loc, smallerInt, src);
416+
417+
auto largerInt = builder.getIntegerType(dstBitWidth);
418+
Value res = builder.create<LLVM::ZExtOp>(loc, largerInt, src);
419+
return {res};
420+
}
421+
assert(srcBitWidth % dstBitWidth == 0 &&
422+
"src bit width must be a multiple of dst bit width");
423+
int64_t numElements = srcBitWidth / dstBitWidth;
424+
auto vecType = VectorType::get(numElements, dstType);
425+
426+
src = builder.create<LLVM::BitcastOp>(loc, vecType, src);
427+
428+
SmallVector<Value> res;
429+
for (auto i : llvm::seq<int64_t>(0, numElements)) {
430+
Value idx = createI32Constant(builder, loc, i);
431+
Value elem = builder.create<LLVM::ExtractElementOp>(loc, src, idx);
432+
res.emplace_back(elem);
433+
}
434+
435+
return res;
436+
}
437+
438+
Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
439+
Type dstType) {
440+
assert(!src.empty() && "src range must not be empty");
441+
if (src.size() == 1) {
442+
Value res = src.front();
443+
if (res.getType() == dstType)
444+
return res;
445+
446+
unsigned srcBitWidth = getBitWidth(res.getType());
447+
unsigned dstBitWidth = getBitWidth(dstType);
448+
if (dstBitWidth < srcBitWidth) {
449+
auto largerInt = builder.getIntegerType(srcBitWidth);
450+
if (res.getType() != largerInt)
451+
res = builder.create<LLVM::BitcastOp>(loc, largerInt, res);
452+
453+
auto smallerInt = builder.getIntegerType(dstBitWidth);
454+
res = builder.create<LLVM::TruncOp>(loc, smallerInt, res);
455+
}
456+
457+
if (res.getType() != dstType)
458+
res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
459+
460+
return res;
461+
}
462+
463+
int64_t numElements = src.size();
464+
auto srcType = VectorType::get(numElements, src.front().getType());
465+
Value res = builder.create<LLVM::PoisonOp>(loc, srcType);
466+
for (auto &&[i, elem] : llvm::enumerate(src)) {
467+
Value idx = createI32Constant(builder, loc, i);
468+
res = builder.create<LLVM::InsertElementOp>(loc, srcType, res, elem, idx);
469+
}
470+
471+
if (res.getType() != dstType)
472+
res = builder.create<LLVM::BitcastOp>(loc, dstType, res);
473+
474+
return res;
475+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// RUN: mlir-opt -convert-amdgpu-to-rocdl --canonicalize %s | FileCheck %s
2+
3+
// CHECK-LABEL: func @test_swizzle_i32
4+
// CHECK-SAME: (%[[ARG0:.*]]: i32)
5+
func.func @test_swizzle_i32(%arg0 : i32) -> i32 {
6+
// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
7+
// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ARG0]], %[[C]] : (i32, i32) -> i32
8+
// CHECK: return %[[RES]] : i32
9+
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : i32
10+
return %0 : i32
11+
}
12+
13+
// CHECK-LABEL: func @test_swizzle_f32
14+
// CHECK-SAME: (%[[ARG0:.*]]: f32)
15+
func.func @test_swizzle_f32(%arg0 : f32) -> f32 {
16+
// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
17+
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
18+
// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[CAST]], %[[C]] : (i32, i32) -> i32
19+
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
20+
// CHECK: return %[[RES_CAST]] : f32
21+
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f32
22+
return %0 : f32
23+
}
24+
25+
// CHECK-LABEL: func @test_swizzle_f16
26+
// CHECK-SAME: (%[[ARG0:.*]]: f16)
27+
func.func @test_swizzle_f16(%arg0 : f16) -> f16 {
28+
// CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
29+
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
30+
// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
31+
// CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ZEXT]], %[[C]] : (i32, i32) -> i32
32+
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
33+
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
34+
// CHECK: return %[[RES_CAST]] : f16
35+
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f16
36+
return %0 : f16
37+
}
38+
39+
// CHECK-LABEL: func @test_swizzle_2xi32
40+
// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
41+
func.func @test_swizzle_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
42+
// CHECK-DAG: %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
43+
// CHECK-DAG: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
44+
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
45+
// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
46+
// CHECK: %[[E0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
47+
// CHECK: %[[E1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
48+
// CHECK: %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
49+
// CHECK: %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
50+
// CHECK: %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
51+
// CHECK: %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
52+
// CHECK: return %[[V3]] : vector<2xi32>
53+
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<2xi32>
54+
return %0 : vector<2xi32>
55+
}
56+
57+
// CHECK-LABEL: func @test_swizzle_4xf16
58+
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
59+
func.func @test_swizzle_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
60+
// CHECK-DAG: %[[V1:.*]] = llvm.mlir.poison : vector<2xi32>
61+
// CHECK-DAG: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32
62+
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
63+
// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
64+
// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
65+
// CHECK: %[[E0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
66+
// CHECK: %[[E1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
67+
// CHECK: %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32
68+
// CHECK: %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32
69+
// CHECK: %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32>
70+
// CHECK: %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32>
71+
// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[V3]] : vector<2xi32> to vector<4xf16>
72+
// CHECK: return %[[CAST2]] : vector<4xf16>
73+
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<4xf16>
74+
return %0 : vector<4xf16>
75+
}

0 commit comments

Comments
 (0)