@@ -710,6 +710,81 @@ def test_torch():
710
710
assert not comparator (gg , ii )
711
711
712
712
713
+ def test_jax ():
714
+ try :
715
+ import jax .numpy as jnp
716
+ except ImportError :
717
+ pytest .skip ()
718
+
719
+ # Test basic arrays
720
+ a = jnp .array ([1 , 2 , 3 ])
721
+ b = jnp .array ([1 , 2 , 3 ])
722
+ c = jnp .array ([1 , 2 , 4 ])
723
+ assert comparator (a , b )
724
+ assert not comparator (a , c )
725
+
726
+ # Test 2D arrays
727
+ d = jnp .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
728
+ e = jnp .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
729
+ f = jnp .array ([[1 , 2 , 3 ], [4 , 5 , 7 ]])
730
+ assert comparator (d , e )
731
+ assert not comparator (d , f )
732
+
733
+ # Test arrays with different data types
734
+ g = jnp .array ([1 , 2 , 3 ], dtype = jnp .float32 )
735
+ h = jnp .array ([1 , 2 , 3 ], dtype = jnp .float32 )
736
+ i = jnp .array ([1 , 2 , 3 ], dtype = jnp .int32 )
737
+ assert comparator (g , h )
738
+ assert not comparator (g , i )
739
+
740
+ # Test 3D arrays
741
+ j = jnp .array ([[[1 , 2 ], [3 , 4 ]], [[5 , 6 ], [7 , 8 ]]])
742
+ k = jnp .array ([[[1 , 2 ], [3 , 4 ]], [[5 , 6 ], [7 , 8 ]]])
743
+ l = jnp .array ([[[1 , 2 ], [3 , 4 ]], [[5 , 6 ], [7 , 9 ]]])
744
+ assert comparator (j , k )
745
+ assert not comparator (j , l )
746
+
747
+ # Test arrays with different shapes
748
+ m = jnp .array ([1 , 2 , 3 ])
749
+ n = jnp .array ([[1 , 2 , 3 ]])
750
+ assert not comparator (m , n )
751
+
752
+ # Test empty arrays
753
+ o = jnp .array ([])
754
+ p = jnp .array ([])
755
+ q = jnp .array ([1 ])
756
+ assert comparator (o , p )
757
+ assert not comparator (o , q )
758
+
759
+ # Test arrays with NaN values
760
+ r = jnp .array ([1.0 , jnp .nan , 3.0 ])
761
+ s = jnp .array ([1.0 , jnp .nan , 3.0 ])
762
+ t = jnp .array ([1.0 , 2.0 , 3.0 ])
763
+ assert comparator (r , s ) # NaN == NaN
764
+ assert not comparator (r , t )
765
+
766
+ # Test arrays with infinity values
767
+ u = jnp .array ([1.0 , jnp .inf , 3.0 ])
768
+ v = jnp .array ([1.0 , jnp .inf , 3.0 ])
769
+ w = jnp .array ([1.0 , - jnp .inf , 3.0 ])
770
+ assert comparator (u , v )
771
+ assert not comparator (u , w )
772
+
773
+ # Test complex arrays
774
+ x = jnp .array ([1 + 2j , 3 + 4j ])
775
+ y = jnp .array ([1 + 2j , 3 + 4j ])
776
+ z = jnp .array ([1 + 2j , 3 + 5j ])
777
+ assert comparator (x , y )
778
+ assert not comparator (x , z )
779
+
780
+ # Test boolean arrays
781
+ aa = jnp .array ([True , False , True ])
782
+ bb = jnp .array ([True , False , True ])
783
+ cc = jnp .array ([True , True , True ])
784
+ assert comparator (aa , bb )
785
+ assert not comparator (aa , cc )
786
+
787
+
713
788
def test_returns ():
714
789
a = Success (5 )
715
790
b = Success (5 )
0 commit comments