@@ -532,6 +532,94 @@ class TestClass(PClass):
532532 assert not comparator (v , x )
533533
534534
535+ def test_torch ():
536+ try :
537+ import torch # type: ignore
538+ except ImportError :
539+ pytest .skip ()
540+
541+ a = torch .tensor ([1 , 2 , 3 ])
542+ b = torch .tensor ([1 , 2 , 3 ])
543+ c = torch .tensor ([1 , 2 , 4 ])
544+ assert comparator (a , b )
545+ assert not comparator (a , c )
546+
547+ d = torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
548+ e = torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
549+ f = torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 7 ]])
550+ assert comparator (d , e )
551+ assert not comparator (d , f )
552+
553+ # Test tensors with different data types
554+ g = torch .tensor ([1 , 2 , 3 ], dtype = torch .float32 )
555+ h = torch .tensor ([1 , 2 , 3 ], dtype = torch .float32 )
556+ i = torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 )
557+ assert comparator (g , h )
558+ assert not comparator (g , i )
559+
560+ # Test 3D tensors
561+ j = torch .tensor ([[[1 , 2 ], [3 , 4 ]], [[5 , 6 ], [7 , 8 ]]])
562+ k = torch .tensor ([[[1 , 2 ], [3 , 4 ]], [[5 , 6 ], [7 , 8 ]]])
563+ l = torch .tensor ([[[1 , 2 ], [3 , 4 ]], [[5 , 6 ], [7 , 9 ]]])
564+ assert comparator (j , k )
565+ assert not comparator (j , l )
566+
567+ # Test tensors with different shapes
568+ m = torch .tensor ([1 , 2 , 3 ])
569+ n = torch .tensor ([[1 , 2 , 3 ]])
570+ assert not comparator (m , n )
571+
572+ # Test empty tensors
573+ o = torch .tensor ([])
574+ p = torch .tensor ([])
575+ q = torch .tensor ([1 ])
576+ assert comparator (o , p )
577+ assert not comparator (o , q )
578+
579+ # Test tensors with NaN values
580+ r = torch .tensor ([1.0 , float ('nan' ), 3.0 ])
581+ s = torch .tensor ([1.0 , float ('nan' ), 3.0 ])
582+ t = torch .tensor ([1.0 , 2.0 , 3.0 ])
583+ assert comparator (r , s ) # NaN == NaN
584+ assert not comparator (r , t )
585+
586+ # Test tensors with infinity values
587+ u = torch .tensor ([1.0 , float ('inf' ), 3.0 ])
588+ v = torch .tensor ([1.0 , float ('inf' ), 3.0 ])
589+ w = torch .tensor ([1.0 , float ('-inf' ), 3.0 ])
590+ assert comparator (u , v )
591+ assert not comparator (u , w )
592+
593+ # Test tensors with different devices (if CUDA is available)
594+ if torch .cuda .is_available ():
595+ x = torch .tensor ([1 , 2 , 3 ]).cuda ()
596+ y = torch .tensor ([1 , 2 , 3 ]).cuda ()
597+ z = torch .tensor ([1 , 2 , 3 ])
598+ assert comparator (x , y )
599+ assert not comparator (x , z )
600+
601+ # Test tensors with requires_grad
602+ aa = torch .tensor ([1. , 2. , 3. ], requires_grad = True )
603+ bb = torch .tensor ([1. , 2. , 3. ], requires_grad = True )
604+ cc = torch .tensor ([1. , 2. , 3. ], requires_grad = False )
605+ assert comparator (aa , bb )
606+ assert not comparator (aa , cc )
607+
608+ # Test complex tensors
609+ dd = torch .tensor ([1 + 2j , 3 + 4j ])
610+ ee = torch .tensor ([1 + 2j , 3 + 4j ])
611+ ff = torch .tensor ([1 + 2j , 3 + 5j ])
612+ assert comparator (dd , ee )
613+ assert not comparator (dd , ff )
614+
615+ # Test boolean tensors
616+ gg = torch .tensor ([True , False , True ])
617+ hh = torch .tensor ([True , False , True ])
618+ ii = torch .tensor ([True , True , True ])
619+ assert comparator (gg , hh )
620+ assert not comparator (gg , ii )
621+
622+
535623def test_returns ():
536624 a = Success (5 )
537625 b = Success (5 )
0 commit comments