@@ -52,25 +52,6 @@ namespace mlir {
5252
5353using namespace mlir ;
5454
55- // Truncate or extend the result depending on the index bitwidth specified
56- // by the LLVMTypeConverter options.
57- static Value truncOrExtToLLVMType (ConversionPatternRewriter &rewriter,
58- Location loc, Value value,
59- const LLVMTypeConverter &converter) {
60- int64_t intWidth = cast<IntegerType>(value.getType ()).getWidth ();
61- int64_t indexBitwidth = converter.getIndexTypeBitwidth ();
62- auto indexBitwidthType =
63- IntegerType::get (rewriter.getContext (), converter.getIndexTypeBitwidth ());
64- // TODO: use <=> in C++20.
65- if (indexBitwidth > intWidth) {
66- return rewriter.create <LLVM::SExtOp>(loc, indexBitwidthType, value);
67- }
68- if (indexBitwidth < intWidth) {
69- return rewriter.create <LLVM::TruncOp>(loc, indexBitwidthType, value);
70- }
71- return value;
72- }
73-
7455// / Returns true if the given `gpu.func` can be safely called using the bare
7556// / pointer calling convention.
7657static bool canBeCalledWithBarePointers (gpu::GPUFuncOp func) {
@@ -99,6 +80,26 @@ static constexpr StringLiteral amdgcnDataLayout =
9980 " 64-S32-A5-G1-ni:7:8:9" ;
10081
10182namespace {
83+
84+ // Truncate or extend the result depending on the index bitwidth specified
85+ // by the LLVMTypeConverter options.
86+ static Value truncOrExtToLLVMType (ConversionPatternRewriter &rewriter,
87+ Location loc, Value value,
88+ const LLVMTypeConverter &converter) {
89+ int64_t intWidth = cast<IntegerType>(value.getType ()).getWidth ();
90+ int64_t indexBitwidth = converter.getIndexTypeBitwidth ();
91+ auto indexBitwidthType =
92+ IntegerType::get (rewriter.getContext (), converter.getIndexTypeBitwidth ());
93+ // TODO: use <=> in C++20.
94+ if (indexBitwidth > intWidth) {
95+ return rewriter.create <LLVM::SExtOp>(loc, indexBitwidthType, value);
96+ }
97+ if (indexBitwidth < intWidth) {
98+ return rewriter.create <LLVM::TruncOp>(loc, indexBitwidthType, value);
99+ }
100+ return value;
101+ }
102+
102103struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
103104 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
104105
@@ -117,16 +118,7 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
117118 rewriter.create <ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
118119 Value laneId = rewriter.create <ROCDL::MbcntHiOp>(
119120 loc, intTy, ValueRange{minus1, mbcntLo});
120- // Truncate or extend the result depending on the index bitwidth specified
121- // by the LLVMTypeConverter options.
122- const unsigned indexBitwidth = getTypeConverter ()->getIndexTypeBitwidth ();
123- if (indexBitwidth > 32 ) {
124- laneId = rewriter.create <LLVM::SExtOp>(
125- loc, IntegerType::get (context, indexBitwidth), laneId);
126- } else if (indexBitwidth < 32 ) {
127- laneId = rewriter.create <LLVM::TruncOp>(
128- loc, IntegerType::get (context, indexBitwidth), laneId);
129- }
121+ laneId = truncOrExtToLLVMType (rewriter, loc, laneId, *getTypeConverter ());
130122 rewriter.replaceOp (op, {laneId});
131123 return success ();
132124 }
@@ -150,11 +142,11 @@ struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
150142 /* bitWidth=*/ 32 , /* lower=*/ isBeforeGfx10 ? 64 : 32 ,
151143 /* upper=*/ op.getUpperBoundAttr ().getInt () + 1 );
152144 }
153- Value wavefrontOp = rewriter.create <ROCDL::WavefrontSizeOp>(
145+ Value wavefrontSizeOp = rewriter.create <ROCDL::WavefrontSizeOp>(
154146 op.getLoc (), rewriter.getI32Type (), bounds);
155- wavefrontOp = truncOrExtToLLVMType (rewriter, op. getLoc (), wavefrontOp,
156- *getTypeConverter ());
157- rewriter.replaceOp (op, {wavefrontOp });
147+ wavefrontSizeOp = truncOrExtToLLVMType (
148+ rewriter, op. getLoc (), wavefrontSizeOp, *getTypeConverter ());
149+ rewriter.replaceOp (op, {wavefrontSizeOp });
158150 return success ();
159151 }
160152
@@ -239,6 +231,65 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
239231 }
240232};
241233
234+ struct GPUSubgroupIdOpToROCDL final
235+ : ConvertOpToLLVMPattern<gpu::SubgroupIdOp> {
236+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
237+
238+ LogicalResult
239+ matchAndRewrite (gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
240+ ConversionPatternRewriter &rewriter) const override {
241+ // Calculation of the thread's subgroup identifier.
242+ //
243+ // The process involves mapping the thread's 3D identifier within its
244+ // workgroup/block (w_id.x, w_id.y, w_id.z) to a 1D linear index.
245+ // This linearization assumes a layout where the x-dimension (w_dim.x)
246+ // varies most rapidly (i.e., it is the innermost dimension).
247+ //
248+ // The formula for the linearized thread index is:
249+ // L = w_id.x + w_dim.x * (w_id.y + (w_dim.y * w_id.z))
250+ //
251+ // Subsequently, the range of linearized indices [0, N_threads-1] is
252+ // divided into consecutive, non-overlapping segments, each representing
253+ // a subgroup of size 'subgroup_size'.
254+ //
255+ // Example Partitioning (N = subgroup_size):
256+ // | Subgroup 0 | Subgroup 1 | Subgroup 2 | ... |
257+ // | Indices 0..N-1 | Indices N..2N-1 | Indices 2N..3N-1| ... |
258+ //
259+ // The subgroup identifier is obtained via integer division of the
260+ // linearized thread index by the predefined 'subgroup_size'.
261+ //
262+ // subgroup_id = floor( L / subgroup_size )
263+ // = (w_id.x + w_dim.x * (w_id.y + w_dim.y * w_id.z)) /
264+ // subgroup_size
265+ auto int32Type = IntegerType::get (rewriter.getContext (), 32 );
266+ Location loc = op.getLoc ();
267+ LLVM::IntegerOverflowFlags flags =
268+ LLVM::IntegerOverflowFlags::nsw | LLVM::IntegerOverflowFlags::nuw;
269+ Value workitemIdX = rewriter.create <ROCDL::ThreadIdXOp>(loc, int32Type);
270+ Value workitemIdY = rewriter.create <ROCDL::ThreadIdYOp>(loc, int32Type);
271+ Value workitemIdZ = rewriter.create <ROCDL::ThreadIdZOp>(loc, int32Type);
272+ Value workitemDimX = rewriter.create <ROCDL::BlockDimXOp>(loc, int32Type);
273+ Value workitemDimY = rewriter.create <ROCDL::BlockDimYOp>(loc, int32Type);
274+ Value dimYxIdZ = rewriter.create <LLVM::MulOp>(loc, int32Type, workitemDimY,
275+ workitemIdZ, flags);
276+ Value dimYxIdZPlusIdY = rewriter.create <LLVM::AddOp>(
277+ loc, int32Type, dimYxIdZ, workitemIdY, flags);
278+ Value dimYxIdZPlusIdYTimesDimX = rewriter.create <LLVM::MulOp>(
279+ loc, int32Type, workitemDimX, dimYxIdZPlusIdY, flags);
280+ Value workitemIdXPlusDimYxIdZPlusIdYTimesDimX =
281+ rewriter.create <LLVM::AddOp>(loc, int32Type, workitemIdX,
282+ dimYxIdZPlusIdYTimesDimX, flags);
283+ Value subgroupSize = rewriter.create <ROCDL::WavefrontSizeOp>(
284+ loc, rewriter.getI32Type (), nullptr );
285+ Value waveIdOp = rewriter.create <LLVM::UDivOp>(
286+ loc, workitemIdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
287+ rewriter.replaceOp (op, {truncOrExtToLLVMType (rewriter, loc, waveIdOp,
288+ *getTypeConverter ())});
289+ return success ();
290+ }
291+ };
292+
242293// / Import the GPU Ops to ROCDL Patterns.
243294#include " GPUToROCDL.cpp.inc"
244295
@@ -249,19 +300,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
249300// code.
250301struct LowerGpuOpsToROCDLOpsPass final
251302 : public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
252- LowerGpuOpsToROCDLOpsPass () = default ;
253- LowerGpuOpsToROCDLOpsPass (const std::string &chipset, unsigned indexBitwidth,
254- bool useBarePtrCallConv,
255- gpu::amd::Runtime runtime) {
256- if (this ->chipset .getNumOccurrences () == 0 )
257- this ->chipset = chipset;
258- if (this ->indexBitwidth .getNumOccurrences () == 0 )
259- this ->indexBitwidth = indexBitwidth;
260- if (this ->useBarePtrCallConv .getNumOccurrences () == 0 )
261- this ->useBarePtrCallConv = useBarePtrCallConv;
262- if (this ->runtime .getNumOccurrences () == 0 )
263- this ->runtime = runtime;
264- }
303+ using Base::Base;
265304
266305 void getDependentDialects (DialectRegistry ®istry) const override {
267306 Base::getDependentDialects (registry);
@@ -455,17 +494,15 @@ void mlir::populateGpuToROCDLConversionPatterns(
455494 // TODO: Add alignment for workgroup memory
456495 patterns.add <GPUDynamicSharedMemoryOpLowering>(converter);
457496
458- patterns.add <GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
497+ patterns
498+ .add <GPUShuffleOpLowering, GPULaneIdOpToROCDL, GPUSubgroupIdOpToROCDL>(
499+ converter);
459500 patterns.add <GPUSubgroupSizeOpToROCDL>(converter, chipset);
460501
461502 populateMathToROCDLConversionPatterns (converter, patterns);
462503}
463504
464505std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
465- mlir::createLowerGpuOpsToROCDLOpsPass (const std::string &chipset,
466- unsigned indexBitwidth,
467- bool useBarePtrCallConv,
468- gpu::amd::Runtime runtime) {
469- return std::make_unique<LowerGpuOpsToROCDLOpsPass>(
470- chipset, indexBitwidth, useBarePtrCallConv, runtime);
506+ mlir::createLowerGpuOpsToROCDLOpsPass () {
507+ return std::make_unique<LowerGpuOpsToROCDLOpsPass>();
471508}
0 commit comments