@@ -145,11 +145,45 @@ struct ArgsUsageInLoop {
145145};
146146} // namespace
147147
148- static fir::SequenceType getAsSequenceType (mlir::Value * v) {
149- mlir::Type argTy = fir::unwrapPassByRefType (fir::unwrapRefType (v-> getType ()));
148+ static fir::SequenceType getAsSequenceType (mlir::Value v) {
149+ mlir::Type argTy = fir::unwrapPassByRefType (fir::unwrapRefType (v. getType ()));
150150 return mlir::dyn_cast<fir::SequenceType>(argTy);
151151}
152152
153+ // / Return the rank and the element size (in bytes) of the given
154+ // / value \p v. If it is not an array or the element type is not
155+ // / supported, then return <0, 0>. Only trivial data types
156+ // / are currently supported.
157+ // / When \p isArgument is true, \p v is assumed to be a function
158+ // / argument. If \p v's type does not look like a type of an assumed
159+ // / shape array, then the function returns <0, 0>.
160+ // / When \p isArgument is false, array types with known innermost
161+ // / dimension are allowed to proceed.
162+ static std::pair<unsigned , size_t >
163+ getRankAndElementSize (const fir::KindMapping &kindMap,
164+ const mlir::DataLayout &dl, mlir::Value v,
165+ bool isArgument = false ) {
166+ if (auto seqTy = getAsSequenceType (v)) {
167+ unsigned rank = seqTy.getDimension ();
168+ if (rank > 0 &&
169+ (!isArgument ||
170+ seqTy.getShape ()[0 ] == fir::SequenceType::getUnknownExtent ())) {
171+ size_t typeSize = 0 ;
172+ mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType (v.getType ());
173+ if (fir::isa_trivial (elementType)) {
174+ auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignmentOrCrash (
175+ v.getLoc (), elementType, dl, kindMap);
176+ typeSize = llvm::alignTo (eleSize, eleAlign);
177+ }
178+ if (typeSize)
179+ return {rank, typeSize};
180+ }
181+ }
182+
183+ LLVM_DEBUG (llvm::dbgs () << " Unsupported rank/type: " << v << ' \n ' );
184+ return {0 , 0 };
185+ }
186+
153187// / if a value comes from a fir.declare, follow it to the original source,
154188// / otherwise return the value
155189static mlir::Value unwrapFirDeclare (mlir::Value val) {
@@ -160,12 +194,48 @@ static mlir::Value unwrapFirDeclare(mlir::Value val) {
160194 return val;
161195}
162196
197+ // / Return true, if \p rebox operation keeps the input array
198+ // / continuous in the innermost dimension, if it is initially continuous
199+ // / in the innermost dimension.
200+ static bool reboxPreservesContinuity (fir::ReboxOp rebox) {
201+ // If slicing is not involved, then the rebox does not affect
202+ // the continuity of the array.
203+ auto sliceArg = rebox.getSlice ();
204+ if (!sliceArg)
205+ return true ;
206+
207+ // A slice with step=1 in the innermost dimension preserves
208+ // the continuity of the array in the innermost dimension.
209+ if (auto sliceOp =
210+ mlir::dyn_cast_or_null<fir::SliceOp>(sliceArg.getDefiningOp ())) {
211+ if (sliceOp.getFields ().empty () && sliceOp.getSubstr ().empty ()) {
212+ auto triples = sliceOp.getTriples ();
213+ if (triples.size () > 2 )
214+ if (auto innermostStep = fir::getIntIfConstant (triples[2 ]))
215+ if (*innermostStep == 1 )
216+ return true ;
217+ }
218+
219+ LLVM_DEBUG (llvm::dbgs ()
220+ << " REBOX with slicing may produce non-contiguous array: "
221+ << sliceOp << ' \n '
222+ << rebox << ' \n ' );
223+ return false ;
224+ }
225+
226+ LLVM_DEBUG (llvm::dbgs () << " REBOX with unknown slice" << sliceArg << ' \n '
227+ << rebox << ' \n ' );
228+ return false ;
229+ }
230+
163231// / if a value comes from a fir.rebox, follow the rebox to the original source,
164232// / of the value, otherwise return the value
165233static mlir::Value unwrapReboxOp (mlir::Value val) {
166- // don't support reboxes of reboxes
167- if (fir::ReboxOp rebox = val.getDefiningOp <fir::ReboxOp>())
234+ while (fir::ReboxOp rebox = val.getDefiningOp <fir::ReboxOp>()) {
235+ if (!reboxPreservesContinuity (rebox))
236+ break ;
168237 val = rebox.getBox ();
238+ }
169239 return val;
170240}
171241
@@ -257,25 +327,10 @@ void LoopVersioningPass::runOnOperation() {
257327 continue ;
258328 }
259329
260- if (auto seqTy = getAsSequenceType (&arg)) {
261- unsigned rank = seqTy.getDimension ();
262- if (rank > 0 &&
263- seqTy.getShape ()[0 ] == fir::SequenceType::getUnknownExtent ()) {
264- size_t typeSize = 0 ;
265- mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType (arg.getType ());
266- if (mlir::isa<mlir::FloatType>(elementType) ||
267- mlir::isa<mlir::IntegerType>(elementType) ||
268- mlir::isa<mlir::ComplexType>(elementType)) {
269- auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignmentOrCrash (
270- arg.getLoc (), elementType, *dl, kindMap);
271- typeSize = llvm::alignTo (eleSize, eleAlign);
272- }
273- if (typeSize)
274- argsOfInterest.push_back ({arg, typeSize, rank, {}});
275- else
276- LLVM_DEBUG (llvm::dbgs () << " Type not supported\n " );
277- }
278- }
330+ auto [rank, typeSize] =
331+ getRankAndElementSize (kindMap, *dl, arg, /* isArgument=*/ true );
332+ if (rank != 0 && typeSize != 0 )
333+ argsOfInterest.push_back ({arg, typeSize, rank, {}});
279334 }
280335
281336 if (argsOfInterest.empty ()) {
@@ -326,6 +381,13 @@ void LoopVersioningPass::runOnOperation() {
326381 if (arrayCoor.getSlice ())
327382 argsInLoop.cannotTransform .insert (a.arg );
328383
384+ // We need to compute the rank and element size
385+ // based on the operand, not the original argument,
386+ // because array slicing may affect it.
387+ std::tie (a.rank , a.size ) = getRankAndElementSize (kindMap, *dl, a.arg );
388+ if (a.rank == 0 || a.size == 0 )
389+ argsInLoop.cannotTransform .insert (a.arg );
390+
329391 if (argsInLoop.cannotTransform .contains (a.arg )) {
330392 // Remove any previously recorded usage, if any.
331393 argsInLoop.usageInfo .erase (a.arg );
@@ -416,8 +478,8 @@ void LoopVersioningPass::runOnOperation() {
416478 mlir::Location loc = builder.getUnknownLoc ();
417479 mlir::IndexType idxTy = builder.getIndexType ();
418480
419- LLVM_DEBUG (llvm::dbgs () << " Module Before transformation:" );
420- LLVM_DEBUG (module ->dump ());
481+ LLVM_DEBUG (llvm::dbgs () << " Func Before transformation:\n " );
482+ LLVM_DEBUG (func ->dump ());
421483
422484 LLVM_DEBUG (llvm::dbgs () << " loopsOfInterest: " << loopsOfInterest.size ()
423485 << " \n " );
@@ -551,8 +613,8 @@ void LoopVersioningPass::runOnOperation() {
551613 }
552614 }
553615
554- LLVM_DEBUG (llvm::dbgs () << " After transform:\n " );
555- LLVM_DEBUG (module ->dump ());
616+ LLVM_DEBUG (llvm::dbgs () << " Func After transform:\n " );
617+ LLVM_DEBUG (func ->dump ());
556618
557619 LLVM_DEBUG (llvm::dbgs () << " === End " DEBUG_TYPE " ===\n " );
558620}
0 commit comments