@@ -679,19 +679,27 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
679679 def to (self , * args , ** kwargs ):
680680 device , dtype , non_blocking , convert_to_format = torch ._C ._nn ._parse_to (* args , ** kwargs )
681681
682- if device is not None and device .type != "meta" and self .data .device .type == "cpu" :
683- if device .type != "cpu" or self .data .dtype != torch .int8 :
684- return self ._quantize (device )
685- elif self .data .dtype == torch .int8 and device .type == "cpu" :
686- self .CB = self .data
682+ is_quantized = self .data .dtype == torch .int8
687683
684+ if not is_quantized and device is not None and device .type != "meta" and self .data .device .type == "cpu" :
685+ # We're moving from a CPU device to a non-meta device.
686+ # In this circumstance, we want to quantize if we haven't already.
687+ return self ._quantize (device )
688+
689+ # Create a new parameter on the target device.
688690 new_param = Int8Params (
689691 super ().to (device = device , dtype = dtype , non_blocking = non_blocking ),
690692 requires_grad = self .requires_grad ,
691693 has_fp16_weights = self .has_fp16_weights ,
692694 )
693- new_param .CB = self .CB
694- new_param .SCB = self .SCB
695+
696+ # If we had already quantized, move the statistics appropriately.
697+ if is_quantized and device is not None :
698+ if self .CB is not None :
699+ new_param .CB = new_param .data
700+
701+ if self .SCB is not None :
702+ new_param .SCB = self .SCB .to (device )
695703
696704 return new_param
697705
@@ -1037,6 +1045,21 @@ def init_8bit_state(self):
10371045 self .weight .CB = None
10381046 self .weight .SCB = None
10391047
1048+ def to (self , * args , ** kwargs ):
1049+ # Call the parent to() method to handle standard parameter/buffer movement
1050+ result = super ().to (* args , ** kwargs )
1051+
1052+ device , dtype , non_blocking , convert_to_format = torch ._C ._nn ._parse_to (* args , ** kwargs )
1053+
1054+ # Handle state tensors if needed.
1055+ if device is not None :
1056+ if result .state .CB is not None :
1057+ result .state .CB = result .state .CB .to (device )
1058+ if result .state .SCB is not None :
1059+ result .state .SCB = result .state .SCB .to (device )
1060+
1061+ return result
1062+
10401063 def forward (self , x : torch .Tensor ):
10411064 self .state .is_training = self .training
10421065 if self .weight .CB is not None :
0 commit comments