|
27 | 27 | # ***************** |
28 | 28 | # 0. do some cleanup rewrites, mostly copied from the old stuff |
29 | 29 |
|
30 | | -def find_permutes(a:UOp, b:UOp, assign:UOp): |
31 | | - if not (permutes:=[s for s in b.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS) |
32 | | - if s.op in GroupOp.Movement and s.op not in {Ops.RESHAPE, Ops.EXPAND, Ops.PAD, Ops.SHRINK}]): return |
33 | | - target = a.base |
34 | | - for p in permutes: |
35 | | - if any(s is target for s in p.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS-{Ops.BUFFER})): return assign.replace(src=(a, b.contiguous())) |
| 30 | +def fix_assign_hazard(dest:UOp, src:UOp, assign:UOp): |
| 31 | + # PERMUTE and FLIP reorder indices, causing read/write races when src and dest are the same buffer |
| 32 | + unsafe = {Ops.PERMUTE, Ops.FLIP} |
| 33 | + if not (hazards:=[s for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS) if s.op in unsafe]): return |
| 34 | + for h in hazards: |
| 35 | + if any(s is dest.base for s in h.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS-{Ops.BUFFER})): |
| 36 | + return assign.replace(src=(dest, src.contiguous())) |
36 | 37 |
|
37 | 38 | def split_reduceop(reduce:UOp, x:UOp): |
38 | 39 | if prod(reduce.shape) == 0: return None |
@@ -116,8 +117,8 @@ def resolve_custom_kernel(ck:UOp) -> UOp: |
116 | 117 | lambda x,target,assign: x.f(Ops.CONTIGUOUS, tag=assign.tag) if ((t:=target.base).op is not Ops.BUFFER and \ |
117 | 118 | not (t.op is Ops.MSTACK and all(s.op is Ops.BUFFER for s in t.src))) else None), |
118 | 119 |
|
119 | | - # realize before assign if input permutes the target buffer |
120 | | - (UPat(Ops.ASSIGN, src=(UPat.var("a"), UPat.var("b")), name="assign"), find_permutes), |
| 120 | + # make source contiguous if it has hazardous movement ops on the dest buffer |
| 121 | + (UPat(Ops.ASSIGN, src=(UPat.var("dest"), UPat.var("src")), name="assign"), fix_assign_hazard), |
121 | 122 | ]) |
122 | 123 |
|
123 | 124 | # ***************** |
|
0 commit comments