@@ -121,6 +121,15 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
121121 ConversionPatternRewriter &rewriter) const override ;
122122};
123123
124+ class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
125+ public:
126+ using OpConversionPattern::OpConversionPattern;
127+
128+ LogicalResult
129+ matchAndRewrite (gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
130+ ConversionPatternRewriter &rewriter) const override ;
131+ };
132+
124133} // namespace
125134
126135// ===----------------------------------------------------------------------===//
@@ -597,6 +606,124 @@ class GPUSubgroupReduceConversion final
597606 }
598607};
599608
609+ // Formulate a unique variable/constant name after
610+ // searching in the module for existing variable/constant names.
611+ // This is to avoid name collision with existing variables.
612+ // Example: printfMsg0, printfMsg1, printfMsg2, ...
613+ static std::string makeVarName (spirv::ModuleOp moduleOp, llvm::Twine prefix) {
614+ std::string name;
615+ unsigned number = 0 ;
616+
617+ do {
618+ name.clear ();
619+ name = (prefix + llvm::Twine (number++)).str ();
620+ } while (moduleOp.lookupSymbol (name));
621+
622+ return name;
623+ }
624+
625+ // / Pattern to convert a gpu.printf op into a SPIR-V CLPrintf op.
626+
627+ LogicalResult GPUPrintfConversion::matchAndRewrite (
628+ gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
629+ ConversionPatternRewriter &rewriter) const {
630+
631+ Location loc = gpuPrintfOp.getLoc ();
632+
633+ auto moduleOp = gpuPrintfOp->getParentOfType <spirv::ModuleOp>();
634+ if (!moduleOp)
635+ return failure ();
636+
637+ // SPIR-V global variable is used to initialize printf
638+ // format string value, if there are multiple printf messages,
639+ // each global var needs to be created with a unique name.
640+ std::string globalVarName = makeVarName (moduleOp, llvm::Twine (" printfMsg" ));
641+ spirv::GlobalVariableOp globalVar;
642+
643+ IntegerType i8Type = rewriter.getI8Type ();
644+ IntegerType i32Type = rewriter.getI32Type ();
645+
646+ // Each character of printf format string is
647+ // stored as a spec constant. We need to create
648+ // unique name for this spec constant like
649+ // @printfMsg0_sc0, @printfMsg0_sc1, ... by searching in the module
650+ // for existing spec constant names.
651+ auto createSpecConstant = [&](unsigned value) {
652+ auto attr = rewriter.getI8IntegerAttr (value);
653+ std::string specCstName =
654+ makeVarName (moduleOp, llvm::Twine (globalVarName) + " _sc" );
655+
656+ return rewriter.create <spirv::SpecConstantOp>(
657+ loc, rewriter.getStringAttr (specCstName), attr);
658+ };
659+ {
660+ Operation *parent =
661+ SymbolTable::getNearestSymbolTable (gpuPrintfOp->getParentOp ());
662+
663+ ConversionPatternRewriter::InsertionGuard guard (rewriter);
664+
665+ Block &entryBlock = *parent->getRegion (0 ).begin ();
666+ rewriter.setInsertionPointToStart (
667+ &entryBlock); // insertion point at module level
668+
669+ // Create Constituents with SpecConstant by scanning format string
670+ // Each character of format string is stored as a spec constant
671+ // and then these spec constants are used to create a
672+ // SpecConstantCompositeOp.
673+ llvm::SmallString<20 > formatString (adaptor.getFormat ());
674+ formatString.push_back (' \0 ' ); // Null terminate for C.
675+ SmallVector<Attribute, 4 > constituents;
676+ for (char c : formatString) {
677+ spirv::SpecConstantOp cSpecConstantOp = createSpecConstant (c);
678+ constituents.push_back (SymbolRefAttr::get (cSpecConstantOp));
679+ }
680+
681+ // Create SpecConstantCompositeOp to initialize the global variable
682+ size_t contentSize = constituents.size ();
683+ auto globalType = spirv::ArrayType::get (i8Type, contentSize);
684+ spirv::SpecConstantCompositeOp specCstComposite;
685+ // There will be one SpecConstantCompositeOp per printf message/global var,
686+ // so no need do lookup for existing ones.
687+ std::string specCstCompositeName =
688+ (llvm::Twine (globalVarName) + " _scc" ).str ();
689+
690+ specCstComposite = rewriter.create <spirv::SpecConstantCompositeOp>(
691+ loc, TypeAttr::get (globalType),
692+ rewriter.getStringAttr (specCstCompositeName),
693+ rewriter.getArrayAttr (constituents));
694+
695+ auto ptrType = spirv::PointerType::get (
696+ globalType, spirv::StorageClass::UniformConstant);
697+
698+ // Define a GlobalVarOp initialized using specialized constants
699+ // that is used to specify the printf format string
700+ // to be passed to the SPIRV CLPrintfOp.
701+ globalVar = rewriter.create <spirv::GlobalVariableOp>(
702+ loc, ptrType, globalVarName, FlatSymbolRefAttr::get (specCstComposite));
703+
704+ globalVar->setAttr (" Constant" , rewriter.getUnitAttr ());
705+ }
706+ // Get SSA value of Global variable and create pointer to i8 to point to
707+ // the format string.
708+ Value globalPtr = rewriter.create <spirv::AddressOfOp>(loc, globalVar);
709+ Value fmtStr = rewriter.create <spirv::BitcastOp>(
710+ loc,
711+ spirv::PointerType::get (i8Type, spirv::StorageClass::UniformConstant),
712+ globalPtr);
713+
714+ // Get printf arguments.
715+ auto printfArgs = llvm::to_vector_of<Value, 4 >(adaptor.getArgs ());
716+
717+ rewriter.create <spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
718+
719+ // Need to erase the gpu.printf op as gpu.printf does not use result vs
720+ // spirv::CLPrintfOp has i32 resultType so cannot replace with new SPIR-V
721+ // printf op.
722+ rewriter.eraseOp (gpuPrintfOp);
723+
724+ return success ();
725+ }
726+
600727// ===----------------------------------------------------------------------===//
601728// GPU To SPIRV Patterns.
602729// ===----------------------------------------------------------------------===//
@@ -620,5 +747,6 @@ void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
620747 SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
621748 spirv::BuiltIn::SubgroupSize>,
622749 WorkGroupSizeConversion, GPUAllReduceConversion,
623- GPUSubgroupReduceConversion>(typeConverter, patterns.getContext ());
750+ GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter,
751+ patterns.getContext ());
624752}
0 commit comments