@@ -310,28 +310,28 @@ def _quantize(self, device):
310310 def cpu (self ):
311311 return self .to (device = "cpu" )
312312
313- def cuda (self , device : Optional [Union [ int , device , str ] ] = None , non_blocking : bool = False ):
313+ def cuda (self , device : Optional [int | device | str ] = None , non_blocking : bool = False ):
314314 return self .to (device = "cuda" if device is None else device , non_blocking = non_blocking )
315315
316- def xpu (self , device : Optional [Union [ int , device , str ] ] = None , non_blocking : bool = False ):
316+ def xpu (self , device : Optional [int | device | str ] = None , non_blocking : bool = False ):
317317 return self .to (device = "xpu" if device is None else device , non_blocking = non_blocking )
318318
319319 @overload
320320 def to (
321321 self : T ,
322- device : Optional [Union [ int , device ] ] = ...,
323- dtype : Optional [Union [ dtype , str ] ] = ...,
322+ device : Optional [int | device ] = ...,
323+ dtype : Optional [dtype | str ] = ...,
324324 non_blocking : bool = ...,
325325 ) -> T : ...
326326
327327 @overload
328- def to (self : T , dtype : Union [ dtype , str ] , non_blocking : bool = ...) -> T : ...
328+ def to (self : T , dtype : dtype | str , non_blocking : bool = ...) -> T : ...
329329
330330 @overload
331331 def to (self : T , tensor : Tensor , non_blocking : bool = ...) -> T : ...
332332
333333 def to (self , * args , ** kwargs ):
334- device , dtype , non_blocking , convert_to_format = torch ._C ._nn ._parse_to (* args , ** kwargs )
334+ device , dtype , non_blocking , _ = torch ._C ._nn ._parse_to (* args , ** kwargs )
335335
336336 if device is not None and device .type != "meta" and not self .bnb_quantized :
337337 return self ._quantize (device )
@@ -644,10 +644,10 @@ def _quantize(self, device):
644644 def cpu (self ):
645645 return self .to (device = "cpu" )
646646
647- def cuda (self , device : Optional [Union [ int , device , str ] ] = None , non_blocking : bool = False ):
647+ def cuda (self , device : Optional [int | device | str ] = None , non_blocking : bool = False ):
648648 return self .to (device = "cuda" if device is None else device , non_blocking = non_blocking )
649649
650- def xpu (self , device : Optional [Union [ int , device , str ] ] = None , non_blocking : bool = False ):
650+ def xpu (self , device : Optional [int | device | str ] = None , non_blocking : bool = False ):
651651 return self .to (device = "xpu" if device is None else device , non_blocking = non_blocking )
652652
653653 def __deepcopy__ (self , memo ):
@@ -665,19 +665,19 @@ def __deepcopy__(self, memo):
665665 @overload
666666 def to (
667667 self : T ,
668- device : Optional [Union [ int , device ] ] = ...,
669- dtype : Optional [Union [ dtype , str ] ] = ...,
668+ device : Optional [int | device ] = ...,
669+ dtype : Optional [dtype | str ] = ...,
670670 non_blocking : bool = ...,
671671 ) -> T : ...
672672
673673 @overload
674- def to (self : T , dtype : Union [ dtype , str ] , non_blocking : bool = ...) -> T : ...
674+ def to (self : T , dtype : dtype | str , non_blocking : bool = ...) -> T : ...
675675
676676 @overload
677677 def to (self : T , tensor : Tensor , non_blocking : bool = ...) -> T : ...
678678
679679 def to (self , * args , ** kwargs ):
680- device , dtype , non_blocking , convert_to_format = torch ._C ._nn ._parse_to (* args , ** kwargs )
680+ device , dtype , non_blocking , _ = torch ._C ._nn ._parse_to (* args , ** kwargs )
681681
682682 is_quantized = self .data .dtype == torch .int8
683683
@@ -1048,7 +1048,7 @@ def to(self, *args, **kwargs):
10481048 # Call the parent to() method to handle standard parameter/buffer movement
10491049 result = super ().to (* args , ** kwargs )
10501050
1051- device , dtype , non_blocking , convert_to_format = torch ._C ._nn ._parse_to (* args , ** kwargs )
1051+ device , _ , _ , _ = torch ._C ._nn ._parse_to (* args , ** kwargs )
10521052
10531053 # Handle state tensors if needed.
10541054 if device is not None :
0 commit comments