@@ -80,6 +80,23 @@ static constexpr StringLiteral amdgcnDataLayout =
8080 " 64-S32-A5-G1-ni:7:8:9" ;
8181
8282namespace {
83+
84+ // Truncate or extend the result depending on the index bitwidth specified
85+ // by the LLVMTypeConverter options.
86+ template <int64_t N>
87+ static Value truncOrExtToLLVMType (ConversionPatternRewriter &rewriter,
88+ Location loc, Value value,
89+ const unsigned indexBitwidth) {
90+ if (indexBitwidth > N) {
91+ return rewriter.create <LLVM::SExtOp>(
92+ loc, IntegerType::get (rewriter.getContext (), indexBitwidth), value);
93+ } else if (indexBitwidth < N) {
94+ return rewriter.create <LLVM::TruncOp>(
95+ loc, IntegerType::get (rewriter.getContext (), indexBitwidth), value);
96+ }
97+ return value;
98+ }
99+
83100struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
84101 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
85102
@@ -98,16 +115,8 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
98115 rewriter.create <ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
99116 Value laneId = rewriter.create <ROCDL::MbcntHiOp>(
100117 loc, intTy, ValueRange{minus1, mbcntLo});
101- // Truncate or extend the result depending on the index bitwidth specified
102- // by the LLVMTypeConverter options.
103118 const unsigned indexBitwidth = getTypeConverter ()->getIndexTypeBitwidth ();
104- if (indexBitwidth > 32 ) {
105- laneId = rewriter.create <LLVM::SExtOp>(
106- loc, IntegerType::get (context, indexBitwidth), laneId);
107- } else if (indexBitwidth < 32 ) {
108- laneId = rewriter.create <LLVM::TruncOp>(
109- loc, IntegerType::get (context, indexBitwidth), laneId);
110- }
119+ laneId = truncOrExtToLLVMType<32 >(rewriter, loc, laneId, indexBitwidth);
111120 rewriter.replaceOp (op, {laneId});
112121 return success ();
113122 }
@@ -190,6 +199,24 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
190199 }
191200};
192201
202+ struct GPUSubgroupIdOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupIdOp> {
203+ using ConvertOpToLLVMPattern<gpu::SubgroupIdOp>::ConvertOpToLLVMPattern;
204+
205+ LogicalResult
206+ matchAndRewrite (gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
207+ ConversionPatternRewriter &rewriter) const override {
208+ auto int32Type = IntegerType::get (rewriter.getContext (), 32 );
209+ Value waveIdOp = rewriter.create <ROCDL::WaveIdOp>(op.getLoc (), int32Type);
210+
211+ waveIdOp =
212+ truncOrExtToLLVMType<32 >(rewriter, op.getLoc (), waveIdOp,
213+ getTypeConverter ()->getIndexTypeBitwidth ());
214+
215+ rewriter.replaceOp (op, {waveIdOp});
216+ return success ();
217+ }
218+ };
219+
193220// / Import the GPU Ops to ROCDL Patterns.
194221#include " GPUToROCDL.cpp.inc"
195222
@@ -405,7 +432,9 @@ void mlir::populateGpuToROCDLConversionPatterns(
405432 // TODO: Add alignment for workgroup memory
406433 patterns.add <GPUDynamicSharedMemoryOpLowering>(converter);
407434
408- patterns.add <GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
435+ patterns
436+ .add <GPUShuffleOpLowering, GPULaneIdOpToROCDL, GPUSubgroupIdOpToROCDL>(
437+ converter);
409438
410439 populateMathToROCDLConversionPatterns (converter, patterns);
411440}
0 commit comments