From 396107623e53c605f2107a5cd6a057ddfd791574 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Mon, 10 Nov 2025 13:15:09 -0800 Subject: [PATCH 1/2] [flang][CUDA] Unify element size computation in CUF helpers Refactor computeWidth from CUFOpConversion into a shared helper function computeElementByteSize in CUFCommon. --- .../flang/Optimizer/Builder/CUFCommon.h | 5 +++ flang/lib/Optimizer/Builder/CUFCommon.cpp | 23 +++++++++++++ .../Optimizer/Transforms/CUFOpConversion.cpp | 32 +++---------------- 3 files changed, 33 insertions(+), 27 deletions(-) diff --git a/flang/include/flang/Optimizer/Builder/CUFCommon.h b/flang/include/flang/Optimizer/Builder/CUFCommon.h index 5c56dd6b695f8..6e2442745f9a0 100644 --- a/flang/include/flang/Optimizer/Builder/CUFCommon.h +++ b/flang/include/flang/Optimizer/Builder/CUFCommon.h @@ -18,6 +18,7 @@ static constexpr llvm::StringRef cudaSharedMemSuffix = "__shared_mem"; namespace fir { class FirOpBuilder; +class KindMapping; } // namespace fir namespace cuf { @@ -34,6 +35,10 @@ bool isRegisteredDeviceAttr(std::optional attr); void genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder); +int computeElementByteSize(mlir::Location loc, mlir::Type type, + fir::KindMapping &kindMap, + bool emitErrorOnFailure = true); + } // namespace cuf #endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_ diff --git a/flang/lib/Optimizer/Builder/CUFCommon.cpp b/flang/lib/Optimizer/Builder/CUFCommon.cpp index cf7588f275d22..461deb8e4cb55 100644 --- a/flang/lib/Optimizer/Builder/CUFCommon.cpp +++ b/flang/lib/Optimizer/Builder/CUFCommon.cpp @@ -9,6 +9,7 @@ #include "flang/Optimizer/Builder/CUFCommon.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Dialect/CUF/CUFOps.h" +#include "flang/Optimizer/Dialect/Support/KindMapping.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" @@ -91,3 +92,25 @@ void cuf::genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder) { } } } + +int cuf::computeElementByteSize(mlir::Location loc, mlir::Type type, + fir::KindMapping &kindMap, + bool emitErrorOnFailure) { + auto eleTy = fir::unwrapSequenceType(type); + if (auto t{mlir::dyn_cast(eleTy)}) + return t.getWidth() / 8; + if (auto t{mlir::dyn_cast(eleTy)}) + return t.getWidth() / 8; + if (auto t{mlir::dyn_cast(eleTy)}) + return kindMap.getLogicalBitsize(t.getFKind()) / 8; + if (auto t{mlir::dyn_cast(eleTy)}) { + int elemSize = + mlir::cast(t.getElementType()).getWidth() / 8; + return 2 * elemSize; + } + if (auto t{mlir::dyn_cast(eleTy)}) + return kindMap.getCharacterBitsize(t.getFKind()) / 8; + if (emitErrorOnFailure) + mlir::emitError(loc, "unsupported type"); + return 0; +} diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index 8d00272b09f42..a61b337fc4152 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -263,28 +263,6 @@ static bool inDeviceContext(mlir::Operation *op) { return false; } -static int computeWidth(mlir::Location loc, mlir::Type type, - fir::KindMapping &kindMap) { - auto eleTy = fir::unwrapSequenceType(type); - if (auto t{mlir::dyn_cast(eleTy)}) - return t.getWidth() / 8; - if (auto t{mlir::dyn_cast(eleTy)}) - return t.getWidth() / 8; - if (eleTy.isInteger(1)) - return 1; - if (auto t{mlir::dyn_cast(eleTy)}) - return kindMap.getLogicalBitsize(t.getFKind()) / 8; - if (auto t{mlir::dyn_cast(eleTy)}) { - int elemSize = - mlir::cast(t.getElementType()).getWidth() / 8; - return 2 * elemSize; - } - if (auto t{mlir::dyn_cast_or_null(eleTy)}) - return kindMap.getCharacterBitsize(t.getFKind()) / 8; - mlir::emitError(loc, "unsupported type"); - return 0; -} - struct CUFAllocOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -320,7 +298,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern { mlir::Value bytes; fir::KindMapping kindMap{fir::getKindMapping(mod)}; if (fir::isa_trivial(op.getInType())) { - int width = computeWidth(loc, op.getInType(), kindMap); + int width = cuf::computeElementByteSize(loc, op.getInType(), kindMap); bytes = builder.createIntegerConstant(loc, builder.getIndexType(), width); } else if (auto seqTy = mlir::dyn_cast_or_null( @@ -330,7 +308,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern { mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy()); size = dl->getTypeSizeInBits(structTy) / 8; } else { - size = computeWidth(loc, seqTy.getEleTy(), kindMap); + size = cuf::computeElementByteSize(loc, seqTy.getEleTy(), kindMap); } mlir::Value width = builder.createIntegerConstant(loc, builder.getIndexType(), size); @@ -619,8 +597,8 @@ struct CUFDataTransferOpConversion const mlir::SymbolTable &symtab, mlir::DataLayout *dl, const fir::LLVMTypeConverter *typeConverter) - : OpRewritePattern(context), symtab{symtab}, dl{dl}, - typeConverter{typeConverter} {} + : OpRewritePattern(context), symtab{symtab}, dl{dl}, typeConverter{ + typeConverter} {} mlir::LogicalResult matchAndRewrite(cuf::DataTransferOp op, @@ -704,7 +682,7 @@ struct CUFDataTransferOpConversion typeConverter->convertType(fir::unwrapSequenceType(dstTy)); width = dl->getTypeSizeInBits(structTy) / 8; } else { - width = computeWidth(loc, dstTy, kindMap); + width = cuf::computeElementByteSize(loc, dstTy, kindMap); } mlir::Value widthValue = mlir::arith::ConstantOp::create( rewriter, loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width)); From c68eaba4e223254b65af12d0b407d4d1689fb664 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Mon, 10 Nov 2025 14:10:45 -0800 Subject: [PATCH 2/2] format --- flang/lib/Optimizer/Transforms/CUFOpConversion.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index a61b337fc4152..5b1b0a2f6feab 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -597,8 +597,8 @@ struct CUFDataTransferOpConversion const mlir::SymbolTable &symtab, mlir::DataLayout *dl, const fir::LLVMTypeConverter *typeConverter) - : OpRewritePattern(context), symtab{symtab}, dl{dl}, typeConverter{ - typeConverter} {} + : OpRewritePattern(context), symtab{symtab}, dl{dl}, + typeConverter{typeConverter} {} mlir::LogicalResult matchAndRewrite(cuf::DataTransferOp op,