@@ -710,6 +710,81 @@ def test_torch():
710710 assert not comparator (gg , ii )
711711
712712
713+ def test_jax ():
714+ try :
715+ import jax .numpy as jnp
716+ except ImportError :
717+ pytest .skip ()
718+
719+ # Test basic arrays
720+ a = jnp .array ([1 , 2 , 3 ])
721+ b = jnp .array ([1 , 2 , 3 ])
722+ c = jnp .array ([1 , 2 , 4 ])
723+ assert comparator (a , b )
724+ assert not comparator (a , c )
725+
726+ # Test 2D arrays
727+ d = jnp .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
728+ e = jnp .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
729+ f = jnp .array ([[1 , 2 , 3 ], [4 , 5 , 7 ]])
730+ assert comparator (d , e )
731+ assert not comparator (d , f )
732+
733+ # Test arrays with different data types
734+ g = jnp .array ([1 , 2 , 3 ], dtype = jnp .float32 )
735+ h = jnp .array ([1 , 2 , 3 ], dtype = jnp .float32 )
736+ i = jnp .array ([1 , 2 , 3 ], dtype = jnp .int32 )
737+ assert comparator (g , h )
738+ assert not comparator (g , i )
739+
740+ # Test 3D arrays
741+ j = jnp .array ([[[1 , 2 ], [3 , 4 ]], [[5 , 6 ], [7 , 8 ]]])
742+ k = jnp .array ([[[1 , 2 ], [3 , 4 ]], [[5 , 6 ], [7 , 8 ]]])
743+ l = jnp .array ([[[1 , 2 ], [3 , 4 ]], [[5 , 6 ], [7 , 9 ]]])
744+ assert comparator (j , k )
745+ assert not comparator (j , l )
746+
747+ # Test arrays with different shapes
748+ m = jnp .array ([1 , 2 , 3 ])
749+ n = jnp .array ([[1 , 2 , 3 ]])
750+ assert not comparator (m , n )
751+
752+ # Test empty arrays
753+ o = jnp .array ([])
754+ p = jnp .array ([])
755+ q = jnp .array ([1 ])
756+ assert comparator (o , p )
757+ assert not comparator (o , q )
758+
759+ # Test arrays with NaN values
760+ r = jnp .array ([1.0 , jnp .nan , 3.0 ])
761+ s = jnp .array ([1.0 , jnp .nan , 3.0 ])
762+ t = jnp .array ([1.0 , 2.0 , 3.0 ])
763+ assert comparator (r , s ) # NaN == NaN
764+ assert not comparator (r , t )
765+
766+ # Test arrays with infinity values
767+ u = jnp .array ([1.0 , jnp .inf , 3.0 ])
768+ v = jnp .array ([1.0 , jnp .inf , 3.0 ])
769+ w = jnp .array ([1.0 , - jnp .inf , 3.0 ])
770+ assert comparator (u , v )
771+ assert not comparator (u , w )
772+
773+ # Test complex arrays
774+ x = jnp .array ([1 + 2j , 3 + 4j ])
775+ y = jnp .array ([1 + 2j , 3 + 4j ])
776+ z = jnp .array ([1 + 2j , 3 + 5j ])
777+ assert comparator (x , y )
778+ assert not comparator (x , z )
779+
780+ # Test boolean arrays
781+ aa = jnp .array ([True , False , True ])
782+ bb = jnp .array ([True , False , True ])
783+ cc = jnp .array ([True , True , True ])
784+ assert comparator (aa , bb )
785+ assert not comparator (aa , cc )
786+
787+
713788def test_returns ():
714789 a = Success (5 )
715790 b = Success (5 )
0 commit comments