22import math , functools
33from tinygrad .dtype import dtypes , DType , promo_lattice
44from tinygrad .device import is_dtype_supported
5- from tinygrad .helpers import flatten , polyN , DISABLE_FAST_IDIV
6- from tinygrad .uop import GroupOp
5+ from tinygrad .helpers import polyN , DISABLE_FAST_IDIV
76from tinygrad .uop .ops import UOp , UPat , Ops , PatternMatcher
87
98TRANSCENDENTAL_DTYPES = (dtypes .float16 , dtypes .float32 , dtypes .float64 )
@@ -315,70 +314,11 @@ def threefry2x32(x: UOp, key: UOp):
315314
316315 return xr [1 ].cast (dtypes .uint64 ) * 2 ** 32 | xr [0 ].cast (dtypes .uint64 )
317316
318- # ***** long as 2 ints *****
319-
320- l2i_dt = {dtypes .long : dtypes .int , dtypes .ulong : dtypes .uint }
321- def unpack32 (v ): return v .bitcast (dtypes .uint ) & 0xFFFF , v .bitcast (dtypes .uint ) >> 16
322- def l2i_idx (idx ,off ): return idx .replace (src = (idx .src [0 ], idx .src [1 ]* 2 + off ))
323-
324- # 4.3.1 is the relevant section in TAOCP
325- def l2i (op : Ops , dt : DType , * uops :UOp ):
326- zero = UOp .const (dt , 0 )
327- if len (uops ) == 2 : a0 , a1 = uops
328- elif len (uops ) == 4 : a0 , a1 , b0 , b1 = uops
329- match op :
330- case Ops .NEG : return l2i (Ops .SUB , dt , zero , zero , * uops )
331- case Ops .CAST if dt in (dtypes .long , dtypes .ulong ) and uops [0 ].dtype not in dtypes .floats :
332- return uops [0 ].cast (l2i_dt [dt ]), (uops [0 ] < 0 ).where (UOp .const (l2i_dt [dt ], - 1 ), UOp .const (l2i_dt [dt ], 0 ))
333- case Ops .CAST if dt in (dtypes .long , dtypes .ulong ):
334- return (lo := uops [0 ].cast (l2i_dt [dt ])), (uops [0 ] / 2 ** 32 ).cast (l2i_dt [dt ]) - ((uops [0 ] < 0 ) & lo .ne (0 )).cast (l2i_dt [dt ])
335- case Ops .CAST if dt in dtypes .floats :
336- small = (a1 .eq (0 ) & (a0 >= 0 )) | (a1 .eq (- 1 ) & (a0 < 0 ))
337- return small .where (a0 .cast (dt ), ((a1 .cast (dtypes .float32 ) * (2 ** 32 )) + a0 .bitcast (dtypes .uint ).cast (dtypes .float32 )).cast (dt ))
338- case Ops .CAST : return a0 .bitcast (dtypes .uint ).cast (dt )
339- case Ops .BITCAST : return a0 .bitcast (dt ), a1 .bitcast (dt )
340- case Ops .SHL :
341- lo , hi = a0 << (b0_mod := b0 & 31 ), (a1 << b0_mod ) | ((a0 >> 1 ) >> (31 - b0_mod ))
342- return (b0 >= 32 ).where (zero , lo ), (b0 >= 32 ).where (lo , hi )
343- case Ops .SHR :
344- lo , hi = (a0 >> (b0_mod := b0 & 31 )) | ((a1 << 1 ) << (31 - b0_mod )), a1 >> b0_mod
345- return (b0 >= 32 ).where (hi , lo ), (b0 >= 32 ).where (zero , hi )
346- case Ops .ADD : return (low := a0 + b0 ), (a1 + b1 ).replace (dtype = dt ) + (low .bitcast (dtypes .uint ) < a0 .bitcast (dtypes .uint )).cast (dt )
347- case Ops .SUB : return a0 - b0 , a1 - b1 - (a0 .bitcast (dtypes .uint ) < b0 .bitcast (dtypes .uint )).cast (dt )
348- case Ops .MUL :
349- (a00 , a01 ), (b00 , b01 ) = unpack32 (a0 ), unpack32 (b0 )
350- mid = l2i (Ops .ADD , dt , ((a00 * b01 )<< 16 ).bitcast (dt ), ((a00 * b01 )>> 16 ).bitcast (dt ), ((a01 * b00 )<< 16 ).bitcast (dt ), ((a01 * b00 )>> 16 ).bitcast (dt ))
351- return l2i (Ops .ADD , dt , * mid , (a00 * b00 ).bitcast (dt ), (a01 * b01 ).bitcast (dt ) + a0 * b1 + a1 * b0 )
352- case Ops .IDIV | Ops .MOD :
353- # TAOCP Algorithm 4.3.1D could be faster here, but must be parameterized over the width of b
354- if dt == dtypes .int :
355- a0 , a1 = (a_neg := a1 < zero ).where ((n := l2i (Ops .NEG , dt , a0 , a1 ))[0 ], a0 ).bitcast (dtypes .uint ), a_neg .where (n [1 ], a1 ).bitcast (dtypes .uint )
356- b0 , b1 = (b_neg := b1 < zero ).where ((n := l2i (Ops .NEG , dt , b0 , b1 ))[0 ], b0 ).bitcast (dtypes .uint ), b_neg .where (n [1 ], b1 ).bitcast (dtypes .uint )
357- q , r = (z := UOp .const (dtypes .uint , 0 ), z ), (z , z )
358- for i in range (63 , - 1 , - 1 ):
359- r = l2i (Ops .SHL , dtypes .uint , * r , UOp .const (dtypes .uint , 1 ), z )
360- r = (r [0 ] | l2i (Ops .SHR , dtypes .uint , a0 , a1 , UOp .const (dtypes .uint , i ), z )[0 ] & 1 ), r [1 ]
361- cond = l2i (Ops .CMPLT , dtypes .uint , * r , b0 , b1 ).logical_not ()
362- diff = l2i (Ops .SUB , dtypes .uint , * r , b0 , b1 )
363- q = ((q [0 ] | cond .cast (dtypes .uint ) << (i % 32 ), q [1 ]) if i < 32 else (q [0 ], q [1 ] | cond .cast (dtypes .uint ) << (i % 32 )))
364- r = l2i (Ops .WHERE , dtypes .uint , cond , * diff , * r )
365- if dt == dtypes .int :
366- nq , nr = l2i (Ops .NEG , dt , q0 := q [0 ].bitcast (dt ), q1 := q [1 ].bitcast (dt )), l2i (Ops .NEG , dt , r0 := r [0 ].bitcast (dt ), r1 := r [1 ].bitcast (dt ))
367- return (a_neg .where (nr [0 ], r0 ), a_neg .where (nr [1 ], r1 )) if op == Ops .MOD else ((a_neg ^ b_neg ).where (nq [0 ], q0 ), (a_neg ^ b_neg ).where (nq [1 ], q1 ))
368- return (r [0 ].bitcast (dt ), r [1 ].bitcast (dt )) if op == Ops .MOD else (q [0 ].bitcast (dt ), q [1 ].bitcast (dt ))
369- case Ops .CMPLT : return (a1 < b1 ) | ((a1 .eq (b1 )) & (a0 .bitcast (dtypes .uint ) < b0 .bitcast (dtypes .uint )))
370- case Ops .CMPEQ : return a0 .eq (b0 ) & a1 .eq (b1 )
371- case Ops .CMPNE : return a0 .ne (b0 ) | a1 .ne (b1 )
372- case Ops .XOR | Ops .OR | Ops .AND : return UOp (op , dt , src = (a0 , b0 )), UOp (op , dt , src = (a1 , b1 ))
373- case Ops .WHERE : return uops [0 ].where (uops [1 ], uops [3 ]), uops [0 ].where (uops [2 ], uops [4 ])
374- case Ops .MAX : return l2i (Ops .WHERE , dt , l2i (Ops .CMPLT , dt , * uops ), b0 , b1 , a0 , a1 )
375- case _: raise NotImplementedError (f"long decomposition of { op } unsupported" )
376-
377317# ***** decomposition patterns *****
378318
379319powers_of_two = {2 ** i :i for i in range (64 )}
380320@functools .cache
381- def get_late_rewrite_patterns (ops :tuple [Ops , ...], device , force_transcendental ):
321+ def get_late_rewrite_patterns (ops :tuple [Ops , ...], force_transcendental ):
382322 pat : list [tuple [UPat , Callable ]] = []
383323 for op ,f in ((Ops .EXP2 , xexp2 ), (Ops .LOG2 , xlog2 ), (Ops .SIN , xsin )):
384324 if op not in ops or force_transcendental :
@@ -406,8 +346,8 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], device, force_transcendental)
406346 pat += [(UPat .var ("x" , dtypes .ints )// UPat .cvar ("d" , vec = False ), lambda ctx , x , d : fast_idiv (ctx , x , d .arg ))]
407347 pat += [(UPat .var ("x" , dtypes .ints )% UPat .var ("d" ), lambda x , d : x - d * (x // d ))]
408348 if Ops .NEG in ops :
409- pat += [(UPat .var ('x' )* - 1 , lambda ctx , x : x .alu (Ops .NEG ))]
410- if Ops .SUB in ops : pat += [(UPat .var ('x' )+ UPat .var ('y' ).alu (Ops .NEG ), lambda ctx , x ,y : x .alu (Ops .SUB , y ))]
349+ pat += [(UPat .var ('x' )* - 1 , lambda x : x .alu (Ops .NEG ))]
350+ if Ops .SUB in ops : pat += [(UPat .var ('x' )+ UPat .var ('y' ).alu (Ops .NEG ), lambda x ,y : x .alu (Ops .SUB , y ))]
411351 if Ops .CMPLT in ops :
412352 # These are late rewrites because simplex expects equalities to be a certain format
413353 pat += [
@@ -424,22 +364,4 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], device, force_transcendental)
424364 if Ops .FDIV in ops :
425365 pat += [(UPat .var ("x" ).reciprocal (), lambda x : x .const_like (1 ).alu (Ops .FDIV , x ))]
426366 pat += [(UPat .var ("a" , dtypes .floats ) * UPat .const (dtypes .floats , 1 ).alu (Ops .FDIV , UPat .var ("b" )), lambda a ,b : a .alu (Ops .FDIV , b ))]
427- if not is_dtype_supported (dtypes .long , device ):
428- pat += [(UPat ((* GroupOp .Defines , Ops .INDEX ), name = "x" ), lambda x :
429- x .replace (dtype = l2i_dt [x .dtype .base ].ptr (x .dtype .size * 2 )) if x .dtype .base in l2i_dt else None )]
430- pat += [(UPat (Ops .STORE , src = (UPat .var ('idx' ), UPat .var ('val' , tuple (l2i_dt .keys ()))), name = 'st' ), lambda st ,idx ,val :
431- st .replace (src = (l2i_idx (idx , 0 ), val .rtag (0 ))).group (st .replace (src = (l2i_idx (idx , 1 ), val .rtag (1 )))) if val .tag is None else None )]
432- pat += [(UPat (GroupOp .Comparison , src = (UPat .var ('a' , tuple (l2i_dt .keys ())), UPat .var ('b' , tuple (l2i_dt .keys ()))), name = "x" ), lambda a ,b ,x :
433- l2i (x .op , dt := l2i_dt [a .dtype ], a .rtag (0 ).cast (dt ), a .rtag (1 ).cast (dt ), b .rtag (0 ).cast (dt ), b .rtag (1 ).cast (dt )))]
434- pat += [(UPat (Ops .CAST , tuple (l2i_dt .keys ()), src = (UPat .var ('a' ),), name = "x" ), lambda a ,x :
435- l2i (x .op , x .dtype , a )[x .tag ] if x .tag is not None else None )]
436- pat += [(UPat (Ops .CAST , src = (UPat .var ('a' , tuple (l2i_dt .keys ())),), name = "x" ), lambda a ,x :
437- l2i (x .op , x .dtype , a .rtag (0 ).cast (dt := l2i_dt [a .dtype ]), a .rtag (1 ).cast (dt )))]
438- pat += [(UPat ((* (GroupOp .ALU - GroupOp .Comparison ), Ops .BITCAST ), tuple (l2i_dt .keys ()), name = "x" ), lambda x :
439- None if x .tag is None else l2i (x .op , l2i_dt [x .dtype ], * flatten ((a .rtag (0 ).cast (dt := l2i_dt [x .src [- 1 ].dtype ]), a .rtag (1 ).cast (dt ))
440- if a .dtype in l2i_dt else (a ,) for a in x .src ))[x .tag ])]
441- pat += [(UPat (Ops .LOAD , tuple (l2i_dt .keys ()), src = (UPat .var ('idx' ),), name = 'x' ), lambda x ,idx :
442- None if x .tag is None else x .replace (dtype = l2i_dt [x .dtype ], src = (l2i_idx (idx , x .tag ),)))]
443- pat += [(UPat (Ops .CONST , tuple (l2i_dt .keys ()), name = 'x' ), lambda x :
444- None if x .tag is None else UOp .const (l2i_dt [x .dtype ], (x .arg >> 32 ) if x .tag == 1 else (x .arg & 0xFFFFFFFF )))]
445367 return PatternMatcher (pat )
0 commit comments