44logger = init_logger (__name__ )
55
66from dataclasses import dataclass
7- from typing import Type , Dict , Optional , Callable , List
7+ from typing import Type , Dict , Optional , Callable , List , Union
88from lightllm .utils .log_utils import init_logger
99
1010logger = init_logger (__name__ )
@@ -23,7 +23,7 @@ def __init__(self):
2323
2424 def __call__ (
2525 self ,
26- model_type : str ,
26+ model_type : Union [ str , List [ str ]] ,
2727 is_multimodal : bool = False ,
2828 condition : Optional [Callable [[dict ], bool ]] = None ,
2929 ):
@@ -58,6 +58,7 @@ def get_model(self, model_cfg: dict, model_kvargs: dict) -> tuple:
5858 if len (matches ) > 1 :
5959 # Keep conditionally matched models
6060 matches = [m for m in matches if m .condition is not None ]
61+
6162 assert (
6263 len (matches ) == 1
6364 ), "Existence of coupled conditon, inability to determine the class of models instantiated"
@@ -78,23 +79,14 @@ def get_model(model_cfg: dict, model_kvargs: dict):
7879 raise
7980
8081
81- def has_visual_config (cfg : dict ) -> bool :
82- return "visual" in cfg
83-
84-
8582def is_reward_model () -> Callable [[Dict [str , any ]], bool ]:
86- return lambda c : "RewardModel" in c .get ("architectures" , [])
87-
83+ return lambda model_cfg : "RewardModel" in model_cfg .get ("architectures" , ["" ])[0 ]
8884
89- def architecture_is (name : str ) -> Callable [[Dict [str , any ]], bool ]:
90- """Predicate: matches first element of model_cfg['architectures'] == name."""
91- return lambda c : c .get ("architectures" , ["" ])[0 ] == name
9285
93-
94- def llm_model_type_is (name : str ) -> Callable [[Dict [str , any ]], bool ]:
95- names = [name ] if isinstance (name , str ) else name
86+ def llm_model_type_is (name : Union [str , List [str ]]) -> Callable [[Dict [str , any ]], bool ]:
9687 """Predicate: matches model_cfg.get("llm_config").get("model_type") == name."""
97- return lambda c : (
98- c .get ("llm_config" , {}).get ("model_type" , "" ) in names
99- or c .get ("text_config" , {}).get ("model_type" , "" ) in names
88+ names = [name ] if isinstance (name , str ) else name
89+ return lambda model_cfg : (
90+ model_cfg .get ("llm_config" , {}).get ("model_type" , "" ) in names
91+ or model_cfg .get ("text_config" , {}).get ("model_type" , "" ) in names
10092 )
0 commit comments