|
18 | 18 | CONFIG_NAME = "model_upload" |
19 | 19 |
|
20 | 20 |
|
21 | | -def link_model_card(model_name: str, target_file: pathlib.Path): |
| 21 | +def link_model_card(model_path: pathlib.Path, target_file: pathlib.Path): |
22 | 22 | """Link the README associated to the model to the current directory.""" |
23 | | - model_directory = ( |
24 | | - pathlib.Path(__file__) / ".." / ".." / "the_well" / "benchmark" / "models" |
25 | | - ) |
26 | | - readme_file = model_directory / model_name / "README.md" |
| 23 | + readme_file = model_path / "README.md" |
27 | 24 | readme_file = readme_file.resolve() |
| 25 | + logger.info(f"Link {target_file=} to {readme_file=}") |
28 | 26 | target_file.symlink_to(readme_file) |
29 | 27 |
|
30 | 28 |
|
| 29 | +def retrieve_model_name(cfg_target: str) -> str: |
| 30 | + """Retrieve the name of the model folder from the hydra config target""" |
| 31 | + model_name = str(cfg_target.split(".")[-2]) |
| 32 | + return model_name |
| 33 | + |
| 34 | + |
| 35 | +def get_model_path(model_name: str) -> pathlib.Path: |
| 36 | + return ( |
| 37 | + pathlib.Path(__file__) |
| 38 | + / ".." |
| 39 | + / ".." |
| 40 | + / ".." |
| 41 | + / "the_well" |
| 42 | + / "benchmark" |
| 43 | + / "models" |
| 44 | + / model_name |
| 45 | + ).resolve() |
| 46 | + |
| 47 | + |
31 | 48 | def upload_folder(folder: pathlib.Path, repo_id: str): |
32 | 49 | api = HfApi() |
33 | 50 | api.upload_large_folder( |
@@ -62,14 +79,15 @@ def main(cfg: DictConfig): |
62 | 79 | model_state_dict = checkpoint["model_state_dict"] |
63 | 80 | model.load_state_dict(model_state_dict) |
64 | 81 |
|
65 | | - model_name = model.__class__.__name__ |
| 82 | + model_name = retrieve_model_name(cfg.model._target_) |
| 83 | + model_path = get_model_path(model_name) |
66 | 84 | dataset_name = str(cfg.data.well_dataset_name) |
67 | 85 | repo_id = f"polymathic-ai/{model_name}-{dataset_name}" |
68 | 86 | logger.info("Uploading model.") |
69 | 87 | with tempfile.TemporaryDirectory() as tmp_dirname: |
70 | 88 | tmp_dirname = pathlib.Path(tmp_dirname) |
71 | 89 | # Copy model readme |
72 | | - link_model_card(model_name, tmp_dirname / "README.md") |
| 90 | + link_model_card(model_path, tmp_dirname / "README.md") |
73 | 91 | # Save model locally with HF formalism |
74 | 92 | model.save_pretrained(tmp_dirname) |
75 | 93 | upload_folder(tmp_dirname, repo_id=repo_id) |
|
0 commit comments