Skip to content

Commit 0c5536e

Browse files
Guang Yangfacebook-github-bot
authored andcommitted
Fix script export_hf_model.py (#6246)
Summary: Having `bos_token_id = None` (and other fields) will cause an emitter error for models that doesn't define that field, for example, [olmo-1b](https://huggingface.co/allenai/OLMo-1B-hf/blob/main/config.json#L8). ``` raise ExportError( executorch.exir.error.ExportError: [ExportErrorType.NOT_SUPPORTED]: Error emitting get_bos_id which returns a value of type <class 'NoneType'>. which is not a supported primitive ``` This PR avoids emitting with unsupported primitive type by removing the `None` fields out from the model metadata. The ExecuTorch runtime will assume the default value for those unspecified fields. Pull Request resolved: #6246 Reviewed By: kirklandsign Differential Revision: D64438406 Pulled By: guangy10 fbshipit-source-id: dceb7a08c1231d7dadf237476b9a229580a96b94
1 parent dc4be7c commit 0c5536e

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

extension/export_util/export_hf_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
1313
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
1414
from torch.nn.attention import SDPBackend
15-
from transformers import AutoModelForCausalLM, AutoTokenizer
15+
from transformers import AutoModelForCausalLM
1616
from transformers.generation.configuration_utils import GenerationConfig
1717
from transformers.integrations.executorch import convert_and_export_with_cache
1818
from transformers.modeling_utils import PreTrainedModel
@@ -73,12 +73,11 @@ def main() -> None:
7373
print(f"{model.config}")
7474
print(f"{model.generation_config}")
7575

76-
tokenizer = AutoTokenizer.from_pretrained(args.hf_model_repo)
77-
input_ids = tokenizer([""], return_tensors="pt").to(device)["input_ids"]
76+
input_ids = torch.tensor([[1]], dtype=torch.long)
7877
cache_position = torch.tensor([0], dtype=torch.long)
7978

8079
def _get_constant_methods(model: PreTrainedModel):
81-
return {
80+
metadata = {
8281
"get_dtype": 5 if model.config.torch_dtype == torch.float16 else 6,
8382
"get_bos_id": model.config.bos_token_id,
8483
"get_eos_id": model.config.eos_token_id,
@@ -90,6 +89,7 @@ def _get_constant_methods(model: PreTrainedModel):
9089
"get_vocab_size": model.config.vocab_size,
9190
"use_kv_cache": model.generation_config.use_cache,
9291
}
92+
return {k: v for k, v in metadata.items() if v is not None}
9393

9494
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
9595

0 commit comments

Comments
 (0)