@@ -52,6 +52,25 @@ namespace mlir {
5252
5353using namespace mlir ;
5454
55+ // Truncate or extend the result depending on the index bitwidth specified
56+ // by the LLVMTypeConverter options.
57+ static Value truncOrExtToLLVMType (ConversionPatternRewriter &rewriter,
58+ Location loc, Value value,
59+ const LLVMTypeConverter &converter) {
60+ int64_t intWidth = cast<IntegerType>(value.getType ()).getWidth ();
61+ int64_t indexBitwidth = converter.getIndexTypeBitwidth ();
62+ auto indexBitwidthType =
63+ IntegerType::get (rewriter.getContext (), converter.getIndexTypeBitwidth ());
64+ // TODO: use <=> in C++20.
65+ if (indexBitwidth > intWidth) {
66+ return rewriter.create <LLVM::SExtOp>(loc, indexBitwidthType, value);
67+ }
68+ if (indexBitwidth < intWidth) {
69+ return rewriter.create <LLVM::TruncOp>(loc, indexBitwidthType, value);
70+ }
71+ return value;
72+ }
73+
5574// / Returns true if the given `gpu.func` can be safely called using the bare
5675// / pointer calling convention.
5776static bool canBeCalledWithBarePointers (gpu::GPUFuncOp func) {
@@ -113,6 +132,20 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
113132 }
114133};
115134
135+ struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
136+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
137+ LogicalResult
138+ matchAndRewrite (gpu::SubgroupSizeOp op, gpu::SubgroupSizeOp::Adaptor adaptor,
139+ ConversionPatternRewriter &rewriter) const override {
140+ Value wavefrontOp = rewriter.create <ROCDL::WavefrontSizeOp>(
141+ op.getLoc (), IntegerType::get (rewriter.getContext (), 32 ));
142+ wavefrontOp = truncOrExtToLLVMType (rewriter, op.getLoc (), wavefrontOp,
143+ *getTypeConverter ());
144+ rewriter.replaceOp (op, {wavefrontOp});
145+ return success ();
146+ }
147+ };
148+
116149struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern <gpu::ShuffleOp> {
117150 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
118151
@@ -405,7 +438,9 @@ void mlir::populateGpuToROCDLConversionPatterns(
405438 // TODO: Add alignment for workgroup memory
406439 patterns.add <GPUDynamicSharedMemoryOpLowering>(converter);
407440
408- patterns.add <GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
441+ patterns
442+ .add <GPUShuffleOpLowering, GPULaneIdOpToROCDL, GPUSubgroupSizeOpToROCDL>(
443+ converter);
409444
410445 populateMathToROCDLConversionPatterns (converter, patterns);
411446}
0 commit comments