2121 CustomOp ,
2222 GatherToLDS ,
2323 TensorLoadToLDS ,
24+ MemoryCounterWait ,
2425 Read ,
2526 Write ,
2627 get_custom ,
@@ -424,6 +425,7 @@ def minimize_placement_strategy(
424425 Efficient greedy barrier placement.
425426 - Forward hazards: O(n log n) sort + O(n) sweep with a single "last_pos".
426427 - Cross-iter hazards: two O(log m) range checks via binary search over an always-sorted list of chosen placement positions.
428+ - Skips hazards that already have MemoryCounterWait synchronization.
427429 """
428430 if not sync_regions :
429431 return []
@@ -445,6 +447,42 @@ def in_range(ranges: List[int], lo: int, hi: int) -> bool:
445447 return False
446448 return any (lo <= p <= hi for p in ranges )
447449
450+
451+ def has_memory_counter_wait (region : SyncRegion ):
452+ """
453+ Check if there's a MemoryCounterWait node between producer and consumer.
454+
455+ MemoryCounterWait provides synchronization for async memory operations,
456+ so we can skip placing a barrier if one already exists between the hazard pair.
457+
458+ Args:
459+ region: The synchronization region contains producer and consumer.
460+ Returns:
461+ True if MemoryCounterWait exists between producer and consumer, False otherwise
462+ """
463+ if region .producer ._topo_location < region .consumer ._topo_location :
464+ # For forward hazards - producer before consumer in same iteration
465+ current = region .producer .next
466+ while current is not None and current ._topo_location < region .consumer ._topo_location :
467+ if isinstance (get_custom (current ), MemoryCounterWait ):
468+ return True
469+ current = current .next
470+ else :
471+ # For cross-iteration hazards - producer after consumer in loop body
472+ # Check from producer to end of graph
473+ current = region .producer .next
474+ while current is not None and current ._topo_location < region .graph_end ._topo_location :
475+ if isinstance (get_custom (current ), MemoryCounterWait ):
476+ return True
477+ current = current .next
478+ # Check from start of graph to consumer
479+ current = region .graph_start
480+ while current is not None and current ._topo_location < region .consumer ._topo_location :
481+ if isinstance (get_custom (current ), MemoryCounterWait ):
482+ return True
483+ current = current .next
484+ return False
485+
448486 # 1) sort by (consumer, producer)
449487 regions = sorted (
450488 sync_regions ,
@@ -459,6 +497,10 @@ def in_range(ranges: List[int], lo: int, hi: int) -> bool:
459497 continue
460498 start , end = get_location (region )
461499
500+ # # Skip if MemoryCounterWait already provides synchronization
501+ # if has_memory_counter_wait(region):
502+ # continue
503+
462504 # A hazard window is covered if placement is at (start, end]
463505 # We append to result if this window is not covered.
464506 if not (last_pos > start and last_pos <= end ):
@@ -475,6 +517,10 @@ def in_range(ranges: List[int], lo: int, hi: int) -> bool:
475517 if not region .cross_iter :
476518 continue
477519
520+ # # Skip if MemoryCounterWait already provides synchronization
521+ # if has_memory_counter_wait(region):
522+ # continue
523+
478524 start , end = get_location (region )
479525 graph_start , graph_end = (
480526 region .graph_start ._topo_location ,
0 commit comments