@@ -263,28 +263,6 @@ static bool inDeviceContext(mlir::Operation *op) {
263263 return false ;
264264}
265265
266- static int computeWidth (mlir::Location loc, mlir::Type type,
267- fir::KindMapping &kindMap) {
268- auto eleTy = fir::unwrapSequenceType (type);
269- if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
270- return t.getWidth () / 8 ;
271- if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
272- return t.getWidth () / 8 ;
273- if (eleTy.isInteger (1 ))
274- return 1 ;
275- if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
276- return kindMap.getLogicalBitsize (t.getFKind ()) / 8 ;
277- if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
278- int elemSize =
279- mlir::cast<mlir::FloatType>(t.getElementType ()).getWidth () / 8 ;
280- return 2 * elemSize;
281- }
282- if (auto t{mlir::dyn_cast_or_null<fir::CharacterType>(eleTy)})
283- return kindMap.getCharacterBitsize (t.getFKind ()) / 8 ;
284- mlir::emitError (loc, " unsupported type" );
285- return 0 ;
286- }
287-
288266struct CUFAllocOpConversion : public mlir ::OpRewritePattern<cuf::AllocOp> {
289267 using OpRewritePattern::OpRewritePattern;
290268
@@ -320,7 +298,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
320298 mlir::Value bytes;
321299 fir::KindMapping kindMap{fir::getKindMapping (mod)};
322300 if (fir::isa_trivial (op.getInType ())) {
323- int width = computeWidth (loc, op.getInType (), kindMap);
301+ int width = cuf::computeElementByteSize (loc, op.getInType (), kindMap);
324302 bytes =
325303 builder.createIntegerConstant (loc, builder.getIndexType (), width);
326304 } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
@@ -330,7 +308,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
330308 mlir::Type structTy = typeConverter->convertType (seqTy.getEleTy ());
331309 size = dl->getTypeSizeInBits (structTy) / 8 ;
332310 } else {
333- size = computeWidth (loc, seqTy.getEleTy (), kindMap);
311+ size = cuf::computeElementByteSize (loc, seqTy.getEleTy (), kindMap);
334312 }
335313 mlir::Value width =
336314 builder.createIntegerConstant (loc, builder.getIndexType (), size);
@@ -619,8 +597,8 @@ struct CUFDataTransferOpConversion
619597 const mlir::SymbolTable &symtab,
620598 mlir::DataLayout *dl,
621599 const fir::LLVMTypeConverter *typeConverter)
622- : OpRewritePattern(context), symtab{symtab}, dl{dl},
623- typeConverter{ typeConverter} {}
600+ : OpRewritePattern(context), symtab{symtab}, dl{dl}, typeConverter{
601+ typeConverter} {}
624602
625603 mlir::LogicalResult
626604 matchAndRewrite (cuf::DataTransferOp op,
@@ -704,7 +682,7 @@ struct CUFDataTransferOpConversion
704682 typeConverter->convertType (fir::unwrapSequenceType (dstTy));
705683 width = dl->getTypeSizeInBits (structTy) / 8 ;
706684 } else {
707- width = computeWidth (loc, dstTy, kindMap);
685+ width = cuf::computeElementByteSize (loc, dstTy, kindMap);
708686 }
709687 mlir::Value widthValue = mlir::arith::ConstantOp::create (
710688 rewriter, loc, i64Ty, rewriter.getIntegerAttr (i64Ty, width));
0 commit comments