11import logging
2+ from collections .abc import Iterable , Mapping , Sequence
23from dataclasses import dataclass , field , fields
3- from typing import Any , Dict , Iterable , List , Mapping , Optional , Sequence
4+ from typing import Any , Optional
45
56from omegaconf import OmegaConf
67
@@ -14,7 +15,7 @@ def _first_not_none(*values: Any) -> Any:
1415 return None
1516
1617
17- def _pick_from_mapping (data : Optional [ Mapping [str , Any ]] , keys : Iterable [str ]) -> Any :
18+ def _pick_from_mapping (data : Mapping [str , Any ] | None , keys : Iterable [str ]) -> Any :
1819 if not data :
1920 return None
2021 for key in keys :
@@ -28,11 +29,11 @@ class EvalEnvDatasetConfig:
2829 """Dataset-level generation parameters shared across delegate clients."""
2930
3031 name : str = ""
31- n_samples_per_eval_prompt : Optional [ int ] = None
32- temperature : Optional [ float ] = None
33- top_p : Optional [ float ] = None
34- top_k : Optional [ int ] = None
35- max_response_len : Optional [ int ] = None
32+ n_samples_per_eval_prompt : int | None = None
33+ temperature : float | None = None
34+ top_p : float | None = None
35+ top_k : int | None = None
36+ max_response_len : int | None = None
3637
3738 # TODO: This is ugly, temporarily leave this. We should unify all the config name for dataset, default, and args. (advice from Tom.)
3839 FIELD_SPECS = {
@@ -75,7 +76,7 @@ def parse(cls, args, dataset_cfg: Mapping[str, Any], defaults: Mapping[str, Any]
7576 "Colon in dataset name is not allowed; use `n_samples_per_eval_prompt` to configure samples per prompt."
7677 )
7778
78- values : Dict [str , Any ] = {"name" : name }
79+ values : dict [str , Any ] = {"name" : name }
7980 for field_name , spec in cls .FIELD_SPECS .items ():
8081 dataset_value = _pick_from_mapping (dataset_cfg , spec ["dataset_keys" ])
8182 default_value = _pick_from_mapping (defaults , spec ["default_keys" ])
@@ -88,9 +89,9 @@ def parse(cls, args, dataset_cfg: Mapping[str, Any], defaults: Mapping[str, Any]
8889 obj = cls (** obj )
8990 return obj
9091
91- def to_payload (self ) -> Dict [str , Any ]:
92+ def to_payload (self ) -> dict [str , Any ]:
9293 """Return a JSON-serializable payload for this dataset configuration."""
93- payload : Dict [str , Any ] = {}
94+ payload : dict [str , Any ] = {}
9495 for field_info in fields (self ):
9596 value = getattr (self , field_info .name )
9697 if value is None :
@@ -104,11 +105,11 @@ class EvalEnvConfig:
104105 """Environment definition shared across delegate implementations."""
105106
106107 name : str = ""
107- url : Optional [ str ] = None
108+ url : str | None = None
108109 timeout_secs : int = 3600
109110 max_retries : int = 1
110- headers : Dict [str , Any ] = field (default_factory = dict )
111- defaults : Dict [str , Any ] = field (default_factory = dict )
111+ headers : dict [str , Any ] = field (default_factory = dict )
112+ defaults : dict [str , Any ] = field (default_factory = dict )
112113
113114 @classmethod
114115 def parse (cls , raw : Mapping [str , Any ], defaults : Mapping [str , Any ]) -> "EvalEnvConfig" :
@@ -121,9 +122,9 @@ def parse(cls, raw: Mapping[str, Any], defaults: Mapping[str, Any]) -> "EvalEnvC
121122
122123
123124def _rebuild_delegate_config (
124- args , raw_delegate_config : Optional [ Sequence [Mapping [str , Any ]]] , defaults : Optional [ Mapping [str , Any ]]
125- ) -> List [EvalEnvConfig ]:
126- envs : List [EvalEnvConfig ] = []
125+ args , raw_delegate_config : Sequence [Mapping [str , Any ]] | None , defaults : Mapping [str , Any ] | None
126+ ) -> list [EvalEnvConfig ]:
127+ envs : list [EvalEnvConfig ] = []
127128 defaults = defaults or {}
128129 for env in raw_delegate_config or []:
129130 env_name = str (env .get ("name" , "" )).strip ().lower ()
@@ -151,13 +152,13 @@ class EvalClient:
151152 def __init__ (self , name : str ):
152153 self .name = name
153154
154- def evaluate (self , args , rollout_id : int ) -> tuple [Dict [str , Any ], Dict [str , Any ]]:
155+ def evaluate (self , args , rollout_id : int ) -> tuple [dict [str , Any ], dict [str , Any ]]:
155156 raise NotImplementedError ("Subclasses must implement this method" )
156157
157158
158- def _flatten (result : Dict [str , Any ], prefix : Optional [ str ] = None ) -> Dict [str , Any ]:
159+ def _flatten (result : dict [str , Any ], prefix : str | None = None ) -> dict [str , Any ]:
159160 """Flatten nested metric dicts into slash separated keys."""
160- flattened : Dict [str , Any ] = {}
161+ flattened : dict [str , Any ] = {}
161162 for key , value in (result or {}).items ():
162163 full_key = f"{ prefix } /{ key } " if prefix else key
163164 if isinstance (value , dict ):
@@ -174,15 +175,13 @@ def __init__(self, delegates: Sequence[EvalClient]):
174175 self ._delegates = list (delegates )
175176
176177 @classmethod
177- def maybe_create (
178- cls , args , env_configs : Optional [Sequence [EvalEnvConfig ]] = None
179- ) -> Optional ["EvalDelegateClient" ]:
178+ def maybe_create (cls , args , env_configs : Sequence [EvalEnvConfig ] | None = None ) -> Optional ["EvalDelegateClient" ]:
180179 env_configs = list (env_configs ) if env_configs is not None else getattr (args , "eval_delegate_config" , None )
181180 if not env_configs :
182181 return None
183182
184183 router_addr = f"http://{ args .sglang_router_ip } :{ args .sglang_router_port } "
185- delegates : List [EvalClient ] = []
184+ delegates : list [EvalClient ] = []
186185 for env_cfg in env_configs :
187186 delegate = cls ._create_delegate (env_cfg , router_addr )
188187 if delegate is not None :
@@ -201,9 +200,9 @@ def _create_delegate(env_cfg: EvalEnvConfig, router_addr: str):
201200 logger .warning ("No delegate client registered for environment: %s" , env_name )
202201 return None
203202
204- def evaluate (self , args , rollout_id : int ) -> tuple [Dict [str , Any ], Dict [str , Any ]]:
205- aggregated_metrics : Dict [str , Any ] = {}
206- raw_responses : Dict [str , Any ] = {}
203+ def evaluate (self , args , rollout_id : int ) -> tuple [dict [str , Any ], dict [str , Any ]]:
204+ aggregated_metrics : dict [str , Any ] = {}
205+ raw_responses : dict [str , Any ] = {}
207206 for delegate in self ._delegates :
208207 metrics , response = delegate .evaluate (args , rollout_id )
209208 if metrics :
0 commit comments