Skip to content

Commit 1d68171

Browse files
author
Guang Yang
committed
Fix script export_hf_model.py
1 parent 708c6b6 commit 1d68171

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

extension/export_util/export_hf_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,11 @@ def main() -> None:
7373
print(f"{model.config}")
7474
print(f"{model.generation_config}")
7575

76-
tokenizer = AutoTokenizer.from_pretrained(args.hf_model_repo)
77-
input_ids = tokenizer([""], return_tensors="pt").to(device)["input_ids"]
76+
input_ids = torch.tensor([[1]], dtype=torch.long)
7877
cache_position = torch.tensor([0], dtype=torch.long)
7978

8079
def _get_constant_methods(model: PreTrainedModel):
81-
return {
80+
metadata = {
8281
"get_dtype": 5 if model.config.torch_dtype == torch.float16 else 6,
8382
"get_bos_id": model.config.bos_token_id,
8483
"get_eos_id": model.config.eos_token_id,
@@ -90,6 +89,7 @@ def _get_constant_methods(model: PreTrainedModel):
9089
"get_vocab_size": model.config.vocab_size,
9190
"use_kv_cache": model.generation_config.use_cache,
9291
}
92+
return {k: v for k, v in metadata.items() if v is not None}
9393

9494
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
9595

0 commit comments

Comments
 (0)