@@ -189,7 +189,7 @@ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int):
189189 yield rewrite (Int (i ) + Int (j )).to (Int (i + j ))
190190 yield rewrite (Int (i ) - Int (j )).to (Int (i - j ))
191191 yield rewrite (Int (i ) * Int (j )).to (Int (i * j ))
192- yield rewrite (Int (i ) / Int (j )).to (Int (i / j ))
192+ yield rewrite (Int (i ) // Int (j )).to (Int (i / j ))
193193 yield rewrite (Int (i ) % Int (j )).to (Int (i % j ))
194194 yield rewrite (Int (i ) & Int (j )).to (Int (i & j ))
195195 yield rewrite (Int (i ) | Int (j )).to (Int (i | j ))
@@ -219,15 +219,17 @@ def abs(self) -> Float: ...
219219 def rational (cls , r : Rational ) -> Float : ...
220220
221221 @classmethod
222- def from_int (cls , i : Int ) -> Float : ...
222+ def from_int (cls , i : IntLike ) -> Float : ...
223223
224- def __truediv__ (self , other : Float ) -> Float : ...
224+ def __truediv__ (self , other : FloatLike ) -> Float : ...
225225
226- def __mul__ (self , other : Float ) -> Float : ...
226+ def __mul__ (self , other : FloatLike ) -> Float : ...
227227
228- def __add__ (self , other : Float ) -> Float : ...
228+ def __add__ (self , other : FloatLike ) -> Float : ...
229229
230- def __sub__ (self , other : Float ) -> Float : ...
230+ def __sub__ (self , other : FloatLike ) -> Float : ...
231+
232+ def __pow__ (self , other : FloatLike ) -> Float : ...
231233
232234
233235converter (float , Float , lambda x : Float (x ))
@@ -252,6 +254,7 @@ def _float(f: f64, f2: f64, i: i64, r: Rational, r1: Rational):
252254 rewrite (Float .rational (r ) + Float .rational (r1 )).to (Float .rational (r + r1 )),
253255 rewrite (Float .rational (r ) - Float .rational (r1 )).to (Float .rational (r - r1 )),
254256 rewrite (Float .rational (r ) * Float .rational (r1 )).to (Float .rational (r * r1 )),
257+ rewrite (Float (f ) ** Float (f2 )).to (Float (f ** f2 )),
255258 ]
256259
257260
@@ -271,6 +274,7 @@ def var(cls, name: StringLike) -> TupleInt: ...
271274
272275 EMPTY : ClassVar [TupleInt ]
273276
277+ @method (cost = 100 )
274278 def __init__ (self , length : IntLike , idx_fn : Callable [[Int ], Int ]) -> None : ...
275279
276280 @classmethod
@@ -325,13 +329,57 @@ def if_(cls, b: Boolean, i: TupleInt, j: TupleInt) -> TupleInt: ...
325329 def to_py (self ) -> tuple [int , ...]:
326330 return tuple (int (i ) for i in self )
327331
332+ @method (subsume = True )
333+ def drop (self , n : Int ) -> TupleInt :
334+ return TupleInt (self .length () - n , lambda i : self [i + n ])
335+
336+ @method (subsume = True )
337+ def product (self ) -> Int :
338+ return self .fold (Int (1 ), lambda acc , i : acc * i )
339+
340+ def map_tuple_int (self , f : Callable [[Int ], TupleInt ]) -> TupleTupleInt : ...
341+
342+ def append (self , i : Int ) -> TupleInt : ...
343+
328344
329345# TODO: Upcast args for Vec[Int] constructor
330346converter (tuple , TupleInt , lambda x : TupleInt .from_vec (Vec (* (convert (i , Int ) for i in x ))))
331347
332348TupleIntLike : TypeAlias = TupleInt | tuple [IntLike , ...]
333349
334350
351+ @array_api_ruleset .register
352+ def _tuple_int_create_from_vec (
353+ x : NDArray , idx_fn : Callable [[Int ], Int ], i : i64 , xs : Vec [Int ], ti : TupleInt , ti2 : TupleInt , v : Int
354+ ):
355+ """
356+ Turn a tuple into constructor with a known length into a from_vec constructor
357+ """
358+ # # create from_vec from zero length tuple
359+ # yield rule(eq(ti).to(TupleInt(0, idx_fn))).then(union(ti).with_(TupleInt.from_vec(Vec[Int]())))
360+
361+ # yield rewrite(x.index(TupleInt(0, idx_fn))).to(x.index(TupleInt.from_vec(Vec[Int]())))
362+ # yield rewrite(x.index(TupleInt(1, idx_fn))).to(x.index(TupleInt.from_vec(Vec(idx_fn(Int(0))))))
363+ # yield rewrite(x.index(TupleInt(2, idx_fn))).to(x.index(TupleInt.from_vec(Vec(idx_fn(Int(0)), idx_fn(Int(1))))))
364+ # yield rewrite(x.index(TupleInt(3, idx_fn))).to(
365+ # x.index(TupleInt.from_vec(Vec(idx_fn(Int(0)), idx_fn(Int(1)), idx_fn(Int(2)))))
366+ # )
367+ yield rewrite (x .index (TupleInt (4 , idx_fn ))).to (
368+ x .index (
369+ TupleInt .from_vec (Vec (idx_fn (Int (0 )), idx_fn (Int (1 )), idx_fn (Int (2 )), idx_fn (Int (3 ))))
370+ # TupleInt.EMPTY.append(idx_fn(Int(0))).append(idx_fn(Int(1))).append(idx_fn(Int(2))).append(idx_fn(Int(3)))
371+ )
372+ )
373+
374+
375+ # # Also create it when appending onto a tuple that already has a vec
376+ # # yield rule(eq(ti).to(ti2.append(v)), eq(ti2).to(TupleInt.from_vec(xs))).then(
377+ # # union(ti).with_(TupleInt.from_vec(xs.append(Vec(v))))
378+ # # )
379+ # # Split up known length tuple vecs into append calls so they will be transformed into from_vec
380+ # yield rewrite(TupleInt(i, idx_fn)).to(TupleInt(i - 1, idx_fn).append(idx_fn(Int(i - 1))), i > 0)
381+
382+
335383@array_api_ruleset .register
336384def _tuple_int (
337385 i : Int ,
@@ -340,7 +388,7 @@ def _tuple_int(
340388 f : Callable [[Int , Int ], Int ],
341389 bool_f : Callable [[Boolean , Int ], Boolean ],
342390 idx_fn : Callable [[Int ], Int ],
343- map_fn : Callable [[Int ], Int ],
391+ map_tuple_int_fn : Callable [[Int ], TupleInt ],
344392 filter_f : Callable [[Int ], Boolean ],
345393 vs : Vec [Int ],
346394 b : Boolean ,
@@ -351,7 +399,7 @@ def _tuple_int(
351399 rewrite (TupleInt (i , idx_fn ).length ()).to (i ),
352400 rewrite (TupleInt (i , idx_fn )[i2 ]).to (idx_fn (i2 )),
353401 # index_vec_int
354- rewrite ( index_vec_int (vs , Int (k ))). to ( vs [ k ], vs .length () > k ),
402+ rule ( eq ( i ). to ( index_vec_int (vs , Int (k ))), k < vs .length (), k >= 0 ). then ( union ( i ). with_ ( vs [ k ]) ),
355403 # fold
356404 rewrite (TupleInt (0 , idx_fn ).fold (i , f )).to (i ),
357405 rewrite (TupleInt (Int (k ), idx_fn ).fold (i , f )).to (
@@ -379,6 +427,10 @@ def _tuple_int(
379427 # if_
380428 rewrite (TupleInt .if_ (TRUE , ti , ti2 )).to (ti ),
381429 rewrite (TupleInt .if_ (FALSE , ti , ti2 )).to (ti2 ),
430+ # map_tuple_int
431+ rewrite (TupleInt (i , idx_fn ).map_tuple_int (map_tuple_int_fn )).to (
432+ TupleTupleInt (i , lambda i : map_tuple_int_fn (idx_fn (i )))
433+ ),
382434 ]
383435
384436
@@ -418,6 +470,55 @@ def __len__(self) -> int:
418470 def __iter__ (self ) -> Iterator [TupleInt ]:
419471 return iter (self [i ] for i in range (len (self )))
420472
473+ def drop (self , n : Int ) -> TupleTupleInt :
474+ return TupleTupleInt (self .length () - n , lambda i : self [i + n ])
475+
476+ def map_int (self , f : Callable [[TupleInt ], Int ]) -> TupleInt : ...
477+
478+ def reduce_value (self , f : Callable [[Value , TupleInt ], Value ], init : ValueLike ) -> Value : ...
479+
480+ def product (self ) -> TupleTupleInt :
481+ """
482+ Cartesian product of inputs
483+
484+ https://docs.python.org/3/library/itertools.html#itertools.product
485+
486+ https://github.com/saulshanabrook/saulshanabrook/discussions/39
487+ """
488+ return TupleTupleInt (
489+ self .map_int (lambda x : x .length ()).product (),
490+ lambda i : TupleInt (
491+ self .length (),
492+ lambda j : self [j ][i // self .drop (j ).map_int (lambda x : x .length ()).product () % self [j ].length ()],
493+ ),
494+ )
495+
496+
497+ @array_api_ruleset .register
498+ def _tuple_tuple_int (
499+ length : Int ,
500+ fn : Callable [[TupleInt ], Int ],
501+ idx_fn : Callable [[Int ], TupleInt ],
502+ f : Callable [[Value , TupleInt ], Value ],
503+ i : Value ,
504+ k : i64 ,
505+ idx : Int ,
506+ ):
507+ yield rewrite (TupleTupleInt (length , idx_fn ).length ()).to (length )
508+
509+ yield rewrite (TupleTupleInt (length , idx_fn )[idx ]).to (idx_fn (idx ))
510+
511+ yield rewrite (TupleTupleInt (length , idx_fn ).map_int (fn ), subsume = True ).to (TupleInt (length , lambda i : fn (idx_fn (i ))))
512+
513+ yield rewrite (TupleTupleInt (0 , idx_fn ).reduce_value (f , i )).to (i )
514+ yield rewrite (TupleTupleInt (Int (k ), idx_fn ).reduce_value (f , i ), subsume = True ).to (
515+ f (
516+ TupleTupleInt (k - 1 , lambda i : idx_fn (i + 1 )).reduce_value (f , i ),
517+ idx_fn (Int (0 )),
518+ ),
519+ ne (k ).to (i64 (0 )),
520+ )
521+
421522
422523@function
423524def bottom_indexing (i : Int ) -> Int : ...
@@ -627,19 +728,23 @@ class Device(Expr): ...
627728
628729class Value (Expr ):
629730 @classmethod
630- def int (cls , i : Int ) -> Value : ...
731+ def int (cls , i : IntLike ) -> Value : ...
631732
632733 @classmethod
633- def float (cls , f : Float ) -> Value : ...
734+ def float (cls , f : FloatLike ) -> Value : ...
634735
635736 @classmethod
636- def bool (cls , b : Boolean ) -> Value : ...
737+ def bool (cls , b : BooleanLike ) -> Value : ...
637738
638739 def isfinite (self ) -> Boolean : ...
639740
640- def __lt__ (self , other : Value ) -> Value : ...
741+ def __lt__ (self , other : ValueLike ) -> Value : ...
742+
743+ def __truediv__ (self , other : ValueLike ) -> Value : ...
641744
642- def __truediv__ (self , other : Value ) -> Value : ...
745+ def __mul__ (self , other : ValueLike ) -> Value : ...
746+
747+ def __add__ (self , other : ValueLike ) -> Value : ...
643748
644749 def astype (self , dtype : DType ) -> Value : ...
645750
@@ -665,17 +770,21 @@ def to_truthy_value(self) -> Value:
665770 https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.any.html
666771 """
667772
773+ def conj (self ) -> Value : ...
774+ def real (self ) -> Value : ...
775+ def sqrt (self ) -> Value : ...
776+
777+
778+ ValueLike : TypeAlias = Value | IntLike | FloatLike | BooleanLike
668779
669780converter (Int , Value , Value .int )
670781converter (Float , Value , Value .float )
671782converter (Boolean , Value , Value .bool )
672783converter (Value , Int , lambda x : x .to_int , 10 )
673784
674- ValueLike : TypeAlias = Value | IntLike | FloatLike | BooleanLike
675-
676785
677786@array_api_ruleset .register
678- def _value (i : Int , f : Float , b : Boolean ):
787+ def _value (i : Int , f : Float , b : Boolean , v : Value ):
679788 # Default dtypes
680789 # https://data-apis.org/array-api/latest/API_specification/data_types.html?highlight=dtype#default-data-types
681790 yield rewrite (Value .int (i ).dtype ).to (DType .int64 )
@@ -688,6 +797,15 @@ def _value(i: Int, f: Float, b: Boolean):
688797 yield rewrite (Value .bool (b ).to_truthy_value ).to (Value .bool (b ))
689798 # TODO: Add more rules for to_bool_value
690799
800+ yield rewrite (Value .float (f ).conj ()).to (Value .float (f ))
801+ yield rewrite (Value .float (f ).real ()).to (Value .float (f ))
802+ yield rewrite (Value .int (i ).real ()).to (Value .int (i ))
803+ yield rewrite (Value .int (i ).conj ()).to (Value .int (i ))
804+
805+ yield rewrite (Value .float (f ).sqrt ()).to (Value .float (f ** (0.5 )))
806+
807+ yield rewrite (Value .float (Float .rational (Rational (0 , 1 ))) + v ).to (v )
808+
691809
692810class TupleValue (Expr ):
693811 EMPTY : ClassVar [TupleValue ]
0 commit comments