@@ -134,33 +134,72 @@ const unsigned int PrintfBufferSize = 4 * MB;
134134// found anything that is not :
135135// * a CastInst
136136// * a GEP with non-zero indices
137- inline GlobalVariable* getGlobalVariable (Value* const v)
137+ // * a SelectInst
138+ // * a PHINode
139+ // In case of select or phi instruction two operands are added to the vector.
140+ // In another case only one is added.
141+ inline SmallVector<Value*, 2 > getGlobalVariable (Value* const v)
138142{
139- Value* curr = v;
140- while (nullptr != curr)
143+ SmallVector<Value *, 2 > curr;
144+ curr.push_back (v);
145+
146+ while (nullptr != curr.front () || nullptr != curr.back ())
141147 {
142- if (isa<GlobalVariable>(curr))
148+ if (curr.size () == 1 && isa<GlobalVariable>(curr.front ()))
149+ {
150+ break ;
151+ }
152+ else if (curr.size () == 2 && (isa<GlobalVariable>(curr.front ()) && isa<GlobalVariable>(curr.back ())))
143153 {
144154 break ;
145155 }
146156
147- if (CastInst * castInst = dyn_cast<CastInst>(curr))
157+ if (CastInst* castInst = dyn_cast<CastInst>(curr. front () ))
148158 {
149- curr = castInst->getOperand (0 );
159+ curr.pop_back ();
160+ curr.push_back (castInst->getOperand (0 ));
150161 }
151- else if (GetElementPtrInst * getElemPtrInst = dyn_cast<GetElementPtrInst>(curr))
162+ else if (GetElementPtrInst* getElemPtrInst = dyn_cast<GetElementPtrInst>(curr. front () ))
152163 {
153- curr = getElemPtrInst->hasAllZeroIndices () ? getElemPtrInst->getPointerOperand () : nullptr ;
164+ if (curr.size () == 2 )
165+ {
166+ if (GetElementPtrInst* getElemPtrInst2 = dyn_cast<GetElementPtrInst>(curr.back ()))
167+ {
168+ curr.pop_back ();
169+ curr.pop_back ();
170+ curr.push_back (getElemPtrInst->hasAllZeroIndices () ? getElemPtrInst->getPointerOperand () : nullptr );
171+ curr.push_back (getElemPtrInst2->hasAllZeroIndices () ? getElemPtrInst2->getPointerOperand () : nullptr );
172+ }
173+ }
174+ else
175+ {
176+ curr.pop_back ();
177+ curr.push_back (getElemPtrInst->hasAllZeroIndices () ? getElemPtrInst->getPointerOperand () : nullptr );
178+ }
179+ }
180+ else if (SelectInst* selectInst = dyn_cast<SelectInst>(curr.front ()))
181+ {
182+ curr.pop_back ();
183+ curr.push_back (selectInst->getOperand (1 ));
184+ curr.push_back (selectInst->getOperand (2 ));
185+ }
186+ else if (PHINode* phiNode = dyn_cast<PHINode>(curr.front ()))
187+ {
188+ curr.pop_back ();
189+ curr.push_back (phiNode->getOperand (0 ));
190+ curr.push_back (phiNode->getOperand (1 ));
154191 }
155192 else
156193 {
157194 // Unhandled value type
158- assert ((false == isa<ConstantExpr>(curr)));
159- curr = nullptr ;
195+ curr.front () = nullptr ;
196+ if (curr.size () == 2 )
197+ {
198+ curr.back () = nullptr ;
199+ }
160200 }
161201 }
162-
163- return dyn_cast_or_null<GlobalVariable>(curr);
202+ return curr;
164203}
165204
166205OpenCLPrintfResolution::OpenCLPrintfResolution () : FunctionPass(ID), m_atomicAddFunc(nullptr )
@@ -282,51 +321,97 @@ std::string OpenCLPrintfResolution::getEscapedString(const ConstantDataSequentia
282321 return Name;
283322}
284323
285- int OpenCLPrintfResolution::processPrintfString (Value* printfArg, Function& F)
324+ Value* OpenCLPrintfResolution::processPrintfString (Value* printfArg, Function& F)
286325{
287- GlobalVariable* formatString = getGlobalVariable (printfArg);
326+ GlobalVariable* formatString = nullptr ;
327+ SmallVector<Value*, 2 > curr = getGlobalVariable (printfArg);
328+ SmallVector<unsigned int , 2 > sv;
329+ for (auto curr_i : curr)
330+ {
331+ auto & curr_e = *curr_i;
288332
289- ConstantDataArray* formatStringConst = ((nullptr != formatString) && (formatString->hasInitializer ())) ?
290- dyn_cast<ConstantDataArray>(formatString->getInitializer ()) :
291- nullptr ;
333+ formatString = dyn_cast_or_null<GlobalVariable>(&curr_e);
334+ ConstantDataArray* formatStringConst = ((nullptr != formatString) && (formatString->hasInitializer ())) ?
335+ dyn_cast<ConstantDataArray>(formatString->getInitializer ()) :
336+ nullptr ;
292337
293- if (nullptr == formatStringConst)
294- {
295- assert (0 && " Unexpected printf argument (expected string literal)" );
296- return 0 ;
297- }
338+ if (nullptr == formatStringConst)
339+ {
340+ assert (0 && " Unexpected printf argument (expected string literal)" );
341+ return 0 ;
342+ }
343+
344+ // Add new metadata node and put the printf string into it.
345+ // The first element of metadata node is the string index,
346+ // the second element is the string itself.
347+ NamedMDNode* namedMDNode = m_module->getOrInsertNamedMetadata (getPrintfStringsMDNodeName (F));
348+ SmallVector<Metadata*, 2 > args;
349+ Metadata* stringIndexVal = ConstantAsMetadata::get (
350+ ConstantInt::get (m_int32Type, m_stringIndex));
298351
299- // Add new metadata node and put the printf string into it.
300- // The first element of metadata node is the string index,
301- // the second element is the string itself.
302- NamedMDNode* namedMDNode = m_module->getOrInsertNamedMetadata (getPrintfStringsMDNodeName (F));
303- SmallVector<Metadata*, 2 > args;
304- Metadata* stringIndexVal = ConstantAsMetadata::get (
305- ConstantInt::get (m_int32Type, m_stringIndex++));
352+ sv.push_back (m_stringIndex++);
306353
307- std::string escaped_string = getEscapedString (formatStringConst);
308- MDString* final_string = MDString::get (*m_context, escaped_string);
354+ std::string escaped_string = getEscapedString (formatStringConst);
355+ MDString* final_string = MDString::get (*m_context, escaped_string);
309356
310- args.push_back (stringIndexVal);
311- args.push_back (final_string);
357+ args.push_back (stringIndexVal);
358+ args.push_back (final_string);
312359
313- MDNode* itemMDNode = MDNode::get (*m_context, args);
314- namedMDNode->addOperand (itemMDNode);
360+ MDNode* itemMDNode = MDNode::get (*m_context, args);
361+ namedMDNode->addOperand (itemMDNode);
362+ }
315363
316- return m_stringIndex - 1 ;
364+ // Checks if the vector have two elements.
365+ // If it has it adds a new phi/select instruction that is responsible
366+ // for the correct execution of the basic instruction.
367+ // This information is forwarded to the store instruction.
368+ if (curr.size () == 2 )
369+ {
370+ if (GetElementPtrInst* getElemPtrInst = dyn_cast<GetElementPtrInst>(printfArg))
371+ {
372+ if (PHINode* phiNode = dyn_cast<PHINode>(getElemPtrInst->getPointerOperand ()))
373+ {
374+ PHINode* phiNode2 = PHINode::Create (m_int32Type, 2 , " " , phiNode);
375+ phiNode2->addIncoming (ConstantInt::get (m_int32Type, sv.front ()), phiNode->getIncomingBlock (0 ));
376+ phiNode2->addIncoming (ConstantInt::get (m_int32Type, sv.back ()), phiNode->getIncomingBlock (1 ));
377+ return phiNode2;
378+ }
379+ }
380+ else if (SelectInst* selectInst = dyn_cast<SelectInst>(printfArg))
381+ {
382+ SelectInst* selectInst2 = SelectInst::Create (selectInst->getOperand (0 ), ConstantInt::get (m_int32Type, sv.front ()),
383+ ConstantInt::get (m_int32Type, sv.back ()), " " , selectInst);
384+ return selectInst2;
385+ }
386+ else
387+ {
388+ assert (0 && " Instructions in the vector are not supported!" );
389+ }
390+ }
391+ return ConstantInt::get (m_int32Type, m_stringIndex - 1 );
317392}
318393
319394
320395bool OpenCLPrintfResolution::argIsString (Value* printfArg)
321396{
322- GlobalVariable* formatString = getGlobalVariable (printfArg);
323- if (nullptr == formatString)
397+ GlobalVariable* formatString = nullptr ;
398+ SmallVector<Value*, 2 > curr = getGlobalVariable (printfArg);
399+
400+ for (auto curr_i : curr)
324401 {
325- return false ;
402+ auto & curr_e = *curr_i;
403+ formatString = dyn_cast_or_null<GlobalVariable>(&curr_e);
404+ if (nullptr == formatString)
405+ {
406+ return false ;
407+ }
408+ ConstantDataArray* formatStringConst = dyn_cast<ConstantDataArray>(formatString->getInitializer ());
409+ if ((nullptr == formatStringConst) && !formatStringConst->isCString ())
410+ {
411+ return false ;
412+ }
326413 }
327-
328- ConstantDataArray* formatStringConst = dyn_cast<ConstantDataArray>(formatString->getInitializer ());
329- return ((nullptr != formatStringConst) && formatStringConst->isCString ());
414+ return true ;
330415}
331416
332417std::string OpenCLPrintfResolution::getPrintfStringsMDNodeName (Function& F)
@@ -578,8 +663,7 @@ Value* OpenCLPrintfResolution::fixupPrintfArg(CallInst& printfCall, Value* arg,
578663 case USC::SHADER_PRINTF_STRING_LITERAL:
579664 {
580665 Function* F = printfCall.getParent ()->getParent ();
581- uint stringIndex = processPrintfString (arg, *F);
582- return ConstantInt::get (m_int32Type, stringIndex);
666+ return processPrintfString (arg, *F);
583667 }
584668 break ;
585669 case USC::SHADER_PRINTF_POINTER:
0 commit comments