Skip to content

Commit 64ccb4f

Browse files
committed
add cuda() with device id
1 parent a26d48d commit 64ccb4f

File tree

1 file changed

+6
-6
lines changed
  • lightllm/common/basemodel/layer_weights/meta_weights

1 file changed

+6
-6
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _post_load_weights(self) -> None:
7676
if self.weight_scale.ndim > 1:
7777
self.weight_scale = self.weight_scale.transpose(0, 1).cuda(get_current_device_id())
7878
self.weight = [
79-
self.weight.transpose(0, 1).cuda(),
79+
self.weight.transpose(0, 1).cuda(get_current_device_id()),
8080
self.weight_scale,
8181
self.input_scale,
8282
]
@@ -151,7 +151,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
151151

152152
if self.act_scale_name is not None and self.act_scale_name in weights:
153153
input_scale = weights[self.act_scale_name].to(torch.float)
154-
self.input_scale = input_scale.cuda()
154+
self.input_scale = input_scale.cuda(get_current_device_id())
155155

156156
if weight is None and weight_scale is None and input_scale is None:
157157
return
@@ -213,7 +213,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
213213

214214
if self.static_activation and self.act_scale_name in weights:
215215
input_scale = weights[self.act_scale_name].to(torch.float)
216-
self.input_scale = input_scale.cuda()
216+
self.input_scale = input_scale.cuda(get_current_device_id())
217217

218218
if weight is None and weight_scale is None and input_scale is None:
219219
return
@@ -291,13 +291,13 @@ def _fuse(self) -> None:
291291
delattr(self, "weights")
292292

293293
if self.weight_scale is None and (None not in self.weight_scales):
294-
self.weight_scale = torch.cat(self.weight_scales, dim=0).cuda()
294+
self.weight_scale = torch.cat(self.weight_scales, dim=0).cuda(get_current_device_id())
295295
self._post_load_weights()
296296
delattr(self, "weight_scales")
297297

298298
if self.static_activation and self.input_scale is None and (None not in self.input_scales):
299299
input_scales = torch.stack(self.input_scales, dim=0)
300-
self.input_scale = torch.max(input_scales).cuda()
300+
self.input_scale = torch.max(input_scales).cuda(get_current_device_id())
301301
self._post_load_weights()
302302
delattr(self, "input_scales")
303303

@@ -528,7 +528,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
528528

529529
if self.act_scale_name is not None and self.act_scale_name in weights:
530530
input_scale = weights[self.act_scale_name].to(torch.float)
531-
self.input_scale = input_scale.cuda()
531+
self.input_scale = input_scale.cuda(get_current_device_id())
532532

533533
if weight is None and weight_scale is None and input_scale is None:
534534
return

0 commit comments

Comments
 (0)