1111 q_constant , q_comp_acc , q_affine , q_routine , search ,
1212 uxreplace )
1313from devito .tools import (Tag , as_mapper , as_tuple , is_integer , filter_sorted ,
14- flatten , memoized_meth , memoized_generator )
14+ flatten , memoized_meth , memoized_generator , smart_gt ,
15+ smart_lt )
1516from devito .types import (ComponentAccess , Dimension , DimensionTuple , Fence ,
1617 CriticalRegion , Function , Symbol , Temp , TempArray ,
1718 TBArray )
@@ -364,11 +365,12 @@ def distance(self, other):
364365 # trip count. E.g. it ranges from 0 to 3; `other` performs a
365366 # constant access at 4
366367 for v in (self [n ], other [n ]):
367- try :
368- if bool (v < sit .symbolic_min or v > sit .symbolic_max ):
369- return Vector (S .ImaginaryUnit )
370- except TypeError :
371- pass
368+ # Note: Uses smart_ comparisons avoid evaluating expensive
369+ # symbolic Lt or Gt operations,
370+ # Note: Boolean is split to make the conditional short
371+ # circuit more frequently for mild speedup.
372+ if smart_lt (v , sit .symbolic_min ) or smart_gt (v , sit .symbolic_max ):
373+ return Vector (S .ImaginaryUnit )
372374
373375 # Case 2: `sit` is an IterationInterval over a local SubDimension
374376 # and `other` performs a constant access
@@ -382,32 +384,36 @@ def distance(self, other):
382384 if disjoint_test (self [n ], other [n ], sai , sit ):
383385 return Vector (S .ImaginaryUnit )
384386
387+ # Compute the distance along the current IterationInterval
385388 if self .function ._mem_shared :
386389 # Special case: the distance between two regular, thread-shared
387- # objects fallbacks to zero, as any other value would be nonsensical
390+ # objects falls back to zero, as any other value would be
391+ # nonsensical
392+ ret .append (S .Zero )
393+ elif degenerating_dimensions (sai , oai ):
394+ # Special case: `sai` and `oai` may be different symbolic objects
395+ # but they can be proved to systematically generate the same value
388396 ret .append (S .Zero )
389-
390397 elif sai and oai and sai ._defines & sit .dim ._defines :
391- # E.g., `self=R<f,[t + 1, x]>`, `self.itintervals=(time, x)`, `ai=t`
398+ # E.g., `self=R<f,[t + 1, x]>`, `self.itintervals=(time, x)`,
399+ # and `ai=t`
392400 if sit .direction is Backward :
393401 ret .append (other [n ] - self [n ])
394402 else :
395403 ret .append (self [n ] - other [n ])
396-
397404 elif not sai and not oai :
398405 # E.g., `self=R<a,[3]>` and `other=W<a,[4]>`
399406 if self [n ] - other [n ] == 0 :
400407 ret .append (S .Zero )
401408 else :
402409 break
403-
404410 elif sai in self .ispace and oai in other .ispace :
405411 # E.g., `self=R<f,[x, y]>`, `sai=time`,
406412 # `self.itintervals=(time, x, y)`, `n=0`
407413 continue
408-
409414 else :
410- # E.g., `self=R<u,[t+1, ii_src_0+1, ii_src_1+2]>`, `fi=p_src`, `n=1`
415+ # E.g., `self=R<u,[t+1, ii_src_0+1, ii_src_1+2]>`, `fi=p_src`,
416+ # and `n=1`
411417 return vinf (ret )
412418
413419 n = len (ret )
@@ -1408,3 +1414,19 @@ def disjoint_test(e0, e1, d, it):
14081414 i1 = sympy .Interval (min (p10 , p11 ), max (p10 , p11 ))
14091415
14101416 return not bool (i0 .intersect (i1 ))
1417+
1418+
1419+ def degenerating_dimensions (d0 , d1 ):
1420+ """
1421+ True if `d0` and `d1` are Dimensions that are possibly symbolically
1422+ different, but they can be proved to systematically degenerate to the
1423+ same value, False otherwise.
1424+ """
1425+ # Case 1: ModuloDimensions of size 1
1426+ try :
1427+ if d0 .is_Modulo and d1 .is_Modulo and d0 .modulo == d1 .modulo == 1 :
1428+ return True
1429+ except AttributeError :
1430+ pass
1431+
1432+ return False
0 commit comments