@@ -316,6 +316,53 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
316316 }
317317};
318318
319+ // ===----------------------------------------------------------------------===//
320+ // Subgroup query ops.
321+ // ===----------------------------------------------------------------------===//
322+
323+ template <typename SubgroupOp>
324+ struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern<SubgroupOp> {
325+ using ConvertOpToLLVMPattern<SubgroupOp>::ConvertOpToLLVMPattern;
326+ using ConvertToLLVMPattern::getTypeConverter;
327+
328+ LogicalResult
329+ matchAndRewrite (SubgroupOp op, typename SubgroupOp::Adaptor adaptor,
330+ ConversionPatternRewriter &rewriter) const final {
331+ constexpr StringRef funcName = [] {
332+ if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupIdOp>) {
333+ return " _Z16get_sub_group_id" ;
334+ } else if constexpr (std::is_same_v<SubgroupOp, gpu::LaneIdOp>) {
335+ return " _Z22get_sub_group_local_id" ;
336+ } else if constexpr (std::is_same_v<SubgroupOp, gpu::NumSubgroupsOp>) {
337+ return " _Z18get_num_sub_groups" ;
338+ } else if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupSizeOp>) {
339+ return " _Z18get_sub_group_size" ;
340+ }
341+ }();
342+
343+ Operation *moduleOp =
344+ op->template getParentWithTrait <OpTrait::SymbolTable>();
345+ Type resultTy = rewriter.getI32Type ();
346+ LLVM::LLVMFuncOp func =
347+ lookupOrCreateSPIRVFn (moduleOp, funcName, {}, resultTy,
348+ /* isMemNone=*/ false , /* isConvergent=*/ false );
349+
350+ Location loc = op->getLoc ();
351+ Value result = createSPIRVBuiltinCall (loc, rewriter, func, {}).getResult ();
352+
353+ Type indexTy = getTypeConverter ()->getIndexType ();
354+ if (resultTy != indexTy) {
355+ if (indexTy.getIntOrFloatBitWidth () < resultTy.getIntOrFloatBitWidth ()) {
356+ return failure ();
357+ }
358+ result = rewriter.create <LLVM::ZExtOp>(loc, indexTy, result);
359+ }
360+
361+ rewriter.replaceOp (op, result);
362+ return success ();
363+ }
364+ };
365+
319366// ===----------------------------------------------------------------------===//
320367// GPU To LLVM-SPV Pass.
321368// ===----------------------------------------------------------------------===//
@@ -337,7 +384,9 @@ struct GPUToLLVMSPVConversionPass final
337384
338385 target.addIllegalOp <gpu::BarrierOp, gpu::BlockDimOp, gpu::BlockIdOp,
339386 gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp,
340- gpu::ReturnOp, gpu::ShuffleOp, gpu::ThreadIdOp>();
387+ gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp,
388+ gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp,
389+ gpu::ThreadIdOp>();
341390
342391 populateGpuToLLVMSPVConversionPatterns (converter, patterns);
343392 populateGpuMemorySpaceAttributeConversions (converter);
@@ -366,11 +415,15 @@ gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace addressSpace) {
366415void populateGpuToLLVMSPVConversionPatterns (LLVMTypeConverter &typeConverter,
367416 RewritePatternSet &patterns) {
368417 patterns.add <GPUBarrierConversion, GPUReturnOpLowering, GPUShuffleConversion,
418+ GPUSubgroupOpConversion<gpu::LaneIdOp>,
419+ GPUSubgroupOpConversion<gpu::NumSubgroupsOp>,
420+ GPUSubgroupOpConversion<gpu::SubgroupIdOp>,
421+ GPUSubgroupOpConversion<gpu::SubgroupSizeOp>,
422+ LaunchConfigOpConversion<gpu::BlockDimOp>,
369423 LaunchConfigOpConversion<gpu::BlockIdOp>,
424+ LaunchConfigOpConversion<gpu::GlobalIdOp>,
370425 LaunchConfigOpConversion<gpu::GridDimOp>,
371- LaunchConfigOpConversion<gpu::BlockDimOp>,
372- LaunchConfigOpConversion<gpu::ThreadIdOp>,
373- LaunchConfigOpConversion<gpu::GlobalIdOp>>(typeConverter);
426+ LaunchConfigOpConversion<gpu::ThreadIdOp>>(typeConverter);
374427 MLIRContext *context = &typeConverter.getContext ();
375428 unsigned privateAddressSpace =
376429 gpuAddressSpaceToOCLAddressSpace (gpu::AddressSpace::Private);
0 commit comments