@@ -694,32 +694,30 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
694694 def to (self , * args , ** kwargs ):
695695 device , dtype , non_blocking , convert_to_format = torch ._C ._nn ._parse_to (* args , ** kwargs )
696696
697- if device is not None and device . type in ( "cuda" , "xpu" , "cpu" ) :
697+ if device is not None :
698698 if device .type == "cuda" and self .data .device .type == "cpu" :
699699 return self .cuda (device )
700700 elif device .type == "cpu" :
701701 if self .data .dtype == torch .int8 :
702702 self .CB = self .data
703- return self
704703 else :
705704 return self .cpu ()
706705 elif device .type == "xpu" :
707706 if self .data .dtype == torch .int8 :
708- self .data = self .data .contiguous (). xpu ( device )
707+ self .data = self .data .contiguous ()
709708 self .CB = self .data
710- return self
711- else :
709+ if self .data .device .type == "cpu" :
712710 return self .xpu (device )
713- else :
714- new_param = Int8Params (
715- super ().to (device = device , dtype = dtype , non_blocking = non_blocking ),
716- requires_grad = self .requires_grad ,
717- has_fp16_weights = self .has_fp16_weights ,
718- )
719- new_param .CB = self .CB
720- new_param .SCB = self .SCB
721711
722- return new_param
712+ new_param = Int8Params (
713+ super ().to (device = device , dtype = dtype , non_blocking = non_blocking ),
714+ requires_grad = self .requires_grad ,
715+ has_fp16_weights = self .has_fp16_weights ,
716+ )
717+ new_param .CB = self .CB
718+ new_param .SCB = self .SCB
719+
720+ return new_param
723721
724722
725723def maybe_rearrange_weight (state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs ):
0 commit comments