@@ -52,6 +52,25 @@ namespace mlir {
5252
5353using namespace mlir ;
5454
55+ // Truncate or extend the result depending on the index bitwidth specified
56+ // by the LLVMTypeConverter options.
57+ static Value truncOrExtToLLVMType (ConversionPatternRewriter &rewriter,
58+ Location loc, Value value,
59+ const LLVMTypeConverter &converter) {
60+ int64_t intWidth = cast<IntegerType>(value.getType ()).getWidth ();
61+ int64_t indexBitwidth = converter.getIndexTypeBitwidth ();
62+ auto indexBitwidthType =
63+ IntegerType::get (rewriter.getContext (), converter.getIndexTypeBitwidth ());
64+ // TODO: use <=> in C++20.
65+ if (indexBitwidth > intWidth) {
66+ return rewriter.create <LLVM::SExtOp>(loc, indexBitwidthType, value);
67+ }
68+ if (indexBitwidth < intWidth) {
69+ return rewriter.create <LLVM::TruncOp>(loc, indexBitwidthType, value);
70+ }
71+ return value;
72+ }
73+
5574// / Returns true if the given `gpu.func` can be safely called using the bare
5675// / pointer calling convention.
5776static bool canBeCalledWithBarePointers (gpu::GPUFuncOp func) {
@@ -113,6 +132,35 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
113132 }
114133};
115134
135+ struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
136+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
137+
138+ GPUSubgroupSizeOpToROCDL (const LLVMTypeConverter &converter,
139+ amdgpu::Chipset chipset)
140+ : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp>(converter),
141+ chipset (chipset) {}
142+
143+ LogicalResult
144+ matchAndRewrite (gpu::SubgroupSizeOp op, gpu::SubgroupSizeOp::Adaptor adaptor,
145+ ConversionPatternRewriter &rewriter) const override {
146+ LLVM::ConstantRangeAttr bounds = nullptr ;
147+ bool isBeforeGfx10 = chipset.majorVersion < 10 ;
148+ if (auto upperBoundAttr = op.getUpperBoundAttr ()) {
149+ bounds = rewriter.getAttr <LLVM::ConstantRangeAttr>(
150+ /* bitWidth=*/ 32 , /* lower=*/ isBeforeGfx10 ? 64 : 32 ,
151+ /* upper=*/ op.getUpperBoundAttr ().getInt () + 1 );
152+ }
153+ Value wavefrontOp = rewriter.create <ROCDL::WavefrontSizeOp>(
154+ op.getLoc (), rewriter.getI32Type (), bounds);
155+ wavefrontOp = truncOrExtToLLVMType (rewriter, op.getLoc (), wavefrontOp,
156+ *getTypeConverter ());
157+ rewriter.replaceOp (op, {wavefrontOp});
158+ return success ();
159+ }
160+
161+ const amdgpu::Chipset chipset;
162+ };
163+
116164struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern <gpu::ShuffleOp> {
117165 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
118166
@@ -322,7 +370,8 @@ struct LowerGpuOpsToROCDLOpsPass final
322370
323371 populateAMDGPUToROCDLConversionPatterns (converter, llvmPatterns,
324372 *maybeChipset);
325- populateGpuToROCDLConversionPatterns (converter, llvmPatterns, runtime);
373+ populateGpuToROCDLConversionPatterns (converter, llvmPatterns, runtime,
374+ *maybeChipset);
326375 configureGpuToROCDLConversionLegality (target);
327376 if (failed (applyPartialConversion (m, target, std::move (llvmPatterns))))
328377 signalPassFailure ();
@@ -370,7 +419,7 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
370419
371420void mlir::populateGpuToROCDLConversionPatterns (
372421 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
373- mlir::gpu::amd::Runtime runtime) {
422+ mlir::gpu::amd::Runtime runtime, amdgpu::Chipset chipset ) {
374423 using gpu::index_lowering::IndexKind;
375424 using gpu::index_lowering::IntrType;
376425 using mlir::gpu::amd::Runtime;
@@ -408,7 +457,10 @@ void mlir::populateGpuToROCDLConversionPatterns(
408457 // TODO: Add alignment for workgroup memory
409458 patterns.add <GPUDynamicSharedMemoryOpLowering>(converter);
410459
411- patterns.add <GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
460+ patterns
461+ .add <GPUShuffleOpLowering, GPULaneIdOpToROCDL, GPUSubgroupSizeOpToROCDL>(
462+ converter);
463+ patterns.add <GPUSubgroupSizeOpToROCDL>(converter, chipset);
412464
413465 populateMathToROCDLConversionPatterns (converter, patterns);
414466}
0 commit comments