Skip to content

Commit 35221d1

Browse files
authored
fix
1 parent 07ec38f commit 35221d1

File tree

2 files changed

+10
-18
lines changed

2 files changed

+10
-18
lines changed

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
FROM nvcr.io/nvidia/tritonserver:24.04-py3-min as base
2-
ARG PYTORCH_VERSION=2.5.1
2+
ARG PYTORCH_VERSION=2.6.0
33
ARG PYTHON_VERSION=3.9
44
ARG CUDA_VERSION=12.4
55
ARG MAMBA_VERSION=23.1.0-1

lightllm/models/registry.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
logger = init_logger(__name__)
55

66
from dataclasses import dataclass
7-
from typing import Type, Dict, Optional, Callable, List
7+
from typing import Type, Dict, Optional, Callable, List, Union
88
from lightllm.utils.log_utils import init_logger
99

1010
logger = init_logger(__name__)
@@ -23,7 +23,7 @@ def __init__(self):
2323

2424
def __call__(
2525
self,
26-
model_type: str,
26+
model_type: Union[str, List[str]],
2727
is_multimodal: bool = False,
2828
condition: Optional[Callable[[dict], bool]] = None,
2929
):
@@ -58,6 +58,7 @@ def get_model(self, model_cfg: dict, model_kvargs: dict) -> tuple:
5858
if len(matches) > 1:
5959
# Keep conditionally matched models
6060
matches = [m for m in matches if m.condition is not None]
61+
6162
assert (
6263
len(matches) == 1
6364
), "Existence of coupled conditon, inability to determine the class of models instantiated"
@@ -78,23 +79,14 @@ def get_model(model_cfg: dict, model_kvargs: dict):
7879
raise
7980

8081

81-
def has_visual_config(cfg: dict) -> bool:
82-
return "visual" in cfg
83-
84-
8582
def is_reward_model() -> Callable[[Dict[str, any]], bool]:
86-
return lambda c: "RewardModel" in c.get("architectures", [])
87-
83+
return lambda model_cfg : "RewardModel" in model_cfg.get("architectures", [""])[0]
8884

89-
def architecture_is(name: str) -> Callable[[Dict[str, any]], bool]:
90-
"""Predicate: matches first element of model_cfg['architectures'] == name."""
91-
return lambda c: c.get("architectures", [""])[0] == name
9285

93-
94-
def llm_model_type_is(name: str) -> Callable[[Dict[str, any]], bool]:
95-
names = [name] if isinstance(name, str) else name
86+
def llm_model_type_is(name: Union[str, List[str]]) -> Callable[[Dict[str, any]], bool]:
9687
"""Predicate: matches model_cfg.get("llm_config").get("model_type") == name."""
97-
return lambda c: (
98-
c.get("llm_config", {}).get("model_type", "") in names
99-
or c.get("text_config", {}).get("model_type", "") in names
88+
names = [name] if isinstance(name, str) else name
89+
return lambda model_cfg : (
90+
model_cfg.get("llm_config", {}).get("model_type", "") in names
91+
or model_cfg.get("text_config", {}).get("model_type", "") in names
10092
)

0 commit comments

Comments
 (0)