@@ -159,28 +159,162 @@ containsReadOrWriteEffectOn(const mlir::MemoryEffects::EffectInstance &effect,
159159 return mlir::AliasResult::NoAlias;
160160}
161161
162- // Returns true if the given array references represent identical
163- // or completely disjoint array slices. The callers may use this
164- // method when the alias analysis reports an alias of some kind,
165- // so that we can run Fortran specific analysis on the array slices
166- // to see if they are identical or disjoint. Note that the alias
167- // analysis are not able to give such an answer about the references.
168- static bool areIdenticalOrDisjointSlices (mlir::Value ref1, mlir::Value ref2) {
162+ // Helper class for analyzing two array slices represented
163+ // by two hlfir.designate operations.
164+ class ArraySectionAnalyzer {
165+ public:
166+ // The result of the analyzis is one of the values below.
167+ enum class SlicesOverlapKind {
168+ // Slices overlap is unknown.
169+ Unknown,
170+ // Slices are definitely identical.
171+ DefinitelyIdentical,
172+ // Slices are definitely disjoint.
173+ DefinitelyDisjoint,
174+ // Slices may be either disjoint or identical,
175+ // i.e. there is definitely no partial overlap.
176+ EitherIdenticalOrDisjoint
177+ };
178+
179+ // Analyzes two hlfir.designate results and returns the overlap kind.
180+ // The callers may use this method when the alias analysis reports
181+ // an alias of some kind, so that we can run Fortran specific analysis
182+ // on the array slices to see if they are identical or disjoint.
183+ // Note that the alias analysis are not able to give such an answer
184+ // about the references.
185+ static SlicesOverlapKind analyze (mlir::Value ref1, mlir::Value ref2);
186+
187+ private:
188+ struct SectionDesc {
189+ // An array section is described by <lb, ub, stride> tuple.
190+ // If the designator's subscript is not a triple, then
191+ // the section descriptor is constructed as <lb, nullptr, nullptr>.
192+ mlir::Value lb, ub, stride;
193+
194+ SectionDesc (mlir::Value lb, mlir::Value ub, mlir::Value stride)
195+ : lb(lb), ub(ub), stride(stride) {
196+ assert (lb && " lower bound or index must be specified" );
197+ normalize ();
198+ }
199+
200+ // Normalize the section descriptor:
201+ // 1. If UB is nullptr, then it is set to LB.
202+ // 2. If LB==UB, then stride does not matter,
203+ // so it is reset to nullptr.
204+ // 3. If STRIDE==1, then it is reset to nullptr.
205+ void normalize () {
206+ if (!ub)
207+ ub = lb;
208+ if (lb == ub)
209+ stride = nullptr ;
210+ if (stride)
211+ if (auto val = fir::getIntIfConstant (stride))
212+ if (*val == 1 )
213+ stride = nullptr ;
214+ }
215+
216+ bool operator ==(const SectionDesc &other) const {
217+ return lb == other.lb && ub == other.ub && stride == other.stride ;
218+ }
219+ };
220+
221+ // Given an operand_iterator over the indices operands,
222+ // read the subscript values and return them as SectionDesc
223+ // updating the iterator. If isTriplet is true,
224+ // the subscript is a triplet, and the result is <lb, ub, stride>.
225+ // Otherwise, the subscript is a scalar index, and the result
226+ // is <index, nullptr, nullptr>.
227+ static SectionDesc readSectionDesc (mlir::Operation::operand_iterator &it,
228+ bool isTriplet) {
229+ if (isTriplet)
230+ return {*it++, *it++, *it++};
231+ return {*it++, nullptr , nullptr };
232+ }
233+
234+ // Return the ordered lower and upper bounds of the section.
235+ // If stride is known to be non-negative, then the ordered
236+ // bounds match the <lb, ub> of the descriptor.
237+ // If stride is known to be negative, then the ordered
238+ // bounds are <ub, lb> of the descriptor.
239+ // If stride is unknown, we cannot deduce any order,
240+ // so the result is <nullptr, nullptr>
241+ static std::pair<mlir::Value, mlir::Value>
242+ getOrderedBounds (const SectionDesc &desc) {
243+ mlir::Value stride = desc.stride ;
244+ // Null stride means stride=1.
245+ if (!stride)
246+ return {desc.lb , desc.ub };
247+ // Reverse the bounds, if stride is negative.
248+ if (auto val = fir::getIntIfConstant (stride)) {
249+ if (*val >= 0 )
250+ return {desc.lb , desc.ub };
251+ else
252+ return {desc.ub , desc.lb };
253+ }
254+
255+ return {nullptr , nullptr };
256+ }
257+
258+ // Given two array sections <lb1, ub1, stride1> and
259+ // <lb2, ub2, stride2>, return true only if the sections
260+ // are known to be disjoint.
261+ //
262+ // For example, for any positive constant C:
263+ // X:Y does not overlap with (Y+C):Z
264+ // X:Y does not overlap with Z:(X-C)
265+ static bool areDisjointSections (const SectionDesc &desc1,
266+ const SectionDesc &desc2) {
267+ auto [lb1, ub1] = getOrderedBounds (desc1);
268+ auto [lb2, ub2] = getOrderedBounds (desc2);
269+ if (!lb1 || !lb2)
270+ return false ;
271+ // Note that this comparison must be made on the ordered bounds,
272+ // otherwise 'a(x:y:1) = a(z:x-1:-1) + 1' may be incorrectly treated
273+ // as not overlapping (x=2, y=10, z=9).
274+ if (isLess (ub1, lb2) || isLess (ub2, lb1))
275+ return true ;
276+ return false ;
277+ }
278+
279+ // Given two array sections <lb1, ub1, stride1> and
280+ // <lb2, ub2, stride2>, return true only if the sections
281+ // are known to be identical.
282+ //
283+ // For example:
284+ // <x, x, stride>
285+ // <x, nullptr, nullptr>
286+ //
287+ // These sections are identical, from the point of which array
288+ // elements are being addresses, even though the shape
289+ // of the array slices might be different.
290+ static bool areIdenticalSections (const SectionDesc &desc1,
291+ const SectionDesc &desc2) {
292+ if (desc1 == desc2)
293+ return true ;
294+ return false ;
295+ }
296+
297+ // Return true, if v1 is known to be less than v2.
298+ static bool isLess (mlir::Value v1, mlir::Value v2);
299+ };
300+
301+ ArraySectionAnalyzer::SlicesOverlapKind
302+ ArraySectionAnalyzer::analyze (mlir::Value ref1, mlir::Value ref2) {
169303 if (ref1 == ref2)
170- return true ;
304+ return SlicesOverlapKind::DefinitelyIdentical ;
171305
172306 auto des1 = ref1.getDefiningOp <hlfir::DesignateOp>();
173307 auto des2 = ref2.getDefiningOp <hlfir::DesignateOp>();
174308 // We only support a pair of designators right now.
175309 if (!des1 || !des2)
176- return false ;
310+ return SlicesOverlapKind::Unknown ;
177311
178312 if (des1.getMemref () != des2.getMemref ()) {
179313 // If the bases are different, then there is unknown overlap.
180314 LLVM_DEBUG (llvm::dbgs () << " No identical base for:\n "
181315 << des1 << " and:\n "
182316 << des2 << " \n " );
183- return false ;
317+ return SlicesOverlapKind::Unknown ;
184318 }
185319
186320 // Require all components of the designators to be the same.
@@ -194,104 +328,105 @@ static bool areIdenticalOrDisjointSlices(mlir::Value ref1, mlir::Value ref2) {
194328 LLVM_DEBUG (llvm::dbgs () << " Different designator specs for:\n "
195329 << des1 << " and:\n "
196330 << des2 << " \n " );
197- return false ;
198- }
199-
200- if (des1.getIsTriplet () != des2.getIsTriplet ()) {
201- LLVM_DEBUG (llvm::dbgs () << " Different sections for:\n "
202- << des1 << " and:\n "
203- << des2 << " \n " );
204- return false ;
331+ return SlicesOverlapKind::Unknown;
205332 }
206333
207334 // Analyze the subscripts.
208- // For example:
209- // hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %0) shape %9
210- // hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %1) shape %9
211- //
212- // If all the triplets (section speficiers) are the same, then
213- // we do not care if %0 is equal to %1 - the slices are either
214- // identical or completely disjoint.
215335 auto des1It = des1.getIndices ().begin ();
216336 auto des2It = des2.getIndices ().begin ();
217337 bool identicalTriplets = true ;
218- for (bool isTriplet : des1.getIsTriplet ()) {
219- if (isTriplet) {
220- for (int i = 0 ; i < 3 ; ++i)
221- if (*des1It++ != *des2It++) {
222- LLVM_DEBUG (llvm::dbgs () << " Triplet mismatch for:\n "
223- << des1 << " and:\n "
224- << des2 << " \n " );
225- identicalTriplets = false ;
226- break ;
227- }
228- } else {
229- ++des1It;
230- ++des2It;
338+ bool identicalIndices = true ;
339+ for (auto [isTriplet1, isTriplet2] :
340+ llvm::zip (des1.getIsTriplet (), des2.getIsTriplet ())) {
341+ SectionDesc desc1 = readSectionDesc (des1It, isTriplet1);
342+ SectionDesc desc2 = readSectionDesc (des2It, isTriplet2);
343+
344+ // See if we can prove that any of the sections do not overlap.
345+ // This is mostly a Polyhedron/nf performance hack that looks for
346+ // particular relations between the lower and upper bounds
347+ // of the array sections, e.g. for any positive constant C:
348+ // X:Y does not overlap with (Y+C):Z
349+ // X:Y does not overlap with Z:(X-C)
350+ if (areDisjointSections (desc1, desc2))
351+ return SlicesOverlapKind::DefinitelyDisjoint;
352+
353+ if (!areIdenticalSections (desc1, desc2)) {
354+ if (isTriplet1 || isTriplet2) {
355+ // For example:
356+ // hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %0)
357+ // hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %1)
358+ //
359+ // If all the triplets (section speficiers) are the same, then
360+ // we do not care if %0 is equal to %1 - the slices are either
361+ // identical or completely disjoint.
362+ //
363+ // Also, treat these as identical sections:
364+ // hlfir.designate %6#0 (%c2:%c2:%c1)
365+ // hlfir.designate %6#0 (%c2)
366+ identicalTriplets = false ;
367+ LLVM_DEBUG (llvm::dbgs () << " Triplet mismatch for:\n "
368+ << des1 << " and:\n "
369+ << des2 << " \n " );
370+ } else {
371+ identicalIndices = false ;
372+ LLVM_DEBUG (llvm::dbgs () << " Indices mismatch for:\n "
373+ << des1 << " and:\n "
374+ << des2 << " \n " );
375+ }
231376 }
232377 }
233- if (identicalTriplets)
234- return true ;
235378
236- // See if we can prove that any of the triplets do not overlap.
237- // This is mostly a Polyhedron/nf performance hack that looks for
238- // particular relations between the lower and upper bounds
239- // of the array sections, e.g. for any positive constant C:
240- // X:Y does not overlap with (Y+C):Z
241- // X:Y does not overlap with Z:(X-C)
242- auto displacedByConstant = [](mlir::Value v1, mlir::Value v2) {
243- auto removeConvert = [](mlir::Value v) -> mlir::Operation * {
244- auto *op = v.getDefiningOp ();
245- while (auto conv = mlir::dyn_cast_or_null<fir::ConvertOp>(op))
246- op = conv.getValue ().getDefiningOp ();
247- return op;
248- };
379+ if (identicalTriplets) {
380+ if (identicalIndices)
381+ return SlicesOverlapKind::DefinitelyIdentical;
382+ else
383+ return SlicesOverlapKind::EitherIdenticalOrDisjoint;
384+ }
249385
250- auto isPositiveConstant = [](mlir::Value v) -> bool {
251- if (auto conOp =
252- mlir::dyn_cast<mlir::arith::ConstantOp>(v.getDefiningOp ()))
253- if (auto iattr = mlir::dyn_cast<mlir::IntegerAttr>(conOp.getValue ()))
254- return iattr.getInt () > 0 ;
255- return false ;
256- };
386+ LLVM_DEBUG (llvm::dbgs () << " Different sections for:\n "
387+ << des1 << " and:\n "
388+ << des2 << " \n " );
389+ return SlicesOverlapKind::Unknown;
390+ }
257391
258- auto *op1 = removeConvert (v1);
259- auto *op2 = removeConvert (v2);
260- if (!op1 || !op2)
261- return false ;
262- if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2))
263- if ((addi.getLhs ().getDefiningOp () == op1 &&
264- isPositiveConstant (addi.getRhs ())) ||
265- (addi.getRhs ().getDefiningOp () == op1 &&
266- isPositiveConstant (addi.getLhs ())))
267- return true ;
268- if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1))
269- if (subi.getLhs ().getDefiningOp () == op2 &&
270- isPositiveConstant (subi.getRhs ()))
271- return true ;
392+ bool ArraySectionAnalyzer::isLess (mlir::Value v1, mlir::Value v2) {
393+ auto removeConvert = [](mlir::Value v) -> mlir::Operation * {
394+ auto *op = v.getDefiningOp ();
395+ while (auto conv = mlir::dyn_cast_or_null<fir::ConvertOp>(op))
396+ op = conv.getValue ().getDefiningOp ();
397+ return op;
398+ };
399+
400+ auto isPositiveConstant = [](mlir::Value v) -> bool {
401+ if (auto val = fir::getIntIfConstant (v))
402+ return *val > 0 ;
272403 return false ;
273404 };
274405
275- des1It = des1.getIndices ().begin ();
276- des2It = des2.getIndices ().begin ();
277- for (bool isTriplet : des1.getIsTriplet ()) {
278- if (isTriplet) {
279- mlir::Value des1Lb = *des1It++;
280- mlir::Value des1Ub = *des1It++;
281- mlir::Value des2Lb = *des2It++;
282- mlir::Value des2Ub = *des2It++;
283- // Ignore strides.
284- ++des1It;
285- ++des2It;
286- if (displacedByConstant (des1Ub, des2Lb) ||
287- displacedByConstant (des2Ub, des1Lb))
288- return true ;
289- } else {
290- ++des1It;
291- ++des2It;
292- }
293- }
406+ auto *op1 = removeConvert (v1);
407+ auto *op2 = removeConvert (v2);
408+ if (!op1 || !op2)
409+ return false ;
294410
411+ // Check if they are both constants.
412+ if (auto val1 = fir::getIntIfConstant (op1->getResult (0 )))
413+ if (auto val2 = fir::getIntIfConstant (op2->getResult (0 )))
414+ return *val1 < *val2;
415+
416+ // Handle some variable cases (C > 0):
417+ // v2 = v1 + C
418+ // v2 = C + v1
419+ // v1 = v2 - C
420+ if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2))
421+ if ((addi.getLhs ().getDefiningOp () == op1 &&
422+ isPositiveConstant (addi.getRhs ())) ||
423+ (addi.getRhs ().getDefiningOp () == op1 &&
424+ isPositiveConstant (addi.getLhs ())))
425+ return true ;
426+ if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1))
427+ if (subi.getLhs ().getDefiningOp () == op2 &&
428+ isPositiveConstant (subi.getRhs ()))
429+ return true ;
295430 return false ;
296431}
297432
@@ -405,21 +540,27 @@ ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) {
405540 if (!res.isPartial ()) {
406541 if (auto designate =
407542 effect.getValue ().getDefiningOp <hlfir::DesignateOp>()) {
408- if (!areIdenticalOrDisjointSlices (match.array , designate.getMemref ())) {
543+ ArraySectionAnalyzer::SlicesOverlapKind overlap =
544+ ArraySectionAnalyzer::analyze (match.array , designate.getMemref ());
545+ if (overlap ==
546+ ArraySectionAnalyzer::SlicesOverlapKind::DefinitelyDisjoint)
547+ continue ;
548+
549+ if (overlap == ArraySectionAnalyzer::SlicesOverlapKind::Unknown) {
409550 LLVM_DEBUG (llvm::dbgs () << " possible read conflict: " << designate
410551 << " at " << elemental.getLoc () << " \n " );
411552 return std::nullopt ;
412553 }
413554 auto indices = designate.getIndices ();
414555 auto elementalIndices = elemental.getIndices ();
415- if (indices.size () != elementalIndices.size ()) {
416- LLVM_DEBUG (llvm::dbgs () << " possible read conflict: " << designate
417- << " at " << elemental.getLoc () << " \n " );
418- return std::nullopt ;
419- }
420- if (std::equal (indices.begin (), indices.end (), elementalIndices.begin (),
556+ if (indices.size () == elementalIndices.size () &&
557+ std::equal (indices.begin (), indices.end (), elementalIndices.begin (),
421558 elementalIndices.end ()))
422559 continue ;
560+
561+ LLVM_DEBUG (llvm::dbgs () << " possible read conflict: " << designate
562+ << " at " << elemental.getLoc () << " \n " );
563+ return std::nullopt ;
423564 }
424565 }
425566 LLVM_DEBUG (llvm::dbgs () << " disallowed side-effect: " << effect.getValue ()
0 commit comments