@@ -153,6 +153,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153 type.isVarArg ());
154154 });
155155
156+ // Helper function that checks if the given value range is a bare pointer.
157+ auto isBarePointer = [](ValueRange values) {
158+ return values.size () == 1 &&
159+ isa<LLVM::LLVMPointerType>(values.front ().getType ());
160+ };
161+
156162 // Argument materializations convert from the new block argument types
157163 // (multiple SSA values that make up a memref descriptor) back to the
158164 // original block argument type. The dialect conversion framework will then
@@ -161,11 +167,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
161167 addArgumentMaterialization ([&](OpBuilder &builder,
162168 UnrankedMemRefType resultType,
163169 ValueRange inputs, Location loc) {
164- if (inputs. size () == 1 ) {
165- // Bare pointers are not supported for unranked memrefs because a
166- // memref descriptor cannot be built just from a bare pointer.
170+ // Note: Bare pointers are not supported for unranked memrefs because a
171+ // memref descriptor cannot be built just from a bare pointer.
172+ if ( TypeRange (inputs) != getUnrankedMemRefDescriptorFields ())
167173 return Value ();
168- }
169174 Value desc =
170175 UnrankedMemRefDescriptor::pack (builder, loc, *this , resultType, inputs);
171176 // An argument materialization must return a value of type
@@ -177,20 +182,17 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
177182 addArgumentMaterialization ([&](OpBuilder &builder, MemRefType resultType,
178183 ValueRange inputs, Location loc) {
179184 Value desc;
180- if (inputs.size () == 1 ) {
181- // This is a bare pointer. We allow bare pointers only for function entry
182- // blocks.
183- BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front ());
184- if (!barePtr)
185- return Value ();
186- Block *block = barePtr.getOwner ();
187- if (!block->isEntryBlock () ||
188- !isa<FunctionOpInterface>(block->getParentOp ()))
189- return Value ();
185+ if (isBarePointer (inputs)) {
190186 desc = MemRefDescriptor::fromStaticShape (builder, loc, *this , resultType,
191187 inputs[0 ]);
192- } else {
188+ } else if (TypeRange (inputs) ==
189+ getMemRefDescriptorFields (resultType,
190+ /* unpackAggregates=*/ true )) {
193191 desc = MemRefDescriptor::pack (builder, loc, *this , resultType, inputs);
192+ } else {
193+ // The inputs are neither a bare pointer nor an unpacked memref
194+ // descriptor. This materialization function cannot be used.
195+ return Value ();
194196 }
195197 // An argument materialization must return a value of type `resultType`,
196198 // so insert a cast from the memref descriptor type (!llvm.struct) to the
0 commit comments