@@ -153,68 +153,106 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153 type.isVarArg ());
154154 });
155155
156- // Argument materializations convert from the new block argument types
157- // (multiple SSA values that make up a memref descriptor) back to the
158- // original block argument type. The dialect conversion framework will then
159- // insert a target materialization from the original block argument type to
160- // a legal type.
161- addArgumentMaterialization ([&](OpBuilder &builder,
162- UnrankedMemRefType resultType,
163- ValueRange inputs, Location loc) {
156+ // Add generic source and target materializations to handle cases where
157+ // non-LLVM types persist after an LLVM conversion.
158+ addSourceMaterialization ([&](OpBuilder &builder, Type resultType,
159+ ValueRange inputs, Location loc) {
160+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
161+ .getResult (0 );
162+ });
163+ addTargetMaterialization ([&](OpBuilder &builder, Type resultType,
164+ ValueRange inputs, Location loc) {
165+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
166+ .getResult (0 );
167+ });
168+
169+ // Source materializations convert the MemrRef descriptor elements
170+ // (multiple SSA values that make up a MemrRef descriptor) back to the
171+ // original MemRef type.
172+ addSourceMaterialization ([&](OpBuilder &builder,
173+ UnrankedMemRefType resultType, ValueRange inputs,
174+ Location loc) {
164175 if (inputs.size () == 1 ) {
165176 // Bare pointers are not supported for unranked memrefs because a
166177 // memref descriptor cannot be built just from a bare pointer.
167178 return Value ();
168179 }
169180 Value desc =
170181 UnrankedMemRefDescriptor::pack (builder, loc, *this , resultType, inputs);
171- // An argument materialization must return a value of type
182+ // A source materialization must return a value of type
172183 // `resultType`, so insert a cast from the memref descriptor type
173184 // (!llvm.struct) to the original memref type.
174185 return builder.create <UnrealizedConversionCastOp>(loc, resultType, desc)
175186 .getResult (0 );
176187 });
177- addArgumentMaterialization ([&](OpBuilder &builder, MemRefType resultType,
178- ValueRange inputs, Location loc) {
188+ addSourceMaterialization ([&](OpBuilder &builder, MemRefType resultType,
189+ ValueRange inputs, Location loc) {
190+ if (inputs.size () == 1 &&
191+ isa<LLVM::LLVMStructType>(inputs.front ().getType ()))
192+ return Value ();
193+
179194 Value desc;
180- if (inputs.size () == 1 ) {
195+ if (inputs.size () == 1 &&
196+ isa<LLVM::LLVMPointerType>(inputs.front ().getType ())) {
181197 // This is a bare pointer. We allow bare pointers only for function entry
182198 // blocks.
183199 BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front ());
184200 if (!barePtr)
185201 return Value ();
186- Block *block = barePtr.getOwner ();
187- if (!block->isEntryBlock () ||
188- !isa<FunctionOpInterface>(block->getParentOp ()))
189- return Value ();
202+ // Block *block = barePtr.getOwner();
203+ // if (!block->isEntryBlock() ||
204+ // !isa<FunctionOpInterface>(block->getParentOp()))
205+ // return Value();
190206 desc = MemRefDescriptor::fromStaticShape (builder, loc, *this , resultType,
191207 inputs[0 ]);
192208 } else {
193209 desc = MemRefDescriptor::pack (builder, loc, *this , resultType, inputs);
194210 }
195- // An argument materialization must return a value of type `resultType`,
211+ // A source materialization must return a value of type `resultType`,
196212 // so insert a cast from the memref descriptor type (!llvm.struct) to the
197213 // original memref type.
198214 return builder.create <UnrealizedConversionCastOp>(loc, resultType, desc)
199215 .getResult (0 );
200216 });
201- // Add generic source and target materializations to handle cases where
202- // non-LLVM types persist after an LLVM conversion.
203- addSourceMaterialization ([&](OpBuilder &builder, Type resultType,
204- ValueRange inputs, Location loc) {
205- if (inputs.size () != 1 )
206- return Value ();
217+ addTargetMaterialization ([&](OpBuilder &builder, LLVM::LLVMStructType resultType,
218+ ValueRange inputs, Location loc,
219+ Type originalType) -> Value {
220+ if (auto memrefType = dyn_cast_or_null<MemRefType>(originalType)) {
221+ if (inputs.size () == 1 ) {
222+ Value input = inputs.front ();
223+ // if (auto castOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
224+ // if (castOp.getInputs().size() == 1 &&
225+ // isa<LLVM::LLVMPointerType>(castOp.getInputs()[0].getType())) {
226+ // input = castOp.getInputs()[0];
227+ // }
228+ // }
229+ if (!isa<LLVM::LLVMPointerType>(input.getType ()))
230+ return Value ();
231+ BlockArgument barePtr = dyn_cast<BlockArgument>(input);
232+ if (!barePtr)
233+ return Value ();
234+ // Block *block = barePtr.getOwner();
235+ // if (!block->isEntryBlock() ||
236+ // !isa<FunctionOpInterface>(block->getParentOp()))
237+ // return Value();
238+ // Bare ptr
239+ return MemRefDescriptor::fromStaticShape (builder, loc, *this ,
240+ memrefType, input);
241+ }
242+ return MemRefDescriptor::pack (builder, loc, *this , memrefType, inputs);
243+ }
207244
208- return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
209- .getResult (0 );
210- });
211- addTargetMaterialization ([&](OpBuilder &builder, Type resultType,
212- ValueRange inputs, Location loc) {
213- if (inputs.size () != 1 )
214- return Value ();
245+ if (auto memrefType = dyn_cast_or_null<UnrankedMemRefType>(originalType)) {
246+ if (inputs.size () == 1 ) {
247+ // Bare pointers are not supported for unranked memrefs because a
248+ // memref descriptor cannot be built just from a bare pointer.
249+ return Value ();
250+ }
251+ return UnrankedMemRefDescriptor::pack (builder, loc, *this , memrefType,
252+ inputs);
253+ }
215254
216- return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
217- .getResult (0 );
255+ return Value ();
218256 });
219257
220258 // Integer memory spaces map to themselves.
0 commit comments