Skip to content

Commit 145df87

Browse files
authored
find_permutes -> fix_assign_hazard [pr] (tinygrad#14354)
some noop tweaks and comment updates
1 parent e152f1b commit 145df87

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

tinygrad/schedule/rangeify.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@
2727
# *****************
2828
# 0. do some cleanup rewrites, mostly copied from the old stuff
2929

30-
def find_permutes(a:UOp, b:UOp, assign:UOp):
31-
if not (permutes:=[s for s in b.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS)
32-
if s.op in GroupOp.Movement and s.op not in {Ops.RESHAPE, Ops.EXPAND, Ops.PAD, Ops.SHRINK}]): return
33-
target = a.base
34-
for p in permutes:
35-
if any(s is target for s in p.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS-{Ops.BUFFER})): return assign.replace(src=(a, b.contiguous()))
30+
def fix_assign_hazard(dest:UOp, src:UOp, assign:UOp):
31+
# PERMUTE and FLIP reorder indices, causing read/write races when src and dest are the same buffer
32+
unsafe = {Ops.PERMUTE, Ops.FLIP}
33+
if not (hazards:=[s for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS) if s.op in unsafe]): return
34+
for h in hazards:
35+
if any(s is dest.base for s in h.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS-{Ops.BUFFER})):
36+
return assign.replace(src=(dest, src.contiguous()))
3637

3738
def split_reduceop(reduce:UOp, x:UOp):
3839
if prod(reduce.shape) == 0: return None
@@ -116,8 +117,8 @@ def resolve_custom_kernel(ck:UOp) -> UOp:
116117
lambda x,target,assign: x.f(Ops.CONTIGUOUS, tag=assign.tag) if ((t:=target.base).op is not Ops.BUFFER and \
117118
not (t.op is Ops.MSTACK and all(s.op is Ops.BUFFER for s in t.src))) else None),
118119

119-
# realize before assign if input permutes the target buffer
120-
(UPat(Ops.ASSIGN, src=(UPat.var("a"), UPat.var("b")), name="assign"), find_permutes),
120+
# make source contiguous if it has hazardous movement ops on the dest buffer
121+
(UPat(Ops.ASSIGN, src=(UPat.var("dest"), UPat.var("src")), name="assign"), fix_assign_hazard),
121122
])
122123

123124
# *****************

0 commit comments

Comments
 (0)