@@ -125,10 +125,171 @@ struct PrivateClauseOpConversion
125
125
return mlir::success ();
126
126
}
127
127
};
128
+
129
+ static mlir::LLVM::LLVMFuncOp getOmpTargetAlloc (mlir::Operation *op) {
130
+ auto module = op->getParentOfType <mlir::ModuleOp>();
131
+ if (mlir::LLVM::LLVMFuncOp mallocFunc =
132
+ module .lookupSymbol <mlir::LLVM::LLVMFuncOp>(" omp_target_alloc" ))
133
+ return mallocFunc;
134
+ mlir::OpBuilder moduleBuilder (module .getBodyRegion ());
135
+ auto i64Ty = mlir::IntegerType::get (module ->getContext (), 64 );
136
+ auto i32Ty = mlir::IntegerType::get (module ->getContext (), 32 );
137
+ return moduleBuilder.create <mlir::LLVM::LLVMFuncOp>(
138
+ moduleBuilder.getUnknownLoc (), " omp_target_alloc" ,
139
+ mlir::LLVM::LLVMFunctionType::get (
140
+ mlir::LLVM::LLVMPointerType::get (module ->getContext ()),
141
+ {i64Ty, i32Ty},
142
+ /* isVarArg=*/ false ));
143
+ }
144
+
145
+ static mlir::Type
146
+ convertObjectType (const fir::LLVMTypeConverter &converter, mlir::Type firType) {
147
+ if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(firType))
148
+ return converter.convertBoxTypeAsStruct (boxTy);
149
+ return converter.convertType (firType);
150
+ }
151
+
152
+ static llvm::SmallVector<mlir::NamedAttribute>
153
+ addLLVMOpBundleAttrs (mlir::ConversionPatternRewriter &rewriter,
154
+ llvm::ArrayRef<mlir::NamedAttribute> attrs,
155
+ int32_t numCallOperands) {
156
+ llvm::SmallVector<mlir::NamedAttribute> newAttrs;
157
+ newAttrs.reserve (attrs.size () + 2 );
158
+
159
+ for (mlir::NamedAttribute attr : attrs) {
160
+ if (attr.getName () != " operandSegmentSizes" )
161
+ newAttrs.push_back (attr);
162
+ }
163
+
164
+ newAttrs.push_back (rewriter.getNamedAttr (
165
+ " operandSegmentSizes" ,
166
+ rewriter.getDenseI32ArrayAttr ({numCallOperands, 0 })));
167
+ newAttrs.push_back (rewriter.getNamedAttr (" op_bundle_sizes" ,
168
+ rewriter.getDenseI32ArrayAttr ({})));
169
+ return newAttrs;
170
+ }
171
+
172
+ static mlir::LLVM::ConstantOp
173
+ genConstantIndex (mlir::Location loc, mlir::Type ity,
174
+ mlir::ConversionPatternRewriter &rewriter,
175
+ std::int64_t offset) {
176
+ auto cattr = rewriter.getI64IntegerAttr (offset);
177
+ return rewriter.create <mlir::LLVM::ConstantOp>(loc, ity, cattr);
178
+ }
179
+
180
+ static mlir::Value
181
+ computeElementDistance (mlir::Location loc, mlir::Type llvmObjectType,
182
+ mlir::Type idxTy,
183
+ mlir::ConversionPatternRewriter &rewriter,
184
+ const mlir::DataLayout &dataLayout) {
185
+ llvm::TypeSize size = dataLayout.getTypeSize (llvmObjectType);
186
+ unsigned short alignment = dataLayout.getTypeABIAlignment (llvmObjectType);
187
+ std::int64_t distance = llvm::alignTo (size, alignment);
188
+ return genConstantIndex (loc, idxTy, rewriter, distance);
189
+ }
190
+
191
+ static mlir::Value genTypeSizeInBytes (mlir::Location loc, mlir::Type idxTy,
192
+ mlir::ConversionPatternRewriter &rewriter,
193
+ mlir::Type llTy, const mlir::DataLayout &dataLayout) {
194
+ return computeElementDistance (loc, llTy, idxTy, rewriter, dataLayout);
195
+ }
196
+
197
+ template <typename OP>
198
+ static mlir::Value
199
+ genAllocationScaleSize (OP op, mlir::Type ity,
200
+ mlir::ConversionPatternRewriter &rewriter) {
201
+ mlir::Location loc = op.getLoc ();
202
+ mlir::Type dataTy = op.getInType ();
203
+ auto seqTy = mlir::dyn_cast<fir::SequenceType>(dataTy);
204
+ fir::SequenceType::Extent constSize = 1 ;
205
+ if (seqTy) {
206
+ int constRows = seqTy.getConstantRows ();
207
+ const fir::SequenceType::ShapeRef &shape = seqTy.getShape ();
208
+ if (constRows != static_cast <int >(shape.size ())) {
209
+ for (auto extent : shape) {
210
+ if (constRows-- > 0 )
211
+ continue ;
212
+ if (extent != fir::SequenceType::getUnknownExtent ())
213
+ constSize *= extent;
214
+ }
215
+ }
216
+ }
217
+
218
+ if (constSize != 1 ) {
219
+ mlir::Value constVal{
220
+ genConstantIndex (loc, ity, rewriter, constSize).getResult ()};
221
+ return constVal;
222
+ }
223
+ return nullptr ;
224
+ }
225
+
226
+ static mlir::Value integerCast (const fir::LLVMTypeConverter &converter,
227
+ mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
228
+ mlir::Type ty, mlir::Value val, bool fold = false ) {
229
+ auto valTy = val.getType ();
230
+ // If the value was not yet lowered, lower its type so that it can
231
+ // be used in getPrimitiveTypeSizeInBits.
232
+ if (!mlir::isa<mlir::IntegerType>(valTy))
233
+ valTy = converter.convertType (valTy);
234
+ auto toSize = mlir::LLVM::getPrimitiveTypeSizeInBits (ty);
235
+ auto fromSize = mlir::LLVM::getPrimitiveTypeSizeInBits (valTy);
236
+ if (fold) {
237
+ if (toSize < fromSize)
238
+ return rewriter.createOrFold <mlir::LLVM::TruncOp>(loc, ty, val);
239
+ if (toSize > fromSize)
240
+ return rewriter.createOrFold <mlir::LLVM::SExtOp>(loc, ty, val);
241
+ } else {
242
+ if (toSize < fromSize)
243
+ return rewriter.create <mlir::LLVM::TruncOp>(loc, ty, val);
244
+ if (toSize > fromSize)
245
+ return rewriter.create <mlir::LLVM::SExtOp>(loc, ty, val);
246
+ }
247
+ return val;
248
+ }
249
+
250
+ // FIR Op specific conversion for TargetAllocMemOp
251
+ struct TargetAllocMemOpConversion
252
+ : public OpenMPFIROpConversion<mlir::omp::TargetAllocMemOp> {
253
+ using OpenMPFIROpConversion::OpenMPFIROpConversion;
254
+
255
+ llvm::LogicalResult
256
+ matchAndRewrite (mlir::omp::TargetAllocMemOp allocmemOp, OpAdaptor adaptor,
257
+ mlir::ConversionPatternRewriter &rewriter) const override {
258
+ mlir::Type heapTy = allocmemOp.getAllocatedType ();
259
+ mlir::LLVM::LLVMFuncOp mallocFunc = getOmpTargetAlloc (allocmemOp);
260
+ mlir::Location loc = allocmemOp.getLoc ();
261
+ auto ity = lowerTy ().indexType ();
262
+ mlir::Type dataTy = fir::unwrapRefType (heapTy);
263
+ mlir::Type llvmObjectTy = convertObjectType (lowerTy (), dataTy);
264
+ mlir::Type llvmPtrTy = mlir::LLVM::LLVMPointerType::get (allocmemOp.getContext (), 0 );
265
+ if (fir::isRecordWithTypeParameters (fir::unwrapSequenceType (dataTy)))
266
+ TODO (loc, " omp.target_allocmem codegen of derived type with length "
267
+ " parameters" );
268
+ mlir::Value size = genTypeSizeInBytes (loc, ity, rewriter, llvmObjectTy, lowerTy ().getDataLayout ());
269
+ if (auto scaleSize = genAllocationScaleSize (allocmemOp, ity, rewriter))
270
+ size = rewriter.create <mlir::LLVM::MulOp>(loc, ity, size, scaleSize);
271
+ for (mlir::Value opnd : adaptor.getOperands ().drop_front ())
272
+ size = rewriter.create <mlir::LLVM::MulOp>(
273
+ loc, ity, size, integerCast (lowerTy (), loc, rewriter, ity, opnd));
274
+ auto mallocTyWidth = lowerTy ().getIndexTypeBitwidth ();
275
+ auto mallocTy =
276
+ mlir::IntegerType::get (rewriter.getContext (), mallocTyWidth);
277
+ if (mallocTyWidth != ity.getIntOrFloatBitWidth ())
278
+ size = integerCast (lowerTy (), loc, rewriter, mallocTy, size);
279
+ allocmemOp->setAttr (" callee" , mlir::SymbolRefAttr::get (mallocFunc));
280
+ auto callOp = rewriter.create <mlir::LLVM::CallOp>(
281
+ loc, llvmPtrTy,
282
+ mlir::SmallVector<mlir::Value, 2 >({size, allocmemOp.getDevice ()}),
283
+ addLLVMOpBundleAttrs (rewriter, allocmemOp->getAttrs (), 2 ));
284
+ rewriter.replaceOpWithNewOp <mlir::LLVM::PtrToIntOp>(allocmemOp, rewriter.getIntegerType (64 ), callOp.getResult ());
285
+ return mlir::success ();
286
+ }
287
+ };
128
288
} // namespace
129
289
130
290
void fir::populateOpenMPFIRToLLVMConversionPatterns (
131
291
const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
132
292
patterns.add <MapInfoOpConversion>(converter);
133
293
patterns.add <PrivateClauseOpConversion>(converter);
294
+ patterns.add <TargetAllocMemOpConversion>(converter);
134
295
}
0 commit comments