@@ -76,7 +76,7 @@ def _post_load_weights(self) -> None:
7676 if self .weight_scale .ndim > 1 :
7777 self .weight_scale = self .weight_scale .transpose (0 , 1 ).cuda (get_current_device_id ())
7878 self .weight = [
79- self .weight .transpose (0 , 1 ).cuda (),
79+ self .weight .transpose (0 , 1 ).cuda (get_current_device_id () ),
8080 self .weight_scale ,
8181 self .input_scale ,
8282 ]
@@ -151,7 +151,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
151151
152152 if self .act_scale_name is not None and self .act_scale_name in weights :
153153 input_scale = weights [self .act_scale_name ].to (torch .float )
154- self .input_scale = input_scale .cuda ()
154+ self .input_scale = input_scale .cuda (get_current_device_id () )
155155
156156 if weight is None and weight_scale is None and input_scale is None :
157157 return
@@ -213,7 +213,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
213213
214214 if self .static_activation and self .act_scale_name in weights :
215215 input_scale = weights [self .act_scale_name ].to (torch .float )
216- self .input_scale = input_scale .cuda ()
216+ self .input_scale = input_scale .cuda (get_current_device_id () )
217217
218218 if weight is None and weight_scale is None and input_scale is None :
219219 return
@@ -291,13 +291,13 @@ def _fuse(self) -> None:
291291 delattr (self , "weights" )
292292
293293 if self .weight_scale is None and (None not in self .weight_scales ):
294- self .weight_scale = torch .cat (self .weight_scales , dim = 0 ).cuda ()
294+ self .weight_scale = torch .cat (self .weight_scales , dim = 0 ).cuda (get_current_device_id () )
295295 self ._post_load_weights ()
296296 delattr (self , "weight_scales" )
297297
298298 if self .static_activation and self .input_scale is None and (None not in self .input_scales ):
299299 input_scales = torch .stack (self .input_scales , dim = 0 )
300- self .input_scale = torch .max (input_scales ).cuda ()
300+ self .input_scale = torch .max (input_scales ).cuda (get_current_device_id () )
301301 self ._post_load_weights ()
302302 delattr (self , "input_scales" )
303303
@@ -528,7 +528,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
528528
529529 if self .act_scale_name is not None and self .act_scale_name in weights :
530530 input_scale = weights [self .act_scale_name ].to (torch .float )
531- self .input_scale = input_scale .cuda ()
531+ self .input_scale = input_scale .cuda (get_current_device_id () )
532532
533533 if weight is None and weight_scale is None and input_scale is None :
534534 return
0 commit comments