@@ -153,20 +153,31 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153 type.isVarArg ());
154154 });
155155
156+ // Add generic source and target materializations to handle cases where
157+ // non-LLVM types persist after an LLVM conversion.
158+ addSourceMaterialization ([&](OpBuilder &builder, Type resultType,
159+ ValueRange inputs, Location loc) {
160+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
161+ .getResult (0 );
162+ });
163+ addTargetMaterialization ([&](OpBuilder &builder, Type resultType,
164+ ValueRange inputs, Location loc) {
165+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
166+ .getResult (0 );
167+ });
168+
156169 // Helper function that checks if the given value range is a bare pointer.
157170 auto isBarePointer = [](ValueRange values) {
158171 return values.size () == 1 &&
159172 isa<LLVM::LLVMPointerType>(values.front ().getType ());
160173 };
161174
162- // Argument materializations convert from the new block argument types
163- // (multiple SSA values that make up a memref descriptor) back to the
164- // original block argument type. The dialect conversion framework will then
165- // insert a target materialization from the original block argument type to
166- // a legal type.
167- addArgumentMaterialization ([&](OpBuilder &builder,
168- UnrankedMemRefType resultType,
169- ValueRange inputs, Location loc) {
175+ // Source materializations convert the MemrRef descriptor elements
176+ // (multiple SSA values that make up a MemrRef descriptor) back to the
177+ // original MemRef type.
178+ addSourceMaterialization ([&](OpBuilder &builder,
179+ UnrankedMemRefType resultType, ValueRange inputs,
180+ Location loc) {
170181 // Note: Bare pointers are not supported for unranked memrefs because a
171182 // memref descriptor cannot be built just from a bare pointer.
172183 if (TypeRange (inputs) != getUnrankedMemRefDescriptorFields ())
@@ -179,8 +190,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
179190 return builder.create <UnrealizedConversionCastOp>(loc, resultType, desc)
180191 .getResult (0 );
181192 });
182- addArgumentMaterialization ([&](OpBuilder &builder, MemRefType resultType,
183- ValueRange inputs, Location loc) {
193+ addSourceMaterialization ([&](OpBuilder &builder, MemRefType resultType,
194+ ValueRange inputs, Location loc) {
184195 Value desc;
185196 if (isBarePointer (inputs)) {
186197 desc = MemRefDescriptor::fromStaticShape (builder, loc, *this , resultType,
@@ -200,23 +211,30 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
200211 return builder.create <UnrealizedConversionCastOp>(loc, resultType, desc)
201212 .getResult (0 );
202213 });
203- // Add generic source and target materializations to handle cases where
204- // non-LLVM types persist after an LLVM conversion.
205- addSourceMaterialization ([&](OpBuilder &builder, Type resultType,
206- ValueRange inputs, Location loc) {
207- if (inputs.size () != 1 )
208- return Value ();
214+ addTargetMaterialization ([&](OpBuilder &builder,
215+ LLVM::LLVMStructType resultType,
216+ ValueRange inputs, Location loc,
217+ Type originalType) -> Value {
218+ if (auto memrefType = dyn_cast_or_null<MemRefType>(originalType)) {
219+ if (isBarePointer (inputs)) {
220+ return MemRefDescriptor::fromStaticShape (builder, loc, *this ,
221+ memrefType, inputs[0 ]);
222+ } else if (TypeRange (inputs) ==
223+ getMemRefDescriptorFields (memrefType,
224+ /* unpackAggregates=*/ true )) {
225+ return MemRefDescriptor::pack (builder, loc, *this , memrefType, inputs);
226+ }
227+ }
209228
210- return builder. create <UnrealizedConversionCastOp>(loc, resultType, inputs)
211- . getResult ( 0 );
212- });
213- addTargetMaterialization ([&](OpBuilder &builder, Type resultType,
214- ValueRange inputs, Location loc) {
215- if (inputs. size () != 1 )
216- return Value ();
229+ if ( auto memrefType = dyn_cast_or_null<UnrankedMemRefType>(originalType)) {
230+ // Note: Bare pointers are not supported for unranked memrefs because a
231+ // memref descriptor cannot be built just from a bare pointer.
232+ if ( TypeRange (inputs) == getUnrankedMemRefDescriptorFields ())
233+ return UnrankedMemRefDescriptor::pack (builder, loc, * this , memrefType,
234+ inputs);
235+ }
217236
218- return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
219- .getResult (0 );
237+ return Value ();
220238 });
221239
222240 // Integer memory spaces map to themselves.
0 commit comments