Skip to content

Commit 2e72625

Browse files
authored
Revert "decompose dtypes.long to ints where unsupported (tinygrad#14261)" (tinygrad#14362)
1 parent f866b2a commit 2e72625

File tree

6 files changed

+15
-94
lines changed

6 files changed

+15
-94
lines changed

test/test_dtype.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
settings.load_profile("my_profile")
1919

2020
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
21-
if not is_dtype_supported(dtype) and dtype not in (dtypes.long, dtypes.ulong): return []
21+
if not is_dtype_supported(dtype): return []
2222
# dont cast internal dtypes
2323
return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")]
2424

@@ -333,14 +333,8 @@ def test_uint16_to_int8_overflow(self):
333333
class TestInt32DType(TestDType): DTYPE = dtypes.int32
334334
class TestUint32DType(TestDType): DTYPE = dtypes.uint32
335335

336-
class TestInt64DType(TestDType):
337-
DTYPE = dtypes.int64
338-
@classmethod
339-
def setUpClass(cls): cls.DATA = rand_for_dtype(cls.DTYPE, 10)
340-
336+
class TestInt64DType(TestDType): DTYPE = dtypes.int64
341337
class TestUint64DType(TestDType):
342-
@classmethod
343-
def setUpClass(cls): cls.DATA = rand_for_dtype(cls.DTYPE, 10)
344338
DTYPE = dtypes.uint64
345339
def test_uint64_load(self):
346340
assert Tensor(2**64 - 1, dtype=dtypes.uint64).numpy() == 2**64 - 1

test/test_dtype_alu.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def test_uint16(self, a, b, op): universal_test(a, b, dtypes.uint16, op)
165165
@given(ht.uint32, ht.uint32, strat.sampled_from(integer_binary_operations))
166166
def test_uint32(self, a, b, op): universal_test(a, b, dtypes.uint32, op)
167167

168+
@unittest.skipUnless(is_dtype_supported(dtypes.uint64), f"no uint64 on {Device.DEFAULT}")
168169
@given(ht.uint64, ht.uint64, strat.sampled_from(integer_binary_operations))
169170
def test_uint64(self, a, b, op): universal_test(a, b, dtypes.uint64, op)
170171

@@ -177,6 +178,7 @@ def test_int16(self, a, b, op): universal_test(a, b, dtypes.int16, op)
177178
@given(ht.int32, ht.int32, strat.sampled_from(integer_binary_operations))
178179
def test_int32(self, a, b, op): universal_test(a, b, dtypes.int32, op)
179180

181+
@unittest.skipUnless(is_dtype_supported(dtypes.int64), f"no int64 on {Device.DEFAULT}")
180182
@given(ht.int64, ht.int64, strat.sampled_from(integer_binary_operations))
181183
def test_int64(self, a, b, op): universal_test(a, b, dtypes.int64, op)
182184

@@ -191,6 +193,7 @@ def test_uint16_unary(self, a, op): universal_test_unary(a, dtypes.uint16, op)
191193
@given(ht.uint32, strat.sampled_from(integer_unary_operations))
192194
def test_uint32_unary(self, a, op): universal_test_unary(a, dtypes.uint32, op)
193195

196+
@unittest.skipUnless(is_dtype_supported(dtypes.uint64), f"no uint64 on {Device.DEFAULT}")
194197
@given(ht.uint64, strat.sampled_from(integer_unary_operations))
195198
def test_uint64_unary(self, a, op): universal_test_unary(a, dtypes.uint64, op)
196199

@@ -203,6 +206,7 @@ def test_int16_unary(self, a, op): universal_test_unary(a, dtypes.int16, op)
203206
@given(ht.int32, strat.sampled_from(integer_unary_operations))
204207
def test_int32_unary(self, a, op): universal_test_unary(a, dtypes.int32, op)
205208

209+
@unittest.skipUnless(is_dtype_supported(dtypes.int64), f"no int64 on {Device.DEFAULT}")
206210
@given(ht.int64, strat.sampled_from(integer_unary_operations))
207211
def test_int64_unary(self, a, op): universal_test_unary(a, dtypes.int64, op)
208212

test/test_edgecases.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import numpy as np
2727
import torch
2828
from tinygrad import Tensor, dtypes, nn
29-
from tinygrad.device import Device
29+
from tinygrad.device import Device, is_dtype_supported
3030
from tinygrad.helpers import getenv
3131
from tinygrad.renderer.nir import NIRRenderer
3232

@@ -207,7 +207,8 @@ class TestUOpValidationIssue(unittest.TestCase):
207207
# these fail with UOp verification error.
208208
# we want more of these with diverse errors!
209209

210-
@unittest.skipIf(MOCKGPU or isinstance(Device[Device.DEFAULT].renderer, NIRRenderer), "hangs gpuocelot, NIR cannot render")
210+
@unittest.skipIf((not is_dtype_supported(dtypes.long)) or MOCKGPU or isinstance(Device[Device.DEFAULT].renderer, NIRRenderer),
211+
"hangs gpuocelot, NIR cannot render")
211212
def test_tensor_index_overflow(self):
212213
val = Tensor([1])
213214
big = val.expand(2**31 + 3)

test/unit/test_indexing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def test_index_put_accumulate_duplicate_indices(self):
339339
numpy_testing_assert_equal_helper(output, input_list)
340340
'''
341341

342-
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't support long indexing: #13624")
342+
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"long dtype not supported on {Device.DEFAULT}")
343343
def test_index_ind_dtype(self):
344344
x = Tensor.randn(4, 4)
345345
# ind_long = torch.randint(4, (4,), dtype=torch.long)

tinygrad/codegen/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
9595

9696
# decompositions
9797
supported_ops = tuple(ren.code_for_op.keys())
98-
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, ren.device, TRANSCENDENTAL>=2)
98+
pm_decomp = symbolic_simple+get_late_rewrite_patterns(supported_ops, TRANSCENDENTAL>=2)
9999
sink = graph_rewrite(sink, pm_decomp, ctx=ren.device, name="decompositions")
100100

101101
# final rules for the renderer (without sym)

tinygrad/uop/decompositions.py

Lines changed: 4 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import math, functools
33
from tinygrad.dtype import dtypes, DType, promo_lattice
44
from tinygrad.device import is_dtype_supported
5-
from tinygrad.helpers import flatten, polyN, DISABLE_FAST_IDIV
6-
from tinygrad.uop import GroupOp
5+
from tinygrad.helpers import polyN, DISABLE_FAST_IDIV
76
from tinygrad.uop.ops import UOp, UPat, Ops, PatternMatcher
87

98
TRANSCENDENTAL_DTYPES = (dtypes.float16, dtypes.float32, dtypes.float64)
@@ -315,70 +314,11 @@ def threefry2x32(x: UOp, key: UOp):
315314

316315
return xr[1].cast(dtypes.uint64) * 2**32 | xr[0].cast(dtypes.uint64)
317316

318-
# ***** long as 2 ints *****
319-
320-
l2i_dt = {dtypes.long: dtypes.int, dtypes.ulong: dtypes.uint}
321-
def unpack32(v): return v.bitcast(dtypes.uint) & 0xFFFF, v.bitcast(dtypes.uint) >> 16
322-
def l2i_idx(idx,off): return idx.replace(src=(idx.src[0], idx.src[1]*2+off))
323-
324-
# 4.3.1 is the relevant section in TAOCP
325-
def l2i(op: Ops, dt: DType, *uops:UOp):
326-
zero = UOp.const(dt, 0)
327-
if len(uops) == 2: a0, a1 = uops
328-
elif len(uops) == 4: a0, a1, b0, b1 = uops
329-
match op:
330-
case Ops.NEG: return l2i(Ops.SUB, dt, zero, zero, *uops)
331-
case Ops.CAST if dt in (dtypes.long, dtypes.ulong) and uops[0].dtype not in dtypes.floats:
332-
return uops[0].cast(l2i_dt[dt]), (uops[0] < 0).where(UOp.const(l2i_dt[dt], -1), UOp.const(l2i_dt[dt], 0))
333-
case Ops.CAST if dt in (dtypes.long, dtypes.ulong):
334-
return (lo:=uops[0].cast(l2i_dt[dt])), (uops[0] / 2**32).cast(l2i_dt[dt]) - ((uops[0] < 0) & lo.ne(0)).cast(l2i_dt[dt])
335-
case Ops.CAST if dt in dtypes.floats:
336-
small = (a1.eq(0) & (a0 >= 0)) | (a1.eq(-1) & (a0 < 0))
337-
return small.where(a0.cast(dt), ((a1.cast(dtypes.float32) * (2**32)) + a0.bitcast(dtypes.uint).cast(dtypes.float32)).cast(dt))
338-
case Ops.CAST: return a0.bitcast(dtypes.uint).cast(dt)
339-
case Ops.BITCAST: return a0.bitcast(dt), a1.bitcast(dt)
340-
case Ops.SHL:
341-
lo, hi = a0 << (b0_mod:=b0 & 31), (a1 << b0_mod) | ((a0 >> 1) >> (31 - b0_mod))
342-
return (b0 >= 32).where(zero, lo), (b0 >= 32).where(lo, hi)
343-
case Ops.SHR:
344-
lo, hi = (a0 >> (b0_mod:=b0 & 31)) | ((a1 << 1) << (31 - b0_mod)), a1 >> b0_mod
345-
return (b0 >= 32).where(hi, lo), (b0 >= 32).where(zero, hi)
346-
case Ops.ADD: return (low:=a0+b0), (a1 + b1).replace(dtype=dt) + (low.bitcast(dtypes.uint) < a0.bitcast(dtypes.uint)).cast(dt)
347-
case Ops.SUB: return a0 - b0, a1 - b1 - (a0.bitcast(dtypes.uint) < b0.bitcast(dtypes.uint)).cast(dt)
348-
case Ops.MUL:
349-
(a00, a01), (b00, b01) = unpack32(a0), unpack32(b0)
350-
mid = l2i(Ops.ADD, dt, ((a00*b01)<<16).bitcast(dt), ((a00*b01)>>16).bitcast(dt), ((a01*b00)<<16).bitcast(dt), ((a01*b00)>>16).bitcast(dt))
351-
return l2i(Ops.ADD, dt, *mid, (a00*b00).bitcast(dt), (a01*b01).bitcast(dt) + a0*b1 + a1*b0)
352-
case Ops.IDIV | Ops.MOD:
353-
# TAOCP Algorithm 4.3.1D could be faster here, but must be parameterized over the width of b
354-
if dt == dtypes.int:
355-
a0, a1 = (a_neg:=a1 < zero).where((n:=l2i(Ops.NEG, dt, a0, a1))[0], a0).bitcast(dtypes.uint), a_neg.where(n[1], a1).bitcast(dtypes.uint)
356-
b0, b1 = (b_neg:=b1 < zero).where((n:=l2i(Ops.NEG, dt, b0, b1))[0], b0).bitcast(dtypes.uint), b_neg.where(n[1], b1).bitcast(dtypes.uint)
357-
q, r = (z:=UOp.const(dtypes.uint, 0), z), (z, z)
358-
for i in range(63, -1, -1):
359-
r = l2i(Ops.SHL, dtypes.uint, *r, UOp.const(dtypes.uint, 1), z)
360-
r = (r[0] | l2i(Ops.SHR, dtypes.uint, a0, a1, UOp.const(dtypes.uint, i), z)[0] & 1), r[1]
361-
cond = l2i(Ops.CMPLT, dtypes.uint, *r, b0, b1).logical_not()
362-
diff = l2i(Ops.SUB, dtypes.uint, *r, b0, b1)
363-
q = ((q[0] | cond.cast(dtypes.uint) << (i % 32), q[1]) if i < 32 else (q[0], q[1] | cond.cast(dtypes.uint) << (i % 32)))
364-
r = l2i(Ops.WHERE, dtypes.uint, cond, *diff, *r)
365-
if dt == dtypes.int:
366-
nq, nr = l2i(Ops.NEG, dt, q0:=q[0].bitcast(dt), q1:=q[1].bitcast(dt)), l2i(Ops.NEG, dt, r0:=r[0].bitcast(dt), r1:=r[1].bitcast(dt))
367-
return (a_neg.where(nr[0], r0), a_neg.where(nr[1], r1)) if op == Ops.MOD else ((a_neg^b_neg).where(nq[0], q0), (a_neg^b_neg).where(nq[1], q1))
368-
return (r[0].bitcast(dt), r[1].bitcast(dt)) if op == Ops.MOD else (q[0].bitcast(dt), q[1].bitcast(dt))
369-
case Ops.CMPLT: return (a1 < b1) | ((a1.eq(b1)) & (a0.bitcast(dtypes.uint) < b0.bitcast(dtypes.uint)))
370-
case Ops.CMPEQ: return a0.eq(b0) & a1.eq(b1)
371-
case Ops.CMPNE: return a0.ne(b0) | a1.ne(b1)
372-
case Ops.XOR | Ops.OR | Ops.AND: return UOp(op, dt, src=(a0, b0)), UOp(op, dt, src=(a1, b1))
373-
case Ops.WHERE: return uops[0].where(uops[1], uops[3]), uops[0].where(uops[2], uops[4])
374-
case Ops.MAX: return l2i(Ops.WHERE, dt, l2i(Ops.CMPLT, dt, *uops), b0, b1, a0, a1)
375-
case _: raise NotImplementedError(f"long decomposition of {op} unsupported")
376-
377317
# ***** decomposition patterns *****
378318

379319
powers_of_two = {2**i:i for i in range(64)}
380320
@functools.cache
381-
def get_late_rewrite_patterns(ops:tuple[Ops, ...], device, force_transcendental):
321+
def get_late_rewrite_patterns(ops:tuple[Ops, ...], force_transcendental):
382322
pat: list[tuple[UPat, Callable]] = []
383323
for op,f in ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)):
384324
if op not in ops or force_transcendental:
@@ -406,8 +346,8 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], device, force_transcendental)
406346
pat += [(UPat.var("x", dtypes.ints)//UPat.cvar("d", vec=False), lambda ctx, x, d: fast_idiv(ctx, x, d.arg))]
407347
pat += [(UPat.var("x", dtypes.ints)%UPat.var("d"), lambda x, d: x-d*(x//d))]
408348
if Ops.NEG in ops:
409-
pat += [(UPat.var('x')*-1, lambda ctx,x: x.alu(Ops.NEG))]
410-
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda ctx,x,y: x.alu(Ops.SUB, y))]
349+
pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
350+
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
411351
if Ops.CMPLT in ops:
412352
# These are late rewrites because simplex expects equalities to be a certain format
413353
pat += [
@@ -424,22 +364,4 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], device, force_transcendental)
424364
if Ops.FDIV in ops:
425365
pat += [(UPat.var("x").reciprocal(), lambda x: x.const_like(1).alu(Ops.FDIV, x))]
426366
pat += [(UPat.var("a", dtypes.floats) * UPat.const(dtypes.floats, 1).alu(Ops.FDIV, UPat.var("b")), lambda a,b: a.alu(Ops.FDIV, b))]
427-
if not is_dtype_supported(dtypes.long, device):
428-
pat += [(UPat((*GroupOp.Defines, Ops.INDEX), name="x"), lambda x:
429-
x.replace(dtype=l2i_dt[x.dtype.base].ptr(x.dtype.size * 2)) if x.dtype.base in l2i_dt else None)]
430-
pat += [(UPat(Ops.STORE, src=(UPat.var('idx'), UPat.var('val', tuple(l2i_dt.keys()))), name='st'), lambda st,idx,val:
431-
st.replace(src=(l2i_idx(idx, 0), val.rtag(0))).group(st.replace(src=(l2i_idx(idx, 1), val.rtag(1)))) if val.tag is None else None)]
432-
pat += [(UPat(GroupOp.Comparison, src=(UPat.var('a', tuple(l2i_dt.keys())), UPat.var('b', tuple(l2i_dt.keys()))), name="x"), lambda a,b,x:
433-
l2i(x.op, dt:=l2i_dt[a.dtype], a.rtag(0).cast(dt), a.rtag(1).cast(dt), b.rtag(0).cast(dt), b.rtag(1).cast(dt)))]
434-
pat += [(UPat(Ops.CAST, tuple(l2i_dt.keys()), src=(UPat.var('a'),), name="x"), lambda a,x:
435-
l2i(x.op, x.dtype, a)[x.tag] if x.tag is not None else None)]
436-
pat += [(UPat(Ops.CAST, src=(UPat.var('a', tuple(l2i_dt.keys())),), name="x"), lambda a,x:
437-
l2i(x.op, x.dtype, a.rtag(0).cast(dt:=l2i_dt[a.dtype]), a.rtag(1).cast(dt)))]
438-
pat += [(UPat((*(GroupOp.ALU - GroupOp.Comparison), Ops.BITCAST), tuple(l2i_dt.keys()), name="x"), lambda x:
439-
None if x.tag is None else l2i(x.op, l2i_dt[x.dtype], *flatten((a.rtag(0).cast(dt:=l2i_dt[x.src[-1].dtype]), a.rtag(1).cast(dt))
440-
if a.dtype in l2i_dt else (a,) for a in x.src))[x.tag])]
441-
pat += [(UPat(Ops.LOAD, tuple(l2i_dt.keys()), src=(UPat.var('idx'),), name='x'), lambda x,idx:
442-
None if x.tag is None else x.replace(dtype=l2i_dt[x.dtype], src=(l2i_idx(idx, x.tag),)))]
443-
pat += [(UPat(Ops.CONST, tuple(l2i_dt.keys()), name='x'), lambda x:
444-
None if x.tag is None else UOp.const(l2i_dt[x.dtype], (x.arg >> 32) if x.tag == 1 else (x.arg & 0xFFFFFFFF)))]
445367
return PatternMatcher(pat)

0 commit comments

Comments
 (0)