Skip to content

Commit a7c6841

Browse files
committed
Check for memory counter wait during barrier placement
Signed-off-by: nithinsubbiah <[email protected]>
1 parent 6970696 commit a7c6841

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

wave_lang/kernel/wave/gather_to_shared.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def emit_global_to_lds(
328328
nd_index = config.get_offset(i)
329329
logger.info(f"nd_index={nd_index}")
330330
write_index = {}
331-
for bound_expr, idx in zip(read.indexing_dims, new_nd_index):
331+
for bound_expr, idx in zip(read.indexing_dims, nd_index):
332332
last = bound_expr == read.indexing_dims[-1]
333333
dim = infer_dim(bound_expr)
334334
size = elements_per_thread if last else 1

wave_lang/kernel/wave/utils/barriers_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
CustomOp,
2222
GatherToLDS,
2323
TensorLoadToLDS,
24+
MemoryCounterWait,
2425
Read,
2526
Write,
2627
get_custom,
@@ -424,6 +425,7 @@ def minimize_placement_strategy(
424425
Efficient greedy barrier placement.
425426
- Forward hazards: O(n log n) sort + O(n) sweep with a single "last_pos".
426427
- Cross-iter hazards: two O(log m) range checks via binary search over an always-sorted list of chosen placement positions.
428+
- Skips hazards that already have MemoryCounterWait synchronization.
427429
"""
428430
if not sync_regions:
429431
return []
@@ -445,6 +447,42 @@ def in_range(ranges: List[int], lo: int, hi: int) -> bool:
445447
return False
446448
return any(lo <= p <= hi for p in ranges)
447449

450+
451+
def has_memory_counter_wait(region: SyncRegion):
452+
"""
453+
Check if there's a MemoryCounterWait node between producer and consumer.
454+
455+
MemoryCounterWait provides synchronization for async memory operations,
456+
so we can skip placing a barrier if one already exists between the hazard pair.
457+
458+
Args:
459+
region: The synchronization region contains producer and consumer.
460+
Returns:
461+
True if MemoryCounterWait exists between producer and consumer, False otherwise
462+
"""
463+
if region.producer._topo_location < region.consumer._topo_location:
464+
# For forward hazards - producer before consumer in same iteration
465+
current = region.producer.next
466+
while current is not None and current._topo_location < region.consumer._topo_location:
467+
if isinstance(get_custom(current), MemoryCounterWait):
468+
return True
469+
current = current.next
470+
else:
471+
# For cross-iteration hazards - producer after consumer in loop body
472+
# Check from producer to end of graph
473+
current = region.producer.next
474+
while current is not None and current._topo_location < region.graph_end._topo_location:
475+
if isinstance(get_custom(current), MemoryCounterWait):
476+
return True
477+
current = current.next
478+
# Check from start of graph to consumer
479+
current = region.graph_start
480+
while current is not None and current._topo_location < region.consumer._topo_location:
481+
if isinstance(get_custom(current), MemoryCounterWait):
482+
return True
483+
current = current.next
484+
return False
485+
448486
# 1) sort by (consumer, producer)
449487
regions = sorted(
450488
sync_regions,
@@ -459,6 +497,10 @@ def in_range(ranges: List[int], lo: int, hi: int) -> bool:
459497
continue
460498
start, end = get_location(region)
461499

500+
# # Skip if MemoryCounterWait already provides synchronization
501+
# if has_memory_counter_wait(region):
502+
# continue
503+
462504
# A hazard window is covered if placement is at (start, end]
463505
# We append to result if this window is not covered.
464506
if not (last_pos > start and last_pos <= end):
@@ -475,6 +517,10 @@ def in_range(ranges: List[int], lo: int, hi: int) -> bool:
475517
if not region.cross_iter:
476518
continue
477519

520+
# # Skip if MemoryCounterWait already provides synchronization
521+
# if has_memory_counter_wait(region):
522+
# continue
523+
478524
start, end = get_location(region)
479525
graph_start, graph_end = (
480526
region.graph_start._topo_location,

0 commit comments

Comments
 (0)