Skip to content

Commit e76cf9d

Browse files
authored
fix:https://nvbugs/5234033 enable starcoder trt-flow with transforme… (#3909)
fix:https://nvbugs/5234033 enable startcoder trt-flow with transformer 4.51.3. Signed-off-by: nv-guomingz <[email protected]>
1 parent 5dc3b53 commit e76cf9d

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

tensorrt_llm/models/gpt/convert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -459,9 +459,9 @@ def load_weights_from_hf_model(hf_model,
459459
f'{prefix}.self_attn.k_proj', dtype)
460460
v_w, v_b = get_weight_and_bias(model_params,
461461
f'{prefix}.self_attn.v_proj', dtype)
462-
qkv_w = torch.cat([q_w, k_w, v_w], dim=0)
463-
qkv_b = torch.cat([q_b, k_b, v_b],
464-
dim=0) if q_b is not None else None
462+
qkv_w = torch.cat([q_w.cuda(), k_w.cuda(), v_w.cuda()], dim=0)
463+
qkv_b = torch.cat([q_b.cuda(), k_b.cuda(),
464+
v_b.cuda()], dim=0) if q_b is not None else None
465465
elif gpt_variant == 'persimmon':
466466
qkv_w, qkv_b = get_weight_and_bias(
467467
model_params, f'{prefix}.self_attn.query_key_value', dtype)

tensorrt_llm/parameter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,9 @@ def _get_weights(self, network) -> trt.Weights | Tensor | None:
265265
def _regularize_value(self, value):
266266
if isinstance(value, np.ndarray):
267267
return value
268+
269+
elif isinstance(value, torch.distributed.tensor.DTensor):
270+
return value.to_local().cpu().numpy()
268271
elif isinstance(value, torch.Tensor):
269272
return torch_to_numpy(value)
270273
raise TypeError(

0 commit comments

Comments
 (0)