Skip to content

Commit 3ffadda

Browse files
author
Guang Yang
committed
Fix script export_hf_model.py
1 parent 708c6b6 commit 3ffadda

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

extension/export_util/export_hf_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,11 @@ def main() -> None:
7474
print(f"{model.generation_config}")
7575

7676
tokenizer = AutoTokenizer.from_pretrained(args.hf_model_repo)
77-
input_ids = tokenizer([""], return_tensors="pt").to(device)["input_ids"]
77+
input_ids = torch.tensor([[1]], dtype=torch.long)
7878
cache_position = torch.tensor([0], dtype=torch.long)
7979

8080
def _get_constant_methods(model: PreTrainedModel):
81-
return {
81+
metadata = {
8282
"get_dtype": 5 if model.config.torch_dtype == torch.float16 else 6,
8383
"get_bos_id": model.config.bos_token_id,
8484
"get_eos_id": model.config.eos_token_id,
@@ -90,6 +90,7 @@ def _get_constant_methods(model: PreTrainedModel):
9090
"get_vocab_size": model.config.vocab_size,
9191
"use_kv_cache": model.generation_config.use_cache,
9292
}
93+
return {k: v for k, v in metadata.items() if v is not None}
9394

9495
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
9596

0 commit comments

Comments
 (0)