-
Notifications
You must be signed in to change notification settings - Fork 12
Feature: Support downloading model weights on-the-fly from HuggingFace (#166) #167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
fc843ed
5f790ff
0f22bec
9f2fdd2
38011be
4de3563
eb1e929
8b6a211
c68cb35
b610891
a7a5deb
bb3142b
2db9c6c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,8 +37,18 @@ def __init__(self, params: dict[str, Any]): | |
| self.additional_binds = ( | ||
| f",{self.params['bind']}" if self.params.get("bind") else "" | ||
| ) | ||
| self.model_weights_path = str( | ||
| Path(self.params["model_weights_parent_dir"], self.params["model_name"]) | ||
| model_weights_path = Path( | ||
| self.params["model_weights_parent_dir"], self.params["model_name"] | ||
| ) | ||
| self.model_weights_exists = model_weights_path.exists() | ||
| self.model_weights_path = str(model_weights_path) | ||
| self.model_source = ( | ||
| self.model_weights_path | ||
| if self.model_weights_exists | ||
| else self.params["model_name"] | ||
| ) | ||
| self.model_bind_option = ( | ||
| f",{self.model_weights_path}" if self.model_weights_exists else "" | ||
| ) | ||
| self.env_str = self._generate_env_str() | ||
|
|
||
|
|
@@ -111,7 +121,9 @@ def _generate_server_setup(self) -> str: | |
| server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_setup"])) | ||
| server_script.append( | ||
| SLURM_SCRIPT_TEMPLATE["bind_path"].format( | ||
| model_weights_path=self.model_weights_path, | ||
| model_weights_path=self.model_weights_path | ||
| if self.model_weights_exists | ||
| else "", | ||
| additional_binds=self.additional_binds, | ||
| ) | ||
| ) | ||
|
|
@@ -131,7 +143,6 @@ def _generate_server_setup(self) -> str: | |
| server_setup_str = server_setup_str.replace( | ||
| "CONTAINER_PLACEHOLDER", | ||
| SLURM_SCRIPT_TEMPLATE["container_command"].format( | ||
| model_weights_path=self.model_weights_path, | ||
| env_str=self.env_str, | ||
| ), | ||
| ) | ||
|
|
@@ -165,22 +176,27 @@ def _generate_launch_cmd(self) -> str: | |
| Server launch command. | ||
| """ | ||
| launcher_script = ["\n"] | ||
|
|
||
| vllm_args_copy = self.params["vllm_args"].copy() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if this is necessary, as the model name should be parsed with launch command not part of --vllm-args
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, I'm open to alternative approaches if you have a preference, like:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah thanks for the clarification, I think having a dedicated CLI option keeps things clean and means minimal changes. However the code base has changed pretty significantly and there are quite a few conflicts in order to merge the changes, if you give me the access to your branch I can help resolve the conflicts. |
||
| model_source = self.model_source | ||
| if "--model" in vllm_args_copy: | ||
| model_source = vllm_args_copy.pop("--model") | ||
|
|
||
| if self.use_container: | ||
| launcher_script.append( | ||
| SLURM_SCRIPT_TEMPLATE["container_command"].format( | ||
| model_weights_path=self.model_weights_path, | ||
| env_str=self.env_str, | ||
| ) | ||
| ) | ||
|
|
||
| launcher_script.append( | ||
| "\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"]).format( | ||
| model_weights_path=self.model_weights_path, | ||
| model_source=model_source, | ||
| model_name=self.params["model_name"], | ||
| ) | ||
| ) | ||
|
|
||
| for arg, value in self.params["vllm_args"].items(): | ||
| for arg, value in vllm_args_copy.items(): | ||
| if isinstance(value, bool): | ||
| launcher_script.append(f" {arg} \\") | ||
| else: | ||
|
|
@@ -225,11 +241,20 @@ def __init__(self, params: dict[str, Any]): | |
| if self.params["models"][model_name].get("bind") | ||
| else "" | ||
| ) | ||
| self.params["models"][model_name]["model_weights_path"] = str( | ||
| Path( | ||
| self.params["models"][model_name]["model_weights_parent_dir"], | ||
| model_name, | ||
| ) | ||
| model_weights_path = Path( | ||
| self.params["models"][model_name]["model_weights_parent_dir"], | ||
| model_name, | ||
| ) | ||
| model_weights_exists = model_weights_path.exists() | ||
| model_weights_path_str = str(model_weights_path) | ||
| self.params["models"][model_name]["model_weights_path"] = ( | ||
| model_weights_path_str | ||
| ) | ||
| self.params["models"][model_name]["model_weights_exists"] = ( | ||
| model_weights_exists | ||
| ) | ||
| self.params["models"][model_name]["model_source"] = ( | ||
| model_weights_path_str if model_weights_exists else model_name | ||
| ) | ||
|
|
||
| def _write_to_log_dir(self, script_content: list[str], script_name: str) -> Path: | ||
|
|
@@ -266,7 +291,9 @@ def _generate_model_launch_script(self, model_name: str) -> Path: | |
| script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_setup"]) | ||
| script_content.append( | ||
| BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["bind_path"].format( | ||
| model_weights_path=model_params["model_weights_path"], | ||
| model_weights_path=model_params["model_weights_path"] | ||
| if model_params.get("model_weights_exists", True) | ||
| else "", | ||
| additional_binds=model_params["additional_binds"], | ||
| ) | ||
| ) | ||
|
|
@@ -283,19 +310,25 @@ def _generate_model_launch_script(self, model_name: str) -> Path: | |
| model_name=model_name, | ||
| ) | ||
| ) | ||
| vllm_args_copy = model_params["vllm_args"].copy() | ||
| model_source = model_params.get( | ||
| "model_source", model_params["model_weights_path"] | ||
| ) | ||
| if "--model" in vllm_args_copy: | ||
| model_source = vllm_args_copy.pop("--model") | ||
|
|
||
| if self.use_container: | ||
| script_content.append( | ||
| BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format( | ||
| model_weights_path=model_params["model_weights_path"], | ||
| ) | ||
| BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format() | ||
| ) | ||
| script_content.append( | ||
| "\n".join(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["launch_cmd"]).format( | ||
| model_weights_path=model_params["model_weights_path"], | ||
| model_source=model_source, | ||
| model_name=model_name, | ||
| ) | ||
| ) | ||
| for arg, value in model_params["vllm_args"].items(): | ||
|
|
||
| for arg, value in vllm_args_copy.items(): | ||
| if isinstance(value, bool): | ||
| script_content.append(f" {arg} \\") | ||
| else: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.