@@ -44,6 +44,74 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
4444 const DataLayoutAnalysis *analysis)
4545 : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
4646
47+ // / Helper function that checks if the given value range is a bare pointer.
48+ static bool isBarePointer (ValueRange values) {
49+ return values.size () == 1 &&
50+ isa<LLVM::LLVMPointerType>(values.front ().getType ());
51+ };
52+
53+ // / Pack SSA values into an unranked memref descriptor struct.
54+ static Value packUnrankedMemRefDesc (OpBuilder &builder,
55+ UnrankedMemRefType resultType,
56+ ValueRange inputs, Location loc,
57+ const LLVMTypeConverter &converter) {
58+ // Note: Bare pointers are not supported for unranked memrefs because a
59+ // memref descriptor cannot be built just from a bare pointer.
60+ if (TypeRange (inputs) != converter.getUnrankedMemRefDescriptorFields ())
61+ return Value ();
62+ return UnrankedMemRefDescriptor::pack (builder, loc, converter, resultType,
63+ inputs);
64+ }
65+
66+ // / Pack SSA values into a ranked memref descriptor struct.
67+ static Value packRankedMemRefDesc (OpBuilder &builder, MemRefType resultType,
68+ ValueRange inputs, Location loc,
69+ const LLVMTypeConverter &converter) {
70+ assert (resultType && " expected non-null result type" );
71+ if (isBarePointer (inputs))
72+ return MemRefDescriptor::fromStaticShape (builder, loc, converter,
73+ resultType, inputs[0 ]);
74+ if (TypeRange (inputs) ==
75+ converter.getMemRefDescriptorFields (resultType,
76+ /* unpackAggregates=*/ true ))
77+ return MemRefDescriptor::pack (builder, loc, converter, resultType, inputs);
78+ // The inputs are neither a bare pointer nor an unpacked memref descriptor.
79+ // This materialization function cannot be used.
80+ return Value ();
81+ }
82+
83+ // / MemRef descriptor elements -> UnrankedMemRefType
84+ static Value unrankedMemRefMaterialization (OpBuilder &builder,
85+ UnrankedMemRefType resultType,
86+ ValueRange inputs, Location loc,
87+ const LLVMTypeConverter &converter) {
88+ // An argument materialization must return a value of type
89+ // `resultType`, so insert a cast from the memref descriptor type
90+ // (!llvm.struct) to the original memref type.
91+ Value packed =
92+ packUnrankedMemRefDesc (builder, resultType, inputs, loc, converter);
93+ if (!packed)
94+ return Value ();
95+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, packed)
96+ .getResult (0 );
97+ };
98+
99+ // / MemRef descriptor elements -> MemRefType
100+ static Value rankedMemRefMaterialization (OpBuilder &builder,
101+ MemRefType resultType,
102+ ValueRange inputs, Location loc,
103+ const LLVMTypeConverter &converter) {
104+ // An argument materialization must return a value of type `resultType`,
105+ // so insert a cast from the memref descriptor type (!llvm.struct) to the
106+ // original memref type.
107+ Value packed =
108+ packRankedMemRefDesc (builder, resultType, inputs, loc, converter);
109+ if (!packed)
110+ return Value ();
111+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, packed)
112+ .getResult (0 );
113+ }
114+
47115// / Create an LLVMTypeConverter using custom LowerToLLVMOptions.
48116LLVMTypeConverter::LLVMTypeConverter (MLIRContext *ctx,
49117 const LowerToLLVMOptions &options,
@@ -166,81 +234,29 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
166234 .getResult (0 );
167235 });
168236
169- // Helper function that checks if the given value range is a bare pointer.
170- auto isBarePointer = [](ValueRange values) {
171- return values.size () == 1 &&
172- isa<LLVM::LLVMPointerType>(values.front ().getType ());
173- };
174-
175- // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
176- // must be passed explicitly.
177- auto packUnrankedMemRefDesc =
178- [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
179- Location loc, LLVMTypeConverter &converter) -> Value {
180- // Note: Bare pointers are not supported for unranked memrefs because a
181- // memref descriptor cannot be built just from a bare pointer.
182- if (TypeRange (inputs) != converter.getUnrankedMemRefDescriptorFields ())
183- return Value ();
184- return UnrankedMemRefDescriptor::pack (builder, loc, converter, resultType,
185- inputs);
186- };
187-
188- // MemRef descriptor elements -> UnrankedMemRefType
189- auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
190- UnrankedMemRefType resultType,
191- ValueRange inputs, Location loc) {
192- // An argument materialization must return a value of type
193- // `resultType`, so insert a cast from the memref descriptor type
194- // (!llvm.struct) to the original memref type.
195- Value packed =
196- packUnrankedMemRefDesc (builder, resultType, inputs, loc, *this );
197- if (!packed)
198- return Value ();
199- return builder.create <UnrealizedConversionCastOp>(loc, resultType, packed)
200- .getResult (0 );
201- };
202-
203- // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
204- // must be passed explicitly.
205- auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
206- ValueRange inputs, Location loc,
207- LLVMTypeConverter &converter) -> Value {
208- assert (resultType && " expected non-null result type" );
209- if (isBarePointer (inputs))
210- return MemRefDescriptor::fromStaticShape (builder, loc, converter,
211- resultType, inputs[0 ]);
212- if (TypeRange (inputs) ==
213- converter.getMemRefDescriptorFields (resultType,
214- /* unpackAggregates=*/ true ))
215- return MemRefDescriptor::pack (builder, loc, converter, resultType,
216- inputs);
217- // The inputs are neither a bare pointer nor an unpacked memref descriptor.
218- // This materialization function cannot be used.
219- return Value ();
220- };
221-
222- // MemRef descriptor elements -> MemRefType
223- auto rankedMemRefMaterialization = [&](OpBuilder &builder,
224- MemRefType resultType,
225- ValueRange inputs, Location loc) {
226- // An argument materialization must return a value of type `resultType`,
227- // so insert a cast from the memref descriptor type (!llvm.struct) to the
228- // original memref type.
229- Value packed =
230- packRankedMemRefDesc (builder, resultType, inputs, loc, *this );
231- if (!packed)
232- return Value ();
233- return builder.create <UnrealizedConversionCastOp>(loc, resultType, packed)
234- .getResult (0 );
235- };
236-
237237 // Argument materializations convert from the new block argument types
238238 // (multiple SSA values that make up a memref descriptor) back to the
239239 // original block argument type.
240- addArgumentMaterialization (unrakedMemRefMaterialization);
241- addArgumentMaterialization (rankedMemRefMaterialization);
242- addSourceMaterialization (unrakedMemRefMaterialization);
243- addSourceMaterialization (rankedMemRefMaterialization);
240+ addArgumentMaterialization ([&](OpBuilder &builder,
241+ UnrankedMemRefType resultType,
242+ ValueRange inputs, Location loc) {
243+ return unrankedMemRefMaterialization (builder, resultType, inputs, loc,
244+ *this );
245+ });
246+ addArgumentMaterialization ([&](OpBuilder &builder, MemRefType resultType,
247+ ValueRange inputs, Location loc) {
248+ return rankedMemRefMaterialization (builder, resultType, inputs, loc, *this );
249+ });
250+ addSourceMaterialization ([&](OpBuilder &builder,
251+ UnrankedMemRefType resultType, ValueRange inputs,
252+ Location loc) {
253+ return unrankedMemRefMaterialization (builder, resultType, inputs, loc,
254+ *this );
255+ });
256+ addSourceMaterialization ([&](OpBuilder &builder, MemRefType resultType,
257+ ValueRange inputs, Location loc) {
258+ return rankedMemRefMaterialization (builder, resultType, inputs, loc, *this );
259+ });
244260
245261 // Bare pointer -> Packed MemRef descriptor
246262 addTargetMaterialization ([&](OpBuilder &builder, Type resultType,
0 commit comments