@@ -216,28 +216,14 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
216216 return memRefDescriptor;
217217}
218218
219- LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors (
220- OpBuilder &builder, Location loc, TypeRange origTypes,
221- SmallVectorImpl<Value> &operands, bool toDynamic) const {
222- assert (origTypes.size () == operands.size () &&
223- " expected as may original types as operands" );
224-
225- // Find operands of unranked memref type and store them.
226- SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs;
227- SmallVector<unsigned > unrankedAddressSpaces;
228- for (unsigned i = 0 , e = operands.size (); i < e; ++i) {
229- if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
230- unrankedMemrefs.emplace_back (operands[i]);
231- FailureOr<unsigned > addressSpace =
232- getTypeConverter ()->getMemRefAddressSpace (memRefType);
233- if (failed (addressSpace))
234- return failure ();
235- unrankedAddressSpaces.emplace_back (*addressSpace);
236- }
237- }
238-
239- if (unrankedMemrefs.empty ())
240- return success ();
219+ Value ConvertToLLVMPattern::copyUnrankedDescriptor (
220+ OpBuilder &builder, Location loc, UnrankedMemRefType memRefType,
221+ Value operand, bool toDynamic) const {
222+ // Convert memory space.
223+ FailureOr<unsigned > addressSpace =
224+ getTypeConverter ()->getMemRefAddressSpace (memRefType);
225+ if (failed (addressSpace))
226+ return {};
241227
242228 // Get frequently used types.
243229 Type indexType = getTypeConverter ()->getIndexType ();
@@ -248,54 +234,61 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
248234 if (toDynamic) {
249235 mallocFunc = LLVM::lookupOrCreateMallocFn (builder, module , indexType);
250236 if (failed (mallocFunc))
251- return failure () ;
237+ return {} ;
252238 }
253239 if (!toDynamic) {
254240 freeFunc = LLVM::lookupOrCreateFreeFn (builder, module );
255241 if (failed (freeFunc))
256- return failure () ;
242+ return {} ;
257243 }
258244
259- unsigned unrankedMemrefPos = 0 ;
260- for (unsigned i = 0 , e = operands.size (); i < e; ++i) {
261- Type type = origTypes[i];
262- if (!isa<UnrankedMemRefType>(type))
263- continue ;
264- UnrankedMemRefDescriptor desc (operands[i]);
265- Value allocationSize = UnrankedMemRefDescriptor::computeSize (
266- builder, loc, *getTypeConverter (), desc,
267- unrankedAddressSpaces[unrankedMemrefPos++]);
268-
269- // Allocate memory, copy, and free the source if necessary.
270- Value memory =
271- toDynamic ? LLVM::CallOp::create (builder, loc, mallocFunc.value (),
272- allocationSize)
273- .getResult ()
274- : LLVM::AllocaOp::create (builder, loc, getPtrType (),
275- IntegerType::get (getContext (), 8 ),
276- allocationSize,
277- /* alignment=*/ 0 );
278- Value source = desc.memRefDescPtr (builder, loc);
279- LLVM::MemcpyOp::create (builder, loc, memory, source, allocationSize, false );
280- if (!toDynamic)
281- LLVM::CallOp::create (builder, loc, freeFunc.value (), source);
282-
283- // Create a new descriptor. The same descriptor can be returned multiple
284- // times, attempting to modify its pointer can lead to memory leaks
285- // (allocated twice and overwritten) or double frees (the caller does not
286- // know if the descriptor points to the same memory).
287- Type descriptorType = getTypeConverter ()->convertType (type);
288- if (!descriptorType)
289- return failure ();
290- auto updatedDesc =
291- UnrankedMemRefDescriptor::poison (builder, loc, descriptorType);
292- Value rank = desc.rank (builder, loc);
293- updatedDesc.setRank (builder, loc, rank);
294- updatedDesc.setMemRefDescPtr (builder, loc, memory);
245+ UnrankedMemRefDescriptor desc (operand);
246+ Value allocationSize = UnrankedMemRefDescriptor::computeSize (
247+ builder, loc, *getTypeConverter (), desc, *addressSpace);
248+
249+ // Allocate memory, copy, and free the source if necessary.
250+ Value memory = toDynamic
251+ ? LLVM::CallOp::create (builder, loc, mallocFunc.value (),
252+ allocationSize)
253+ .getResult ()
254+ : LLVM::AllocaOp::create (builder, loc, getPtrType (),
255+ IntegerType::get (getContext (), 8 ),
256+ allocationSize,
257+ /* alignment=*/ 0 );
258+ Value source = desc.memRefDescPtr (builder, loc);
259+ LLVM::MemcpyOp::create (builder, loc, memory, source, allocationSize, false );
260+ if (!toDynamic)
261+ LLVM::CallOp::create (builder, loc, freeFunc.value (), source);
262+
263+ // Create a new descriptor. The same descriptor can be returned multiple
264+ // times, attempting to modify its pointer can lead to memory leaks
265+ // (allocated twice and overwritten) or double frees (the caller does not
266+ // know if the descriptor points to the same memory).
267+ Type descriptorType = getTypeConverter ()->convertType (memRefType);
268+ if (!descriptorType)
269+ return {};
270+ auto updatedDesc =
271+ UnrankedMemRefDescriptor::poison (builder, loc, descriptorType);
272+ Value rank = desc.rank (builder, loc);
273+ updatedDesc.setRank (builder, loc, rank);
274+ updatedDesc.setMemRefDescPtr (builder, loc, memory);
275+ return updatedDesc;
276+ }
295277
296- operands[i] = updatedDesc;
278+ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors (
279+ OpBuilder &builder, Location loc, TypeRange origTypes,
280+ SmallVectorImpl<Value> &operands, bool toDynamic) const {
281+ assert (origTypes.size () == operands.size () &&
282+ " expected as may original types as operands" );
283+ for (unsigned i = 0 , e = operands.size (); i < e; ++i) {
284+ if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
285+ Value updatedDesc = copyUnrankedDescriptor (builder, loc, memRefType,
286+ operands[i], toDynamic);
287+ if (!updatedDesc)
288+ return failure ();
289+ operands[i] = updatedDesc;
290+ }
297291 }
298-
299292 return success ();
300293}
301294
0 commit comments