@@ -153,31 +153,42 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153 type.isVarArg ());
154154 });
155155
156- // Argument materializations convert from the new block argument types
157- // (multiple SSA values that make up a memref descriptor) back to the
158- // original block argument type. The dialect conversion framework will then
159- // insert a target materialization from the original block argument type to
160- // a legal type.
161- addArgumentMaterialization ([&](OpBuilder &builder,
162- UnrankedMemRefType resultType,
163- ValueRange inputs, Location loc) {
156+ // Add generic source and target materializations to handle cases where
157+ // non-LLVM types persist after an LLVM conversion.
158+ addSourceMaterialization ([&](OpBuilder &builder, Type resultType,
159+ ValueRange inputs, Location loc) {
160+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
161+ .getResult (0 );
162+ });
163+
164+ // Source materializations convert the MemrRef descriptor elements
165+ // (multiple SSA values that make up a MemrRef descriptor) back to the
166+ // original MemRef type.
167+ addSourceMaterialization ([&](OpBuilder &builder,
168+ UnrankedMemRefType resultType, ValueRange inputs,
169+ Location loc) {
164170 if (inputs.size () == 1 ) {
165171 // Bare pointers are not supported for unranked memrefs because a
166172 // memref descriptor cannot be built just from a bare pointer.
167173 return Value ();
168174 }
169175 Value desc =
170176 UnrankedMemRefDescriptor::pack (builder, loc, *this , resultType, inputs);
171- // An argument materialization must return a value of type
177+ // A source materialization must return a value of type
172178 // `resultType`, so insert a cast from the memref descriptor type
173179 // (!llvm.struct) to the original memref type.
174180 return builder.create <UnrealizedConversionCastOp>(loc, resultType, desc)
175181 .getResult (0 );
176182 });
177- addArgumentMaterialization ([&](OpBuilder &builder, MemRefType resultType,
178- ValueRange inputs, Location loc) {
183+ addSourceMaterialization ([&](OpBuilder &builder, MemRefType resultType,
184+ ValueRange inputs, Location loc) {
185+ if (inputs.size () == 1 &&
186+ isa<LLVM::LLVMStructType>(inputs.front ().getType ()))
187+ return Value ();
188+
179189 Value desc;
180- if (inputs.size () == 1 ) {
190+ if (inputs.size () == 1 &&
191+ isa<LLVM::LLVMPointerType>(inputs.front ().getType ())) {
181192 // This is a bare pointer. We allow bare pointers only for function entry
182193 // blocks.
183194 BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front ());
@@ -192,15 +203,13 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
192203 } else {
193204 desc = MemRefDescriptor::pack (builder, loc, *this , resultType, inputs);
194205 }
195- // An argument materialization must return a value of type `resultType`,
206+ // A source materialization must return a value of type `resultType`,
196207 // so insert a cast from the memref descriptor type (!llvm.struct) to the
197208 // original memref type.
198209 return builder.create <UnrealizedConversionCastOp>(loc, resultType, desc)
199210 .getResult (0 );
200211 });
201- // Add generic source and target materializations to handle cases where
202- // non-LLVM types persist after an LLVM conversion.
203- addSourceMaterialization ([&](OpBuilder &builder, Type resultType,
212+ addTargetMaterialization ([&](OpBuilder &builder, Type resultType,
204213 ValueRange inputs, Location loc) {
205214 if (inputs.size () != 1 )
206215 return Value ();
@@ -209,12 +218,50 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
209218 .getResult (0 );
210219 });
211220 addTargetMaterialization ([&](OpBuilder &builder, Type resultType,
212- ValueRange inputs, Location loc) {
213- if (inputs.size () != 1 )
221+ ValueRange inputs, Location loc,
222+ Type originalType) -> Value {
223+ llvm::errs () << " TARGET MAT: -> " << resultType << " \n " ;
224+ if (!originalType) {
225+ llvm::errs () << " -- no orig\n " ;
214226 return Value ();
227+ }
228+ if (auto memrefType = dyn_cast<MemRefType>(originalType)) {
229+ assert (isa<LLVM::LLVMStructType>(resultType) && " expected struct type" );
230+ if (inputs.size () == 1 ) {
231+ Value input = inputs.front ();
232+ if (auto castOp = input.getDefiningOp <UnrealizedConversionCastOp>()) {
233+ if (castOp.getInputs ().size () == 1 &&
234+ isa<LLVM::LLVMPointerType>(castOp.getInputs ()[0 ].getType ())) {
235+ input = castOp.getInputs ()[0 ];
236+ }
237+ }
238+ if (!isa<LLVM::LLVMPointerType>(input.getType ()))
239+ return Value ();
240+ BlockArgument barePtr = dyn_cast<BlockArgument>(input);
241+ if (!barePtr)
242+ return Value ();
243+ Block *block = barePtr.getOwner ();
244+ if (!block->isEntryBlock () ||
245+ !isa<FunctionOpInterface>(block->getParentOp ()))
246+ return Value ();
247+ // Bare ptr
248+ return MemRefDescriptor::fromStaticShape (builder, loc, *this ,
249+ memrefType, input);
250+ }
251+ return MemRefDescriptor::pack (builder, loc, *this , memrefType, inputs);
252+ }
253+ if (auto memrefType = dyn_cast<UnrankedMemRefType>(originalType)) {
254+ assert (isa<LLVM::LLVMStructType>(resultType) && " expected struct type" );
255+ if (inputs.size () == 1 ) {
256+ // Bare pointers are not supported for unranked memrefs because a
257+ // memref descriptor cannot be built just from a bare pointer.
258+ return Value ();
259+ }
260+ return UnrankedMemRefDescriptor::pack (builder, loc, *this , memrefType,
261+ inputs);
262+ }
215263
216- return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
217- .getResult (0 );
264+ return Value ();
218265 });
219266
220267 // Integer memory spaces map to themselves.
0 commit comments