-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir] AMDGPUToROCDL: lower amdgpu.swizzle_bitmode
#136223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1377,6 +1377,38 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> { | |
| } | ||
| }; | ||
|
|
||
| struct AMDGPUSwizzleBitModeLowering | ||
| : public ConvertOpToLLVMPattern<SwizzleBitModeOp> { | ||
| using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; | ||
|
|
||
| LogicalResult | ||
| matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| Location loc = op.getLoc(); | ||
| Type i32 = rewriter.getI32Type(); | ||
| Value src = adaptor.getSrc(); | ||
| SmallVector<Value> decomposed = | ||
| LLVM::decomposeValue(rewriter, loc, src, i32); | ||
| unsigned andMask = op.getAndMask(); | ||
| unsigned orMask = op.getOrMask(); | ||
| unsigned xorMask = op.getXorMask(); | ||
|
|
||
| // bit 15 is 0 for the BitMode swizzle. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we link to the ISA manual or llvm intrinsics here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| unsigned mask = andMask | (orMask << 5) | (xorMask << 10); | ||
| Value maskValue = createI32Constant(rewriter, loc, mask); | ||
| SmallVector<Value> swizzled; | ||
| for (Value v : decomposed) { | ||
| Value res = | ||
| rewriter.create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue); | ||
| swizzled.emplace_back(res); | ||
| } | ||
|
|
||
| Value result = LLVM::composeValue(rewriter, loc, swizzled, src.getType()); | ||
| rewriter.replaceOp(op, result); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| struct ConvertAMDGPUToROCDLPass | ||
| : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> { | ||
| using Base::Base; | ||
|
|
@@ -1444,4 +1476,5 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, | |
| MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering, | ||
| PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, | ||
| GatherToLDSOpLowering>(converter, chipset); | ||
| patterns.add<AMDGPUSwizzleBitModeLowering>(converter); | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -381,3 +381,95 @@ LogicalResult LLVM::detail::oneToOneRewrite( | |
| rewriter.replaceOp(op, results); | ||
| return success(); | ||
| } | ||
|
|
||
| static unsigned getBitWidth(Type type) { | ||
| if (type.isIntOrFloat()) | ||
| return type.getIntOrFloatBitWidth(); | ||
|
|
||
| auto vec = cast<VectorType>(type); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need to assert this is not a scalable vector or bail out in some other way
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, updated verifier to reject scalable vectors, added assert here. |
||
| return vec.getNumElements() * getBitWidth(vec.getElementType()); | ||
| } | ||
|
|
||
| static Value createI32Constant(OpBuilder &builder, Location loc, | ||
| int32_t value) { | ||
| Type i32 = builder.getI32Type(); | ||
| return builder.create<LLVM::ConstantOp>(loc, i32, value); | ||
| } | ||
|
|
||
| SmallVector<Value> mlir::LLVM::decomposeValue(OpBuilder &builder, Location loc, | ||
| Value src, Type dstType) { | ||
| Type srcType = src.getType(); | ||
| if (srcType == dstType) | ||
| return {src}; | ||
|
|
||
| unsigned srcBitWidth = getBitWidth(srcType); | ||
| unsigned dstBitWidth = getBitWidth(dstType); | ||
| if (srcBitWidth == dstBitWidth) { | ||
| Value cast = builder.create<LLVM::BitcastOp>(loc, dstType, src); | ||
| return {cast}; | ||
| } | ||
|
|
||
| if (dstBitWidth > srcBitWidth) { | ||
| auto smallerInt = builder.getIntegerType(srcBitWidth); | ||
| if (srcType != smallerInt) | ||
| src = builder.create<LLVM::BitcastOp>(loc, smallerInt, src); | ||
|
|
||
| auto largerInt = builder.getIntegerType(dstBitWidth); | ||
| Value res = builder.create<LLVM::ZExtOp>(loc, largerInt, src); | ||
| return {res}; | ||
| } | ||
| assert(srcBitWidth % dstBitWidth == 0 && | ||
| "src bit width must be a multiple of dst bit width"); | ||
| int64_t numElements = srcBitWidth / dstBitWidth; | ||
| auto vecType = VectorType::get(numElements, dstType); | ||
|
|
||
| src = builder.create<LLVM::BitcastOp>(loc, vecType, src); | ||
|
|
||
| SmallVector<Value> res; | ||
| for (auto i : llvm::seq<int64_t>(0, numElements)) { | ||
|
||
| Value idx = createI32Constant(builder, loc, i); | ||
| Value elem = builder.create<LLVM::ExtractElementOp>(loc, src, idx); | ||
| res.emplace_back(elem); | ||
| } | ||
|
|
||
| return res; | ||
| } | ||
|
|
||
| Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src, | ||
| Type dstType) { | ||
| assert(!src.empty() && "src range must not be empty"); | ||
| if (src.size() == 1) { | ||
| Value res = src.front(); | ||
| if (res.getType() == dstType) | ||
| return res; | ||
|
|
||
| unsigned srcBitWidth = getBitWidth(res.getType()); | ||
| unsigned dstBitWidth = getBitWidth(dstType); | ||
| if (dstBitWidth < srcBitWidth) { | ||
| auto largerInt = builder.getIntegerType(srcBitWidth); | ||
| if (res.getType() != largerInt) | ||
| res = builder.create<LLVM::BitcastOp>(loc, largerInt, res); | ||
|
|
||
| auto smallerInt = builder.getIntegerType(dstBitWidth); | ||
| res = builder.create<LLVM::TruncOp>(loc, smallerInt, res); | ||
| } | ||
|
|
||
| if (res.getType() != dstType) | ||
| res = builder.create<LLVM::BitcastOp>(loc, dstType, res); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. very cool, didn't know llvm::bitcast can transform from scalar to vector types with same bitwidth |
||
|
|
||
| return res; | ||
| } | ||
|
|
||
| int64_t numElements = src.size(); | ||
| auto srcType = VectorType::get(numElements, src.front().getType()); | ||
| Value res = builder.create<LLVM::PoisonOp>(loc, srcType); | ||
raikonenfnu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| for (auto &&[i, elem] : llvm::enumerate(src)) { | ||
| Value idx = createI32Constant(builder, loc, i); | ||
| res = builder.create<LLVM::InsertElementOp>(loc, srcType, res, elem, idx); | ||
| } | ||
|
|
||
| if (res.getType() != dstType) | ||
| res = builder.create<LLVM::BitcastOp>(loc, dstType, res); | ||
|
|
||
| return res; | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| // RUN: mlir-opt -convert-amdgpu-to-rocdl --canonicalize %s | FileCheck %s | ||
|
|
||
| // CHECK-LABEL: func @test_swizzle_i32 | ||
| // CHECK-SAME: (%[[ARG0:.*]]: i32) | ||
| func.func @test_swizzle_i32(%arg0 : i32) -> i32 { | ||
| // CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32 | ||
| // CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ARG0]], %[[C]] : (i32, i32) -> i32 | ||
| // CHECK: return %[[RES]] : i32 | ||
| %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : i32 | ||
| return %0 : i32 | ||
| } | ||
|
|
||
| // CHECK-LABEL: func @test_swizzle_f32 | ||
| // CHECK-SAME: (%[[ARG0:.*]]: f32) | ||
| func.func @test_swizzle_f32(%arg0 : f32) -> f32 { | ||
| // CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32 | ||
| // CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32 | ||
| // CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[CAST]], %[[C]] : (i32, i32) -> i32 | ||
| // CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32 | ||
| // CHECK: return %[[RES_CAST]] : f32 | ||
| %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f32 | ||
| return %0 : f32 | ||
| } | ||
|
|
||
| // CHECK-LABEL: func @test_swizzle_f16 | ||
| // CHECK-SAME: (%[[ARG0:.*]]: f16) | ||
| func.func @test_swizzle_f16(%arg0 : f16) -> f16 { | ||
| // CHECK: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32 | ||
| // CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16 | ||
| // CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32 | ||
| // CHECK: %[[RES:.*]] = rocdl.ds_swizzle %[[ZEXT]], %[[C]] : (i32, i32) -> i32 | ||
| // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16 | ||
| // CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16 | ||
| // CHECK: return %[[RES_CAST]] : f16 | ||
| %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f16 | ||
| return %0 : f16 | ||
| } | ||
|
|
||
| // CHECK-LABEL: func @test_swizzle_2xi32 | ||
| // CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>) | ||
| func.func @test_swizzle_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> { | ||
| // CHECK-DAG: %[[V1:.*]] = llvm.mlir.poison : vector<2xi32> | ||
| // CHECK-DAG: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32 | ||
| // CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 | ||
| // CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 | ||
| // CHECK: %[[E0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32> | ||
| // CHECK: %[[E1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32> | ||
| // CHECK: %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32 | ||
| // CHECK: %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32 | ||
| // CHECK: %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32> | ||
| // CHECK: %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32> | ||
| // CHECK: return %[[V3]] : vector<2xi32> | ||
| %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<2xi32> | ||
| return %0 : vector<2xi32> | ||
| } | ||
|
|
||
| // CHECK-LABEL: func @test_swizzle_4xf16 | ||
| // CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>) | ||
| func.func @test_swizzle_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> { | ||
| // CHECK-DAG: %[[V1:.*]] = llvm.mlir.poison : vector<2xi32> | ||
| // CHECK-DAG: %[[C:.*]] = llvm.mlir.constant(4161 : i32) : i32 | ||
| // CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 | ||
| // CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 | ||
| // CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32> | ||
| // CHECK: %[[E0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32> | ||
| // CHECK: %[[E1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32> | ||
| // CHECK: %[[S1:.*]] = rocdl.ds_swizzle %[[E0]], %[[C]] : (i32, i32) -> i32 | ||
| // CHECK: %[[S2:.*]] = rocdl.ds_swizzle %[[E1]], %[[C]] : (i32, i32) -> i32 | ||
| // CHECK: %[[V2:.*]] = llvm.insertelement %[[S1]], %[[V1]][%[[C0]] : i32] : vector<2xi32> | ||
| // CHECK: %[[V3:.*]] = llvm.insertelement %[[S2]], %[[V2]][%[[C1]] : i32] : vector<2xi32> | ||
| // CHECK: %[[CAST2:.*]] = llvm.bitcast %[[V3]] : vector<2xi32> to vector<4xf16> | ||
| // CHECK: return %[[CAST2]] : vector<4xf16> | ||
| %0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : vector<4xf16> | ||
| return %0 : vector<4xf16> | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unrelated to this PR, but I think it'd be more intuitive if the op name is
dsSwizzleOp, and we can have the QDMode and the BitMode as a attribute of this op. Additionally we can also reference https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations. :)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QDMode and BitMode variants have a different offsets format, my idea was to have them as separate ops in
AMDGPU, so user won't need to bother with offset bitpacking even if they are both lowered tods_swizzleeventually.