Skip to content

Commit da500db

Browse files
authored
simplify late_buffer_view [pr] (tinygrad#14478)
check the only allowed Ops in the chain, and offset cannot be negative
1 parent b4f9630 commit da500db

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

tinygrad/schedule/rangeify.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,14 +275,15 @@ def late_buffer_view(t:UOp, b:UOp):
275275
size = prod(shape)
276276

277277
# walk up for the INDEX
278+
# NOTE: even though we allow RESHAPE and SHRINK, they can combine to form non-contiguous access patterns (e.g. t[::2])
278279
x = t
279-
while not any(u.op is Ops.INDEX for u in x.src):
280-
assert x.op not in GroupOp.Elementwise, "can't buffer view elementwise"
280+
while x.op is not Ops.INDEX:
281+
assert x.op in {Ops.BITCAST, Ops.CONTIGUOUS, Ops.SHRINK, Ops.RESHAPE}, f"unexpected op {x.op} in buffer view walk"
281282
x = x.src[0]
282-
x = next(u for u in x.src if u.op is Ops.INDEX)
283283

284284
if len(shape) == 0: offset = x.src[1].arg
285-
else: offset = max(sum(idx.vmin for idx in x.src[1:]), 0)
285+
else: offset = sum(idx.vmin for idx in x.src[1:])
286+
if offset < 0: raise RuntimeError(f"negative offset {offset} in buffer view")
286287

287288
return b.replace(src=(UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (size, offset), tag=t.tag), b.src[1]))
288289

0 commit comments

Comments
 (0)