|
5 | 5 | import pytest |
6 | 6 |
|
7 | 7 | from clarifai.runners.models.model_builder import ModelBuilder |
8 | | -from clarifai.runners.utils.const import DEFAULT_RUNTIME_DOWNLOAD_PATH |
9 | 8 | from clarifai.runners.utils.loader import HuggingFaceLoader |
10 | 9 |
|
11 | 10 | MODEL_ID = "timm/mobilenetv3_small_100.lamb_in1k" |
@@ -62,16 +61,20 @@ def test_download_checkpoints(dummy_runner_models_dir): |
62 | 61 | model_builder = ModelBuilder(model_folder_path, download_validation_only=True) |
63 | 62 | # defaults to runtime stage which matches config.yaml not having a when field. |
64 | 63 | # get whatever stage is in config.yaml to force download now |
| 64 | + # also always write to where upload/build wants to, not the /tmp folder that runtime stage uses |
65 | 65 | _, _, _, when = model_builder._validate_config_checkpoints() |
66 | | - checkpoint_dir = model_builder.download_checkpoints(stage=when) |
67 | | - assert checkpoint_dir == DEFAULT_RUNTIME_DOWNLOAD_PATH |
| 66 | + checkpoint_dir = model_builder.download_checkpoints( |
| 67 | + stage=when, checkpoint_path_override=model_builder.checkpoint_path) |
| 68 | + assert checkpoint_dir == model_builder.checkpoint_path |
68 | 69 |
|
69 | 70 | # This doesn't have when in it's config.yaml so build. |
70 | 71 | model_folder_path = os.path.join(os.path.dirname(__file__), "hf_mbart_model") |
71 | 72 | model_builder = ModelBuilder(model_folder_path, download_validation_only=True) |
72 | 73 | # defaults to runtime stage which matches config.yaml not having a when field. |
73 | 74 | # get whatever stage is in config.yaml to force download now |
| 75 | + # also always write to where upload/build wants to, not the /tmp folder that runtime stage uses |
74 | 76 | _, _, _, when = model_builder._validate_config_checkpoints() |
75 | | - checkpoint_dir = model_builder.download_checkpoints(stage=when) |
| 77 | + checkpoint_dir = model_builder.download_checkpoints( |
| 78 | + stage=when, checkpoint_path_override=model_builder.checkpoint_path) |
76 | 79 | assert checkpoint_dir == os.path.join( |
77 | 80 | os.path.dirname(__file__), "hf_mbart_model", "1", "checkpoints") |
0 commit comments