Skip to content

Commit 3961076

Browse files
committed
[flang][CUDA] Unify element size computation in CUF helpers
Refactor computeWidth from CUFOpConversion into a shared helper function computeElementByteSize in CUFCommon.
1 parent f63d33d commit 3961076

File tree

3 files changed

+33
-27
lines changed

3 files changed

+33
-27
lines changed

flang/include/flang/Optimizer/Builder/CUFCommon.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ static constexpr llvm::StringRef cudaSharedMemSuffix = "__shared_mem";
1818

1919
namespace fir {
2020
class FirOpBuilder;
21+
class KindMapping;
2122
} // namespace fir
2223

2324
namespace cuf {
@@ -34,6 +35,10 @@ bool isRegisteredDeviceAttr(std::optional<cuf::DataAttribute> attr);
3435

3536
void genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder);
3637

38+
int computeElementByteSize(mlir::Location loc, mlir::Type type,
39+
fir::KindMapping &kindMap,
40+
bool emitErrorOnFailure = true);
41+
3742
} // namespace cuf
3843

3944
#endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_

flang/lib/Optimizer/Builder/CUFCommon.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "flang/Optimizer/Builder/CUFCommon.h"
1010
#include "flang/Optimizer/Builder/FIRBuilder.h"
1111
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
12+
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
1213
#include "flang/Optimizer/HLFIR/HLFIROps.h"
1314
#include "mlir/Dialect/Func/IR/FuncOps.h"
1415
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
@@ -91,3 +92,25 @@ void cuf::genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder) {
9192
}
9293
}
9394
}
95+
96+
int cuf::computeElementByteSize(mlir::Location loc, mlir::Type type,
97+
fir::KindMapping &kindMap,
98+
bool emitErrorOnFailure) {
99+
auto eleTy = fir::unwrapSequenceType(type);
100+
if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
101+
return t.getWidth() / 8;
102+
if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
103+
return t.getWidth() / 8;
104+
if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
105+
return kindMap.getLogicalBitsize(t.getFKind()) / 8;
106+
if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
107+
int elemSize =
108+
mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
109+
return 2 * elemSize;
110+
}
111+
if (auto t{mlir::dyn_cast<fir::CharacterType>(eleTy)})
112+
return kindMap.getCharacterBitsize(t.getFKind()) / 8;
113+
if (emitErrorOnFailure)
114+
mlir::emitError(loc, "unsupported type");
115+
return 0;
116+
}

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -263,28 +263,6 @@ static bool inDeviceContext(mlir::Operation *op) {
263263
return false;
264264
}
265265

266-
static int computeWidth(mlir::Location loc, mlir::Type type,
267-
fir::KindMapping &kindMap) {
268-
auto eleTy = fir::unwrapSequenceType(type);
269-
if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
270-
return t.getWidth() / 8;
271-
if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
272-
return t.getWidth() / 8;
273-
if (eleTy.isInteger(1))
274-
return 1;
275-
if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
276-
return kindMap.getLogicalBitsize(t.getFKind()) / 8;
277-
if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
278-
int elemSize =
279-
mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
280-
return 2 * elemSize;
281-
}
282-
if (auto t{mlir::dyn_cast_or_null<fir::CharacterType>(eleTy)})
283-
return kindMap.getCharacterBitsize(t.getFKind()) / 8;
284-
mlir::emitError(loc, "unsupported type");
285-
return 0;
286-
}
287-
288266
struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
289267
using OpRewritePattern::OpRewritePattern;
290268

@@ -320,7 +298,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
320298
mlir::Value bytes;
321299
fir::KindMapping kindMap{fir::getKindMapping(mod)};
322300
if (fir::isa_trivial(op.getInType())) {
323-
int width = computeWidth(loc, op.getInType(), kindMap);
301+
int width = cuf::computeElementByteSize(loc, op.getInType(), kindMap);
324302
bytes =
325303
builder.createIntegerConstant(loc, builder.getIndexType(), width);
326304
} else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
@@ -330,7 +308,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
330308
mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy());
331309
size = dl->getTypeSizeInBits(structTy) / 8;
332310
} else {
333-
size = computeWidth(loc, seqTy.getEleTy(), kindMap);
311+
size = cuf::computeElementByteSize(loc, seqTy.getEleTy(), kindMap);
334312
}
335313
mlir::Value width =
336314
builder.createIntegerConstant(loc, builder.getIndexType(), size);
@@ -619,8 +597,8 @@ struct CUFDataTransferOpConversion
619597
const mlir::SymbolTable &symtab,
620598
mlir::DataLayout *dl,
621599
const fir::LLVMTypeConverter *typeConverter)
622-
: OpRewritePattern(context), symtab{symtab}, dl{dl},
623-
typeConverter{typeConverter} {}
600+
: OpRewritePattern(context), symtab{symtab}, dl{dl}, typeConverter{
601+
typeConverter} {}
624602

625603
mlir::LogicalResult
626604
matchAndRewrite(cuf::DataTransferOp op,
@@ -704,7 +682,7 @@ struct CUFDataTransferOpConversion
704682
typeConverter->convertType(fir::unwrapSequenceType(dstTy));
705683
width = dl->getTypeSizeInBits(structTy) / 8;
706684
} else {
707-
width = computeWidth(loc, dstTy, kindMap);
685+
width = cuf::computeElementByteSize(loc, dstTy, kindMap);
708686
}
709687
mlir::Value widthValue = mlir::arith::ConstantOp::create(
710688
rewriter, loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width));

0 commit comments

Comments
 (0)