@@ -583,6 +583,43 @@ def test_torch():
583583 assert comparator (r , s ) # NaN == NaN
584584 assert not comparator (r , t )
585585
586+ # Test tensors with infinity values
587+ u = torch .tensor ([1.0 , float ('inf' ), 3.0 ])
588+ v = torch .tensor ([1.0 , float ('inf' ), 3.0 ])
589+ w = torch .tensor ([1.0 , float ('-inf' ), 3.0 ])
590+ assert comparator (u , v )
591+ assert not comparator (u , w )
592+
593+ # Test tensors with different devices (if CUDA is available)
594+ if torch .cuda .is_available ():
595+ x = torch .tensor ([1 , 2 , 3 ]).cuda ()
596+ y = torch .tensor ([1 , 2 , 3 ]).cuda ()
597+ z = torch .tensor ([1 , 2 , 3 ])
598+ assert comparator (x , y )
599+ assert not comparator (x , z )
600+
601+ # Test tensors with requires_grad
602+ aa = torch .tensor ([1. , 2. , 3. ], requires_grad = True )
603+ bb = torch .tensor ([1. , 2. , 3. ], requires_grad = True )
604+ cc = torch .tensor ([1. , 2. , 3. ], requires_grad = False )
605+ assert comparator (aa , bb )
606+ assert not comparator (aa , cc )
607+
608+ # Test complex tensors
609+ dd = torch .tensor ([1 + 2j , 3 + 4j ])
610+ ee = torch .tensor ([1 + 2j , 3 + 4j ])
611+ ff = torch .tensor ([1 + 2j , 3 + 5j ])
612+ assert comparator (dd , ee )
613+ assert not comparator (dd , ff )
614+
615+ # Test boolean tensors
616+ gg = torch .tensor ([True , False , True ])
617+ hh = torch .tensor ([True , False , True ])
618+ ii = torch .tensor ([True , True , True ])
619+ assert comparator (gg , hh )
620+ assert not comparator (gg , ii )
621+
622+
586623def test_returns ():
587624 a = Success (5 )
588625 b = Success (5 )
0 commit comments