@@ -532,6 +532,94 @@ class TestClass(PClass):
532532    assert  not  comparator (v , x )
533533
534534
535+ def  test_torch ():
536+     try :
537+         import  torch   # type: ignore 
538+     except  ImportError :
539+         pytest .skip ()
540+ 
541+     a  =  torch .tensor ([1 , 2 , 3 ])
542+     b  =  torch .tensor ([1 , 2 , 3 ])
543+     c  =  torch .tensor ([1 , 2 , 4 ])
544+     assert  comparator (a , b )
545+     assert  not  comparator (a , c )
546+ 
547+     d  =  torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
548+     e  =  torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
549+     f  =  torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 7 ]])
550+     assert  comparator (d , e )
551+     assert  not  comparator (d , f )
552+ 
553+     # Test tensors with different data types 
554+     g  =  torch .tensor ([1 , 2 , 3 ], dtype = torch .float32 )
555+     h  =  torch .tensor ([1 , 2 , 3 ], dtype = torch .float32 )
556+     i  =  torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 )
557+     assert  comparator (g , h )
558+     assert  not  comparator (g , i )
559+ 
560+     # Test 3D tensors 
561+     j  =  torch .tensor ([[[1 , 2 ], [3 , 4 ]], [[5 , 6 ], [7 , 8 ]]])
562+     k  =  torch .tensor ([[[1 , 2 ], [3 , 4 ]], [[5 , 6 ], [7 , 8 ]]])
563+     l  =  torch .tensor ([[[1 , 2 ], [3 , 4 ]], [[5 , 6 ], [7 , 9 ]]])
564+     assert  comparator (j , k )
565+     assert  not  comparator (j , l )
566+ 
567+     # Test tensors with different shapes 
568+     m  =  torch .tensor ([1 , 2 , 3 ])
569+     n  =  torch .tensor ([[1 , 2 , 3 ]])
570+     assert  not  comparator (m , n )
571+ 
572+     # Test empty tensors 
573+     o  =  torch .tensor ([])
574+     p  =  torch .tensor ([])
575+     q  =  torch .tensor ([1 ])
576+     assert  comparator (o , p )
577+     assert  not  comparator (o , q )
578+ 
579+     # Test tensors with NaN values 
580+     r  =  torch .tensor ([1.0 , float ('nan' ), 3.0 ])
581+     s  =  torch .tensor ([1.0 , float ('nan' ), 3.0 ])
582+     t  =  torch .tensor ([1.0 , 2.0 , 3.0 ])
583+     assert  comparator (r , s )  # NaN == NaN 
584+     assert  not  comparator (r , t )
585+ 
586+     # Test tensors with infinity values 
587+     u  =  torch .tensor ([1.0 , float ('inf' ), 3.0 ])
588+     v  =  torch .tensor ([1.0 , float ('inf' ), 3.0 ])
589+     w  =  torch .tensor ([1.0 , float ('-inf' ), 3.0 ])
590+     assert  comparator (u , v )
591+     assert  not  comparator (u , w )
592+ 
593+     # Test tensors with different devices (if CUDA is available) 
594+     if  torch .cuda .is_available ():
595+         x  =  torch .tensor ([1 , 2 , 3 ]).cuda ()
596+         y  =  torch .tensor ([1 , 2 , 3 ]).cuda ()
597+         z  =  torch .tensor ([1 , 2 , 3 ])
598+         assert  comparator (x , y )
599+         assert  not  comparator (x , z )
600+ 
601+     # Test tensors with requires_grad 
602+     aa  =  torch .tensor ([1. , 2. , 3. ], requires_grad = True )
603+     bb  =  torch .tensor ([1. , 2. , 3. ], requires_grad = True )
604+     cc  =  torch .tensor ([1. , 2. , 3. ], requires_grad = False )
605+     assert  comparator (aa , bb )
606+     assert  not  comparator (aa , cc )
607+ 
608+     # Test complex tensors 
609+     dd  =  torch .tensor ([1 + 2j , 3 + 4j ])
610+     ee  =  torch .tensor ([1 + 2j , 3 + 4j ])
611+     ff  =  torch .tensor ([1 + 2j , 3 + 5j ])
612+     assert  comparator (dd , ee )
613+     assert  not  comparator (dd , ff )
614+ 
615+     # Test boolean tensors 
616+     gg  =  torch .tensor ([True , False , True ])
617+     hh  =  torch .tensor ([True , False , True ])
618+     ii  =  torch .tensor ([True , True , True ])
619+     assert  comparator (gg , hh )
620+     assert  not  comparator (gg , ii )
621+ 
622+ 
535623def  test_returns ():
536624    a  =  Success (5 )
537625    b  =  Success (5 )
0 commit comments