|
35 | 35 | from huggingface_hub import snapshot_download # noqa |
36 | 36 | from datasets import load_dataset # noqa |
37 | 37 | from app import __version__ # noqa |
| 38 | +from app.config import Settings # noqa |
38 | 39 | from app.domain import ModelType, TrainingType, BuildBackend, Device, ArchiveFormat, LlmEngine # noqa |
39 | 40 | from app.registry import model_service_registry # noqa |
40 | 41 | from app.api.api import ( |
|
44 | 45 | get_vllm_server, |
45 | 46 | get_app_for_api_docs, |
46 | 47 | ) # noqa |
47 | | -from app.utils import get_settings, send_gelf_message, download_model_package # noqa |
| 48 | +from app.utils import get_settings, send_gelf_message, download_model_package, get_model_data_package_base_name # noqa |
48 | 49 | from app.management.model_manager import ModelManager # noqa |
49 | 50 | from app.api.dependencies import ModelServiceDep, ModelManagerDep # noqa |
50 | 51 | from app.management.tracker_client import TrackerClient # noqa |
@@ -113,10 +114,7 @@ def serve_model( |
113 | 114 | model_service_dep = ModelServiceDep(model_type, config, model_name) |
114 | 115 | cms_globals.model_service_dep = model_service_dep |
115 | 116 |
|
116 | | - dst_model_path = os.path.join(parent_dir, "model", "model.zip" if model_path.endswith(".zip") else "model.tar.gz") |
117 | | - config.BASE_MODEL_FILE = "model.zip" if model_path.endswith(".zip") else "model.tar.gz" |
118 | | - if dst_model_path and os.path.exists(os.path.splitext(dst_model_path)[0]): |
119 | | - shutil.rmtree(os.path.splitext(dst_model_path)[0]) |
| 117 | + dst_model_path = _ensure_dst_model_path(model_path, parent_dir, config) |
120 | 118 |
|
121 | 119 | if model_path: |
122 | 120 | if model_path.startswith("http://") or model_path.startswith("https://"): |
@@ -221,15 +219,13 @@ def train_model( |
221 | 219 | model_service_dep = ModelServiceDep(model_type, config) |
222 | 220 | cms_globals.model_service_dep = model_service_dep |
223 | 221 |
|
224 | | - dst_model_path = os.path.join(parent_dir, "model", "model.zip" if base_model_path.endswith(".zip") else "model.tar.gz") |
225 | | - config.BASE_MODEL_FILE = "model.zip" if base_model_path.endswith(".zip") else "model.tar.gz" |
226 | | - if dst_model_path and os.path.exists(os.path.splitext(dst_model_path)[0]): |
227 | | - shutil.rmtree(os.path.splitext(dst_model_path)[0]) |
| 222 | + dst_model_path = _ensure_dst_model_path(base_model_path, parent_dir, config) |
228 | 223 |
|
229 | 224 | if base_model_path: |
230 | 225 | try: |
231 | 226 | shutil.copy2(base_model_path, dst_model_path) |
232 | 227 | except shutil.SameFileError: |
| 228 | + logger.warning("Source and destination are the same model package file.") |
233 | 229 | pass |
234 | 230 | model_service = model_service_dep() |
235 | 231 | model_service.model_name = model_name if model_name is not None else "CMS model" |
@@ -708,6 +704,23 @@ def show_banner() -> None: |
708 | 704 | typer.echo(banner) |
709 | 705 |
|
710 | 706 |
|
| 707 | +def _ensure_dst_model_path(model_path: str, parent_dir: str, config: Settings) -> str: |
| 708 | + if model_path.endswith(".zip"): |
| 709 | + dst_model_path = os.path.join(parent_dir, "model", "model.zip") |
| 710 | + config.BASE_MODEL_FILE = "model.zip" |
| 711 | + else: |
| 712 | + dst_model_path = os.path.join(parent_dir, "model", "model.tar.gz") |
| 713 | + config.BASE_MODEL_FILE = "model.tar.gz" |
| 714 | + model_dir = os.path.join(parent_dir, "model", "model") |
| 715 | + if os.path.exists(model_dir): |
| 716 | + shutil.rmtree(model_dir) |
| 717 | + if dst_model_path.endswith(".zip") and os.path.exists(dst_model_path.replace(".zip", ".tar.gz")): |
| 718 | + os.remove(dst_model_path.replace(".zip", ".tar.gz")) |
| 719 | + if dst_model_path.endswith(".tar.gz") and os.path.exists(dst_model_path.replace(".tar.gz", ".zip")): |
| 720 | + os.remove(dst_model_path.replace(".tar.gz", ".zip")) |
| 721 | + return dst_model_path |
| 722 | + |
| 723 | + |
711 | 724 | def _get_logger( |
712 | 725 | debug: Optional[bool] = None, |
713 | 726 | model_type: Optional[ModelType] = None, |
|
0 commit comments