Skip to content

Commit f8f9fd6

Browse files
committed
Move model types to environment config to be dynamic
1 parent 12d3c55 commit f8f9fd6

File tree

5 files changed

+57
-35
lines changed

5 files changed

+57
-35
lines changed

vec_inf/cli/_vars.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,40 @@
44
used in the CLI display formatting.
55
"""
66

7+
from typing import get_args
8+
9+
from vec_inf.client._slurm_vars import MODEL_TYPES
10+
11+
12+
# Extract model type values from the Literal type
13+
_MODEL_TYPES = get_args(MODEL_TYPES)
14+
15+
# Rich color options (prioritizing current colors, with fallbacks for additional types)
16+
_RICH_COLORS = [
17+
"cyan",
18+
"bright_blue",
19+
"purple",
20+
"bright_magenta",
21+
"green",
22+
"yellow",
23+
"bright_green",
24+
"bright_yellow",
25+
"red",
26+
"bright_red",
27+
"blue",
28+
"magenta",
29+
"bright_cyan",
30+
"white",
31+
"bright_white",
32+
]
33+
734
# Mapping of model types to their display priority (lower numbers shown first)
8-
MODEL_TYPE_PRIORITY = {
9-
"LLM": 0,
10-
"VLM": 1,
11-
"Text_Embedding": 2,
12-
"Reward_Modeling": 3,
13-
}
35+
MODEL_TYPE_PRIORITY = {model_type: idx for idx, model_type in enumerate(_MODEL_TYPES)}
1436

1537
# Mapping of model types to their display colors in Rich
1638
MODEL_TYPE_COLORS = {
17-
"LLM": "cyan",
18-
"VLM": "bright_blue",
19-
"Text_Embedding": "purple",
20-
"Reward_Modeling": "bright_magenta",
39+
model_type: _RICH_COLORS[idx % len(_RICH_COLORS)]
40+
for idx, model_type in enumerate(_MODEL_TYPES)
2141
}
2242

2343
# Inference engine choice and name mapping

vec_inf/client/_slurm_vars.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,13 @@ def create_literal_type(values: list[str], fallback: str = "") -> Any:
8282
_config["allowed_values"]["resource_type"]
8383
)
8484

85-
# Extract required arguments, for launching jobs that don't have a default value and
86-
# their corresponding environment variables
85+
# Model types available derived from the cached model config
86+
MODEL_TYPES: TypeAlias = create_literal_type(_config["model_types"]) # type: ignore[valid-type]
87+
88+
# Required arguments for launching jobs and corresponding environment variables
8789
REQUIRED_ARGS: dict[str, str | None] = _config["required_args"]
8890

89-
# Extract python version, running sglang requires python version
91+
# Running sglang requires python version
9092
PYTHON_VERSION: str = _config["python_version"]
9193

9294
# Extract default arguments

vec_inf/client/config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
from typing import Any, Optional, Union
99

1010
from pydantic import BaseModel, ConfigDict, Field
11-
from typing_extensions import Literal
1211

1312
from vec_inf.client._slurm_vars import (
1413
DEFAULT_ARGS,
1514
MAX_CPUS_PER_TASK,
1615
MAX_GPUS_PER_NODE,
1716
MAX_NUM_NODES,
17+
MODEL_TYPES,
1818
PARTITION,
1919
QOS,
2020
RESOURCE_TYPE,
@@ -88,9 +88,7 @@ class ModelConfig(BaseModel):
8888
model_variant: Optional[str] = Field(
8989
default=None, description="Specific variant/version of the model family"
9090
)
91-
model_type: Literal["LLM", "VLM", "Text_Embedding", "Reward_Modeling"] = Field(
92-
..., description="Type of model architecture"
93-
)
91+
model_type: MODEL_TYPES = Field(..., description="Type of model architecture")
9492
gpus_per_node: int = Field(
9593
..., gt=0, le=MAX_GPUS_PER_NODE, description="GPUs per node"
9694
)

vec_inf/client/models.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525

2626
from dataclasses import dataclass, field
2727
from enum import Enum
28-
from typing import Any, Optional, Union
28+
from typing import Any, Optional, Union, get_args
29+
30+
from vec_inf.client._slurm_vars import MODEL_TYPES
2931

3032

3133
class ModelStatus(str, Enum):
@@ -55,25 +57,23 @@ class ModelStatus(str, Enum):
5557
UNAVAILABLE = "UNAVAILABLE"
5658

5759

58-
class ModelType(str, Enum):
59-
"""Enum representing the possible model types.
60+
# Extract model type values from the Literal type
61+
_MODEL_TYPE_VALUES = get_args(MODEL_TYPES)
62+
63+
64+
def _model_type_to_enum_name(model_type: str) -> str:
65+
"""Convert a model type string to a valid enum attribute name."""
66+
# Convert to uppercase and replace hyphens with underscores
67+
return model_type.upper().replace("-", "_")
6068

61-
Attributes
62-
----------
63-
LLM : str
64-
Large Language Model
65-
VLM : str
66-
Vision Language Model
67-
TEXT_EMBEDDING : str
68-
Text Embedding Model
69-
REWARD_MODELING : str
70-
Reward Modeling Model
71-
"""
7269

73-
LLM = "LLM"
74-
VLM = "VLM"
75-
TEXT_EMBEDDING = "Text_Embedding"
76-
REWARD_MODELING = "Reward_Modeling"
70+
# Create ModelType enum dynamically from MODEL_TYPES
71+
ModelType = Enum(
72+
"ModelType",
73+
{_model_type_to_enum_name(mt): mt for mt in _MODEL_TYPE_VALUES},
74+
type=str,
75+
module=__name__,
76+
)
7777

7878

7979
@dataclass

vec_inf/config/environment.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ required_args:
2424

2525
python_version: "python3.12"
2626

27+
model_types: ["LLM", "VLM", "Text_Embedding", "Reward_Modeling", "OCR"] # Derived from models.yaml
28+
2729
default_args:
2830
cpus_per_task: "16"
2931
mem_per_node: "64G"

0 commit comments

Comments
 (0)