Skip to content

Commit 1b8a4aa

Browse files
authored
[flang][cuda] Extract element count computation into helper function (#168937)
This patch extracts the common logic for computing array element counts from shape operands into a reusable helper function in CUFCommon.
1 parent 7e43715 commit 1b8a4aa

File tree

3 files changed

+47
-25
lines changed

3 files changed

+47
-25
lines changed

flang/include/flang/Optimizer/Builder/CUFCommon.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ int computeElementByteSize(mlir::Location loc, mlir::Type type,
3939
fir::KindMapping &kindMap,
4040
bool emitErrorOnFailure = true);
4141

42+
mlir::Value computeElementCount(mlir::PatternRewriter &rewriter,
43+
mlir::Location loc, mlir::Value shapeOperand,
44+
mlir::Type seqType, mlir::Type targetType);
45+
4246
} // namespace cuf
4347

4448
#endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_

flang/lib/Optimizer/Builder/CUFCommon.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,44 @@ int cuf::computeElementByteSize(mlir::Location loc, mlir::Type type,
114114
mlir::emitError(loc, "unsupported type");
115115
return 0;
116116
}
117+
118+
mlir::Value cuf::computeElementCount(mlir::PatternRewriter &rewriter,
119+
mlir::Location loc,
120+
mlir::Value shapeOperand,
121+
mlir::Type seqType,
122+
mlir::Type targetType) {
123+
if (shapeOperand) {
124+
// Dynamic extent - extract from shape operand
125+
llvm::SmallVector<mlir::Value> extents;
126+
if (auto shapeOp =
127+
mlir::dyn_cast<fir::ShapeOp>(shapeOperand.getDefiningOp())) {
128+
extents = shapeOp.getExtents();
129+
} else if (auto shapeShiftOp = mlir::dyn_cast<fir::ShapeShiftOp>(
130+
shapeOperand.getDefiningOp())) {
131+
for (auto i : llvm::enumerate(shapeShiftOp.getPairs()))
132+
if (i.index() & 1)
133+
extents.push_back(i.value());
134+
}
135+
136+
if (extents.empty())
137+
return mlir::Value();
138+
139+
// Compute total element count by multiplying all dimensions
140+
mlir::Value count =
141+
fir::ConvertOp::create(rewriter, loc, targetType, extents[0]);
142+
for (unsigned i = 1; i < extents.size(); ++i) {
143+
auto operand =
144+
fir::ConvertOp::create(rewriter, loc, targetType, extents[i]);
145+
count = mlir::arith::MulIOp::create(rewriter, loc, count, operand);
146+
}
147+
return count;
148+
} else {
149+
// Static extent - use constant array size
150+
if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(seqType)) {
151+
mlir::IntegerAttr attr =
152+
rewriter.getIntegerAttr(targetType, seqTy.getConstantArraySize());
153+
return mlir::arith::ConstantOp::create(rewriter, loc, targetType, attr);
154+
}
155+
}
156+
return mlir::Value();
157+
}

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -651,31 +651,8 @@ struct CUFDataTransferOpConversion
651651
}
652652

653653
mlir::Type i64Ty = builder.getI64Type();
654-
mlir::Value nbElement;
655-
if (op.getShape()) {
656-
llvm::SmallVector<mlir::Value> extents;
657-
if (auto shapeOp =
658-
mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp())) {
659-
extents = shapeOp.getExtents();
660-
} else if (auto shapeShiftOp = mlir::dyn_cast<fir::ShapeShiftOp>(
661-
op.getShape().getDefiningOp())) {
662-
for (auto i : llvm::enumerate(shapeShiftOp.getPairs()))
663-
if (i.index() & 1)
664-
extents.push_back(i.value());
665-
}
666-
667-
nbElement = fir::ConvertOp::create(rewriter, loc, i64Ty, extents[0]);
668-
for (unsigned i = 1; i < extents.size(); ++i) {
669-
auto operand =
670-
fir::ConvertOp::create(rewriter, loc, i64Ty, extents[i]);
671-
nbElement =
672-
mlir::arith::MulIOp::create(rewriter, loc, nbElement, operand);
673-
}
674-
} else {
675-
if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(dstTy))
676-
nbElement = builder.createIntegerConstant(
677-
loc, i64Ty, seqTy.getConstantArraySize());
678-
}
654+
mlir::Value nbElement =
655+
cuf::computeElementCount(rewriter, loc, op.getShape(), dstTy, i64Ty);
679656
unsigned width = 0;
680657
if (fir::isa_derived(fir::unwrapSequenceType(dstTy))) {
681658
mlir::Type structTy =

0 commit comments

Comments
 (0)