Skip to content

Commit 5cf8ae1

Browse files
committed
Improve model path and name retrieval
1 parent f258c1d commit 5cf8ae1

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

scripts/huggingface/upload_model.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,33 @@
1818
CONFIG_NAME = "model_upload"
1919

2020

21-
def link_model_card(model_name: str, target_file: pathlib.Path):
21+
def link_model_card(model_path: pathlib.Path, target_file: pathlib.Path):
2222
"""Link the README associated to the model to the current directory."""
23-
model_directory = (
24-
pathlib.Path(__file__) / ".." / ".." / "the_well" / "benchmark" / "models"
25-
)
26-
readme_file = model_directory / model_name / "README.md"
23+
readme_file = model_path / "README.md"
2724
readme_file = readme_file.resolve()
25+
logger.info(f"Link {target_file=} to {readme_file=}")
2826
target_file.symlink_to(readme_file)
2927

3028

29+
def retrieve_model_name(cfg_target: str) -> str:
30+
"""Retrieve the name of the model folder from the hydra config target"""
31+
model_name = str(cfg_target.split(".")[-2])
32+
return model_name
33+
34+
35+
def get_model_path(model_name: str) -> pathlib.Path:
36+
return (
37+
pathlib.Path(__file__)
38+
/ ".."
39+
/ ".."
40+
/ ".."
41+
/ "the_well"
42+
/ "benchmark"
43+
/ "models"
44+
/ model_name
45+
).resolve()
46+
47+
3148
def upload_folder(folder: pathlib.Path, repo_id: str):
3249
api = HfApi()
3350
api.upload_large_folder(
@@ -62,14 +79,15 @@ def main(cfg: DictConfig):
6279
model_state_dict = checkpoint["model_state_dict"]
6380
model.load_state_dict(model_state_dict)
6481

65-
model_name = model.__class__.__name__
82+
model_name = retrieve_model_name(cfg.model._target_)
83+
model_path = get_model_path(model_name)
6684
dataset_name = str(cfg.data.well_dataset_name)
6785
repo_id = f"polymathic-ai/{model_name}-{dataset_name}"
6886
logger.info("Uploading model.")
6987
with tempfile.TemporaryDirectory() as tmp_dirname:
7088
tmp_dirname = pathlib.Path(tmp_dirname)
7189
# Copy model readme
72-
link_model_card(model_name, tmp_dirname / "README.md")
90+
link_model_card(model_path, tmp_dirname / "README.md")
7391
# Save model locally with HF formalism
7492
model.save_pretrained(tmp_dirname)
7593
upload_folder(tmp_dirname, repo_id=repo_id)

0 commit comments

Comments
 (0)