@@ -80,6 +80,24 @@ static constexpr StringLiteral amdgcnDataLayout =
8080 " 64-S32-A5-G1-ni:7:8:9" ;
8181
8282namespace {
83+
84+ // Truncate or extend the result depending on the index bitwidth specified
85+ // by the LLVMTypeConverter options.
86+ static Value truncOrExtToLLVMType (ConversionPatternRewriter &rewriter,
87+ Location loc, Value value,
88+ const LLVMTypeConverter *converter) {
89+ auto intWidth = cast<IntegerType>(value.getType ()).getWidth ();
90+ auto indexBitwidth = converter->getIndexTypeBitwidth ();
91+ if (indexBitwidth > intWidth) {
92+ return rewriter.create <LLVM::SExtOp>(
93+ loc, IntegerType::get (rewriter.getContext (), indexBitwidth), value);
94+ } else if (indexBitwidth < intWidth) {
95+ return rewriter.create <LLVM::TruncOp>(
96+ loc, IntegerType::get (rewriter.getContext (), indexBitwidth), value);
97+ }
98+ return value;
99+ }
100+
83101struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
84102 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
85103
@@ -98,16 +116,7 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
98116 rewriter.create <ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
99117 Value laneId = rewriter.create <ROCDL::MbcntHiOp>(
100118 loc, intTy, ValueRange{minus1, mbcntLo});
101- // Truncate or extend the result depending on the index bitwidth specified
102- // by the LLVMTypeConverter options.
103- const unsigned indexBitwidth = getTypeConverter ()->getIndexTypeBitwidth ();
104- if (indexBitwidth > 32 ) {
105- laneId = rewriter.create <LLVM::SExtOp>(
106- loc, IntegerType::get (context, indexBitwidth), laneId);
107- } else if (indexBitwidth < 32 ) {
108- laneId = rewriter.create <LLVM::TruncOp>(
109- loc, IntegerType::get (context, indexBitwidth), laneId);
110- }
119+ laneId = truncOrExtToLLVMType (rewriter, loc, laneId, getTypeConverter ());
111120 rewriter.replaceOp (op, {laneId});
112121 return success ();
113122 }
@@ -190,6 +199,21 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
190199 }
191200};
192201
202+ struct GPUSubgroupIdOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupIdOp> {
203+ using ConvertOpToLLVMPattern<gpu::SubgroupIdOp>::ConvertOpToLLVMPattern;
204+
205+ LogicalResult
206+ matchAndRewrite (gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
207+ ConversionPatternRewriter &rewriter) const override {
208+ auto int32Type = IntegerType::get (rewriter.getContext (), 32 );
209+ Value waveIdOp = rewriter.create <ROCDL::WaveIdOp>(op.getLoc (), int32Type);
210+ waveIdOp = truncOrExtToLLVMType (rewriter, op.getLoc (), waveIdOp,
211+ getTypeConverter ());
212+ rewriter.replaceOp (op, {waveIdOp});
213+ return success ();
214+ }
215+ };
216+
193217// / Import the GPU Ops to ROCDL Patterns.
194218#include " GPUToROCDL.cpp.inc"
195219
@@ -405,7 +429,9 @@ void mlir::populateGpuToROCDLConversionPatterns(
405429 // TODO: Add alignment for workgroup memory
406430 patterns.add <GPUDynamicSharedMemoryOpLowering>(converter);
407431
408- patterns.add <GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
432+ patterns
433+ .add <GPUShuffleOpLowering, GPULaneIdOpToROCDL, GPUSubgroupIdOpToROCDL>(
434+ converter);
409435
410436 populateMathToROCDLConversionPatterns (converter, patterns);
411437}
0 commit comments