@@ -368,8 +368,14 @@ def __add__(self, other: FloatLike) -> Float: ...
368368 def __sub__ (self , other : FloatLike ) -> Float : ...
369369
370370 def __pow__ (self , other : FloatLike ) -> Float : ...
371+ def __round__ (self , ndigits : OptionalIntLike = None ) -> Float : ...
371372
372373 def __eq__ (self , other : FloatLike ) -> Boolean : ... # type: ignore[override]
374+ def __ne__ (self , other : FloatLike ) -> Boolean : ... # type: ignore[override]
375+ def __lt__ (self , other : FloatLike ) -> Boolean : ...
376+ def __le__ (self , other : FloatLike ) -> Boolean : ...
377+ def __gt__ (self , other : FloatLike ) -> Boolean : ...
378+ def __ge__ (self , other : FloatLike ) -> Boolean : ...
373379
374380
375381converter (float , Float , lambda x : Float (x ))
@@ -380,9 +386,10 @@ def __eq__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
380386
381387
382388@array_api_ruleset .register
383- def _float (fl : Float , f : f64 , f2 : f64 , i : i64 , r : BigRat , r1 : BigRat ):
389+ def _float (fl : Float , f : f64 , f2 : f64 , i : i64 , r : BigRat , r1 : BigRat , i_ : Int ):
384390 return [
385391 rule (eq (fl ).to (Float (f ))).then (set_ (fl .to_f64 ).to (f )),
392+ rewrite (Float .from_int (Int (i ))).to (Float (f64 .from_i64 (i ))),
386393 rewrite (Float (f ).abs ()).to (Float (f ), f >= 0.0 ),
387394 rewrite (Float (f ).abs ()).to (Float (- f ), f < 0.0 ),
388395 # Convert from float to rationl, if its a whole number i.e. can be converted to int
@@ -397,11 +404,22 @@ def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat):
397404 rewrite (Float .rational (r ) - Float .rational (r1 )).to (Float .rational (r - r1 )),
398405 rewrite (Float .rational (r ) * Float .rational (r1 )).to (Float .rational (r * r1 )),
399406 rewrite (Float (f ) ** Float (f2 )).to (Float (f ** f2 )),
400- # ==
407+ # comparisons
401408 rewrite (Float (f ) == Float (f )).to (TRUE ),
402409 rewrite (Float (f ) == Float (f2 )).to (FALSE , ne (f ).to (f2 )),
410+ rewrite (Float (f ) != Float (f2 )).to (TRUE , f != f2 ),
411+ rewrite (Float (f ) != Float (f )).to (FALSE ),
412+ rewrite (Float (f ) >= Float (f2 )).to (TRUE , f >= f2 ),
413+ rewrite (Float (f ) >= Float (f2 )).to (FALSE , f < f2 ),
414+ rewrite (Float (f ) <= Float (f2 )).to (TRUE , f <= f2 ),
415+ rewrite (Float (f ) <= Float (f2 )).to (FALSE , f > f2 ),
416+ rewrite (Float (f ) > Float (f2 )).to (TRUE , f > f2 ),
417+ rewrite (Float (f ) > Float (f2 )).to (FALSE , f <= f2 ),
418+ rewrite (Float (f ) < Float (f2 )).to (TRUE , f < f2 ),
403419 rewrite (Float .rational (r ) == Float .rational (r )).to (TRUE ),
404420 rewrite (Float .rational (r ) == Float .rational (r1 )).to (FALSE , ne (r ).to (r1 )),
421+ # round
422+ rewrite (Float .rational (r ).__round__ ()).to (Float .rational (r .round ())),
405423 ]
406424
407425
0 commit comments