@@ -26,9 +26,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
2626
2727 SmallVector<LLVM::GlobalOp, 3 > workgroupBuffers;
2828 workgroupBuffers.reserve (gpuFuncOp.getNumWorkgroupAttributions ());
29- for (const auto &en : llvm::enumerate (gpuFuncOp.getWorkgroupAttributions ())) {
30- BlockArgument attribution = en.value ();
31-
29+ for (const auto [idx, attribution] :
30+ llvm::enumerate (gpuFuncOp.getWorkgroupAttributions ())) {
3231 auto type = dyn_cast<MemRefType>(attribution.getType ());
3332 assert (type && type.hasStaticShape () && " unexpected type in attribution" );
3433
@@ -37,12 +36,12 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
3736 auto elementType =
3837 cast<Type>(typeConverter->convertType (type.getElementType ()));
3938 auto arrayType = LLVM::LLVMArrayType::get (elementType, numElements);
40- std::string name = std::string (
41- llvm::formatv (" __wg_{0}_{1}" , gpuFuncOp.getName (), en. index () ));
39+ std::string name =
40+ std::string ( llvm::formatv (" __wg_{0}_{1}" , gpuFuncOp.getName (), idx ));
4241 uint64_t alignment = 0 ;
4342 if (auto alignAttr =
4443 dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getWorkgroupAttributionAttr (
45- en. index () , LLVM::LLVMDialect::getAlignAttrName ())))
44+ idx , LLVM::LLVMDialect::getAlignAttrName ())))
4645 alignment = alignAttr.getInt ();
4746 auto globalOp = rewriter.create <LLVM::GlobalOp>(
4847 gpuFuncOp.getLoc (), arrayType, /* isConstant=*/ false ,
@@ -105,8 +104,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
105104 rewriter.setInsertionPointToStart (&gpuFuncOp.front ());
106105 unsigned numProperArguments = gpuFuncOp.getNumArguments ();
107106
108- for (const auto &en : llvm::enumerate (workgroupBuffers)) {
109- LLVM::GlobalOp global = en.value ();
107+ for (const auto [idx, global] : llvm::enumerate (workgroupBuffers)) {
110108 auto ptrType = LLVM::LLVMPointerType::get (rewriter.getContext (),
111109 global.getAddrSpace ());
112110 Value address = rewriter.create <LLVM::AddressOfOp>(
@@ -119,18 +117,18 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
119117 // existing memref infrastructure. This may use more registers than
120118 // otherwise necessary given that memref sizes are fixed, but we can try
121119 // and canonicalize that away later.
122- Value attribution = gpuFuncOp.getWorkgroupAttributions ()[en. index () ];
120+ Value attribution = gpuFuncOp.getWorkgroupAttributions ()[idx ];
123121 auto type = cast<MemRefType>(attribution.getType ());
124122 auto descr = MemRefDescriptor::fromStaticShape (
125123 rewriter, loc, *getTypeConverter (), type, memory);
126- signatureConversion.remapInput (numProperArguments + en. index () , descr);
124+ signatureConversion.remapInput (numProperArguments + idx , descr);
127125 }
128126
129127 // Rewrite private memory attributions to alloca'ed buffers.
130128 unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions ();
131129 auto int64Ty = IntegerType::get (rewriter.getContext (), 64 );
132- for (const auto &en : llvm::enumerate (gpuFuncOp. getPrivateAttributions ())) {
133- Value attribution = en. value ();
130+ for (const auto [idx, attribution] :
131+ llvm::enumerate (gpuFuncOp. getPrivateAttributions ())) {
134132 auto type = cast<MemRefType>(attribution.getType ());
135133 assert (type && type.hasStaticShape () && " unexpected type in attribution" );
136134
@@ -145,14 +143,14 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
145143 uint64_t alignment = 0 ;
146144 if (auto alignAttr =
147145 dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr (
148- en. index () , LLVM::LLVMDialect::getAlignAttrName ())))
146+ idx , LLVM::LLVMDialect::getAlignAttrName ())))
149147 alignment = alignAttr.getInt ();
150148 Value allocated = rewriter.create <LLVM::AllocaOp>(
151149 gpuFuncOp.getLoc (), ptrType, elementType, numElements, alignment);
152150 auto descr = MemRefDescriptor::fromStaticShape (
153151 rewriter, loc, *getTypeConverter (), type, allocated);
154152 signatureConversion.remapInput (
155- numProperArguments + numWorkgroupAttributions + en. index () , descr);
153+ numProperArguments + numWorkgroupAttributions + idx , descr);
156154 }
157155 }
158156
@@ -169,15 +167,16 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
169167 if (getTypeConverter ()->getOptions ().useBarePtrCallConv ) {
170168 OpBuilder::InsertionGuard guard (rewriter);
171169 rewriter.setInsertionPointToStart (&llvmFuncOp.getBody ().front ());
172- for (const auto &en : llvm::enumerate (gpuFuncOp.getArgumentTypes ())) {
173- auto memrefTy = dyn_cast<MemRefType>(en.value ());
170+ for (const auto [idx, argTy] :
171+ llvm::enumerate (gpuFuncOp.getArgumentTypes ())) {
172+ auto memrefTy = dyn_cast<MemRefType>(argTy);
174173 if (!memrefTy)
175174 continue ;
176175 assert (memrefTy.hasStaticShape () &&
177176 " Bare pointer convertion used with dynamically-shaped memrefs" );
178177 // Use a placeholder when replacing uses of the memref argument to prevent
179178 // circular replacements.
180- auto remapping = signatureConversion.getInputMapping (en. index () );
179+ auto remapping = signatureConversion.getInputMapping (idx );
181180 assert (remapping && remapping->size == 1 &&
182181 " Type converter should produce 1-to-1 mapping for bare memrefs" );
183182 BlockArgument newArg =
@@ -193,19 +192,23 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
193192
194193 // Get memref type from function arguments and set the noalias to
195194 // pointer arguments.
196- for (const auto &en : llvm::enumerate (gpuFuncOp.getArgumentTypes ())) {
197- auto memrefTy = en.value ().dyn_cast <MemRefType>();
198- NamedAttrList argAttr = argAttrs
199- ? argAttrs[en.index ()].cast <DictionaryAttr>()
200- : NamedAttrList ();
201-
195+ for (const auto [idx, argTy] :
196+ llvm::enumerate (gpuFuncOp.getArgumentTypes ())) {
197+ auto remapping = signatureConversion.getInputMapping (idx);
198+ NamedAttrList argAttr =
199+ argAttrs ? argAttrs[idx].cast <DictionaryAttr>() : NamedAttrList ();
200+ auto copyAttribute = [&](StringRef attrName) {
201+ Attribute attr = argAttr.erase (attrName);
202+ if (!attr)
203+ return ;
204+ for (size_t i = 0 , e = remapping->size ; i < e; ++i)
205+ llvmFuncOp.setArgAttr (remapping->inputNo + i, attrName, attr);
206+ };
202207 auto copyPointerAttribute = [&](StringRef attrName) {
203208 Attribute attr = argAttr.erase (attrName);
204209
205- // This is a proxy for the bare pointer calling convention.
206210 if (!attr)
207211 return ;
208- auto remapping = signatureConversion.getInputMapping (en.index ());
209212 if (remapping->size > 1 &&
210213 attrName == LLVM::LLVMDialect::getNoAliasAttrName ()) {
211214 emitWarning (llvmFuncOp.getLoc (),
@@ -224,10 +227,23 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
224227 if (argAttr.empty ())
225228 continue ;
226229
227- if (memrefTy) {
230+ copyAttribute (LLVM::LLVMDialect::getReturnedAttrName ());
231+ copyAttribute (LLVM::LLVMDialect::getNoUndefAttrName ());
232+ copyAttribute (LLVM::LLVMDialect::getInRegAttrName ());
233+ bool lowersToPointer = false ;
234+ for (size_t i = 0 , e = remapping->size ; i < e; ++i) {
235+ lowersToPointer |= isa<LLVM::LLVMPointerType>(
236+ llvmFuncOp.getArgument (remapping->inputNo + i).getType ());
237+ }
238+
239+ if (lowersToPointer) {
228240 copyPointerAttribute (LLVM::LLVMDialect::getNoAliasAttrName ());
241+ copyPointerAttribute (LLVM::LLVMDialect::getNoCaptureAttrName ());
242+ copyPointerAttribute (LLVM::LLVMDialect::getNoFreeAttrName ());
243+ copyPointerAttribute (LLVM::LLVMDialect::getAlignAttrName ());
229244 copyPointerAttribute (LLVM::LLVMDialect::getReadonlyAttrName ());
230245 copyPointerAttribute (LLVM::LLVMDialect::getWriteOnlyAttrName ());
246+ copyPointerAttribute (LLVM::LLVMDialect::getReadnoneAttrName ());
231247 copyPointerAttribute (LLVM::LLVMDialect::getNonNullAttrName ());
232248 copyPointerAttribute (LLVM::LLVMDialect::getDereferenceableAttrName ());
233249 copyPointerAttribute (
0 commit comments