@@ -153,70 +153,112 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153 type.isVarArg ());
154154 });
155155
156+ // Add generic source and target materializations to handle cases where
157+ // non-LLVM types persist after an LLVM conversion.
158+ addSourceMaterialization ([&](OpBuilder &builder, Type resultType,
159+ ValueRange inputs, Location loc) {
160+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
161+ .getResult (0 );
162+ });
163+ addTargetMaterialization ([&](OpBuilder &builder, Type resultType,
164+ ValueRange inputs, Location loc) {
165+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
166+ .getResult (0 );
167+ });
168+
156169 // Helper function that checks if the given value range is a bare pointer.
157170 auto isBarePointer = [](ValueRange values) {
158171 return values.size () == 1 &&
159172 isa<LLVM::LLVMPointerType>(values.front ().getType ());
160173 };
161174
162- // Argument materializations convert from the new block argument types
163- // (multiple SSA values that make up a memref descriptor) back to the
164- // original block argument type. The dialect conversion framework will then
165- // insert a target materialization from the original block argument type to
166- // a legal type.
167- addArgumentMaterialization ([&](OpBuilder &builder,
168- UnrankedMemRefType resultType,
169- ValueRange inputs, Location loc) {
175+ // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
176+ // must be passed explicitly.
177+ auto packUnrankedMemRefDesc =
178+ [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
179+ Location loc, LLVMTypeConverter &converter) -> Value {
170180 // Note: Bare pointers are not supported for unranked memrefs because a
171181 // memref descriptor cannot be built just from a bare pointer.
172- if (TypeRange (inputs) != getUnrankedMemRefDescriptorFields ())
182+ if (TypeRange (inputs) != converter. getUnrankedMemRefDescriptorFields ())
173183 return Value ();
174- Value desc =
175- UnrankedMemRefDescriptor::pack (builder, loc, *this , resultType, inputs);
184+ return UnrankedMemRefDescriptor::pack (builder, loc, converter, resultType,
185+ inputs);
186+ };
187+
188+ // MemRef descriptor elements -> UnrankedMemRefType
189+ auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
190+ UnrankedMemRefType resultType,
191+ ValueRange inputs, Location loc) {
176192 // An argument materialization must return a value of type
177193 // `resultType`, so insert a cast from the memref descriptor type
178194 // (!llvm.struct) to the original memref type.
179- return builder.create <UnrealizedConversionCastOp>(loc, resultType, desc)
180- .getResult (0 );
181- });
182- addArgumentMaterialization ([&](OpBuilder &builder, MemRefType resultType,
183- ValueRange inputs, Location loc) {
184- Value desc;
185- if (isBarePointer (inputs)) {
186- desc = MemRefDescriptor::fromStaticShape (builder, loc, *this , resultType,
187- inputs[0 ]);
188- } else if (TypeRange (inputs) ==
189- getMemRefDescriptorFields (resultType,
190- /* unpackAggregates=*/ true )) {
191- desc = MemRefDescriptor::pack (builder, loc, *this , resultType, inputs);
192- } else {
193- // The inputs are neither a bare pointer nor an unpacked memref
194- // descriptor. This materialization function cannot be used.
195+ Value packed =
196+ packUnrankedMemRefDesc (builder, resultType, inputs, loc, *this );
197+ if (!packed)
195198 return Value ();
196- }
199+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, packed)
200+ .getResult (0 );
201+ };
202+
203+ // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
204+ // must be passed explicitly.
205+ auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
206+ ValueRange inputs, Location loc,
207+ LLVMTypeConverter &converter) -> Value {
208+ assert (resultType && " expected non-null result type" );
209+ if (isBarePointer (inputs))
210+ return MemRefDescriptor::fromStaticShape (builder, loc, converter,
211+ resultType, inputs[0 ]);
212+ if (TypeRange (inputs) ==
213+ converter.getMemRefDescriptorFields (resultType,
214+ /* unpackAggregates=*/ true ))
215+ return MemRefDescriptor::pack (builder, loc, converter, resultType,
216+ inputs);
217+ // The inputs are neither a bare pointer nor an unpacked memref descriptor.
218+ // This materialization function cannot be used.
219+ return Value ();
220+ };
221+
222+ // MemRef descriptor elements -> MemRefType
223+ auto rankedMemRefMaterialization = [&](OpBuilder &builder,
224+ MemRefType resultType,
225+ ValueRange inputs, Location loc) {
197226 // An argument materialization must return a value of type `resultType`,
198227 // so insert a cast from the memref descriptor type (!llvm.struct) to the
199228 // original memref type.
200- return builder.create <UnrealizedConversionCastOp>(loc, resultType, desc)
201- .getResult (0 );
202- });
203- // Add generic source and target materializations to handle cases where
204- // non-LLVM types persist after an LLVM conversion.
205- addSourceMaterialization ([&](OpBuilder &builder, Type resultType,
206- ValueRange inputs, Location loc) {
207- if (inputs.size () != 1 )
229+ Value packed =
230+ packRankedMemRefDesc (builder, resultType, inputs, loc, *this );
231+ if (!packed)
208232 return Value ();
209-
210- return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
233+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, packed)
211234 .getResult (0 );
212- });
235+ };
236+
237+ // Argument materializations convert from the new block argument types
238+ // (multiple SSA values that make up a memref descriptor) back to the
239+ // original block argument type.
240+ addArgumentMaterialization (unrakedMemRefMaterialization);
241+ addArgumentMaterialization (rankedMemRefMaterialization);
242+ addSourceMaterialization (unrakedMemRefMaterialization);
243+ addSourceMaterialization (rankedMemRefMaterialization);
244+
245+ // Bare pointer -> Packed MemRef descriptor
213246 addTargetMaterialization ([&](OpBuilder &builder, Type resultType,
214- ValueRange inputs, Location loc) {
215- if (inputs.size () != 1 )
247+ ValueRange inputs, Location loc,
248+ Type originalType) -> Value {
249+ // The original MemRef type is required to build a MemRef descriptor
250+ // because the sizes/strides of the MemRef cannot be inferred from just the
251+ // bare pointer.
252+ if (!originalType)
216253 return Value ();
217-
218- return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
219- .getResult (0 );
254+ if (resultType != convertType (originalType))
255+ return Value ();
256+ if (auto memrefType = dyn_cast<MemRefType>(originalType))
257+ return packRankedMemRefDesc (builder, memrefType, inputs, loc, *this );
258+ if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
259+ return packUnrankedMemRefDesc (builder, unrankedMemrefType, inputs, loc,
260+ *this );
261+ return Value ();
220262 });
221263
222264 // Integer memory spaces map to themselves.
0 commit comments