-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][ROCDL] Lower gpu.subgroup_size to wavefrontsize
#137360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -52,6 +52,25 @@ namespace mlir { | |
|
|
||
| using namespace mlir; | ||
|
|
||
| // Truncate or extend the result depending on the index bitwidth specified | ||
| // by the LLVMTypeConverter options. | ||
| static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter, | ||
| Location loc, Value value, | ||
| const LLVMTypeConverter &converter) { | ||
| int64_t intWidth = cast<IntegerType>(value.getType()).getWidth(); | ||
| int64_t indexBitwidth = converter.getIndexTypeBitwidth(); | ||
| auto indexBitwidthType = | ||
| IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth()); | ||
| // TODO: use <=> in C++20. | ||
| if (indexBitwidth > intWidth) { | ||
| return rewriter.create<LLVM::SExtOp>(loc, indexBitwidthType, value); | ||
| } | ||
| if (indexBitwidth < intWidth) { | ||
| return rewriter.create<LLVM::TruncOp>(loc, indexBitwidthType, value); | ||
| } | ||
| return value; | ||
| } | ||
|
|
||
| /// Returns true if the given `gpu.func` can be safely called using the bare | ||
| /// pointer calling convention. | ||
| static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) { | ||
|
|
@@ -113,6 +132,26 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> { | |
| } | ||
| }; | ||
|
|
||
| struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> { | ||
| using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; | ||
| LogicalResult | ||
| matchAndRewrite(gpu::SubgroupSizeOp op, gpu::SubgroupSizeOp::Adaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| LLVM::ConstantRangeAttr bounds = nullptr; | ||
| if (auto upperBoundAttr = op.getUpperBoundAttr()) { | ||
| bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>( | ||
| /*bitWidth=*/32, /*lower=*/32, | ||
|
||
| /*upper=*/op.getUpperBoundAttr().getInt()); | ||
|
||
| } | ||
| Value wavefrontOp = rewriter.create<ROCDL::WavefrontSizeOp>( | ||
lialan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| op.getLoc(), rewriter.getI32Type(), bounds); | ||
| wavefrontOp = truncOrExtToLLVMType(rewriter, op.getLoc(), wavefrontOp, | ||
| *getTypeConverter()); | ||
| rewriter.replaceOp(op, {wavefrontOp}); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> { | ||
| using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern; | ||
|
|
||
|
|
@@ -405,7 +444,9 @@ void mlir::populateGpuToROCDLConversionPatterns( | |
| // TODO: Add alignment for workgroup memory | ||
| patterns.add<GPUDynamicSharedMemoryOpLowering>(converter); | ||
|
|
||
| patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter); | ||
| patterns | ||
| .add<GPUShuffleOpLowering, GPULaneIdOpToROCDL, GPUSubgroupSizeOpToROCDL>( | ||
| converter); | ||
|
|
||
| populateMathToROCDLConversionPatterns(converter, patterns); | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also use this in the code below, e.g.,
GPULaneIdOpToROCDL?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is refactored in the other PR: https://github.com/llvm/llvm-project/pull/136405/files#diff-cd4257dddc1cb3043071e5c7641774615ffd685cc779acf70a47a3e83401b515R141