Skip to content

Commit f29cb1c

Browse files
authored
[4/N] Tiny enable UP ruleset in Ruff (#994)
1 parent 1ede369 commit f29cb1c

File tree

69 files changed

+385
-414
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+385
-414
lines changed

examples/eval/eval_delegate.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
2+
from collections.abc import Iterable, Mapping, Sequence
23
from dataclasses import dataclass, field, fields
3-
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence
4+
from typing import Any, Optional
45

56
from omegaconf import OmegaConf
67

@@ -14,7 +15,7 @@ def _first_not_none(*values: Any) -> Any:
1415
return None
1516

1617

17-
def _pick_from_mapping(data: Optional[Mapping[str, Any]], keys: Iterable[str]) -> Any:
18+
def _pick_from_mapping(data: Mapping[str, Any] | None, keys: Iterable[str]) -> Any:
1819
if not data:
1920
return None
2021
for key in keys:
@@ -28,11 +29,11 @@ class EvalEnvDatasetConfig:
2829
"""Dataset-level generation parameters shared across delegate clients."""
2930

3031
name: str = ""
31-
n_samples_per_eval_prompt: Optional[int] = None
32-
temperature: Optional[float] = None
33-
top_p: Optional[float] = None
34-
top_k: Optional[int] = None
35-
max_response_len: Optional[int] = None
32+
n_samples_per_eval_prompt: int | None = None
33+
temperature: float | None = None
34+
top_p: float | None = None
35+
top_k: int | None = None
36+
max_response_len: int | None = None
3637

3738
# TODO: This is ugly, temporarily leave this. We should unify all the config name for dataset, default, and args. (advice from Tom.)
3839
FIELD_SPECS = {
@@ -75,7 +76,7 @@ def parse(cls, args, dataset_cfg: Mapping[str, Any], defaults: Mapping[str, Any]
7576
"Colon in dataset name is not allowed; use `n_samples_per_eval_prompt` to configure samples per prompt."
7677
)
7778

78-
values: Dict[str, Any] = {"name": name}
79+
values: dict[str, Any] = {"name": name}
7980
for field_name, spec in cls.FIELD_SPECS.items():
8081
dataset_value = _pick_from_mapping(dataset_cfg, spec["dataset_keys"])
8182
default_value = _pick_from_mapping(defaults, spec["default_keys"])
@@ -88,9 +89,9 @@ def parse(cls, args, dataset_cfg: Mapping[str, Any], defaults: Mapping[str, Any]
8889
obj = cls(**obj)
8990
return obj
9091

91-
def to_payload(self) -> Dict[str, Any]:
92+
def to_payload(self) -> dict[str, Any]:
9293
"""Return a JSON-serializable payload for this dataset configuration."""
93-
payload: Dict[str, Any] = {}
94+
payload: dict[str, Any] = {}
9495
for field_info in fields(self):
9596
value = getattr(self, field_info.name)
9697
if value is None:
@@ -104,11 +105,11 @@ class EvalEnvConfig:
104105
"""Environment definition shared across delegate implementations."""
105106

106107
name: str = ""
107-
url: Optional[str] = None
108+
url: str | None = None
108109
timeout_secs: int = 3600
109110
max_retries: int = 1
110-
headers: Dict[str, Any] = field(default_factory=dict)
111-
defaults: Dict[str, Any] = field(default_factory=dict)
111+
headers: dict[str, Any] = field(default_factory=dict)
112+
defaults: dict[str, Any] = field(default_factory=dict)
112113

113114
@classmethod
114115
def parse(cls, raw: Mapping[str, Any], defaults: Mapping[str, Any]) -> "EvalEnvConfig":
@@ -121,9 +122,9 @@ def parse(cls, raw: Mapping[str, Any], defaults: Mapping[str, Any]) -> "EvalEnvC
121122

122123

123124
def _rebuild_delegate_config(
124-
args, raw_delegate_config: Optional[Sequence[Mapping[str, Any]]], defaults: Optional[Mapping[str, Any]]
125-
) -> List[EvalEnvConfig]:
126-
envs: List[EvalEnvConfig] = []
125+
args, raw_delegate_config: Sequence[Mapping[str, Any]] | None, defaults: Mapping[str, Any] | None
126+
) -> list[EvalEnvConfig]:
127+
envs: list[EvalEnvConfig] = []
127128
defaults = defaults or {}
128129
for env in raw_delegate_config or []:
129130
env_name = str(env.get("name", "")).strip().lower()
@@ -151,13 +152,13 @@ class EvalClient:
151152
def __init__(self, name: str):
152153
self.name = name
153154

154-
def evaluate(self, args, rollout_id: int) -> tuple[Dict[str, Any], Dict[str, Any]]:
155+
def evaluate(self, args, rollout_id: int) -> tuple[dict[str, Any], dict[str, Any]]:
155156
raise NotImplementedError("Subclasses must implement this method")
156157

157158

158-
def _flatten(result: Dict[str, Any], prefix: Optional[str] = None) -> Dict[str, Any]:
159+
def _flatten(result: dict[str, Any], prefix: str | None = None) -> dict[str, Any]:
159160
"""Flatten nested metric dicts into slash separated keys."""
160-
flattened: Dict[str, Any] = {}
161+
flattened: dict[str, Any] = {}
161162
for key, value in (result or {}).items():
162163
full_key = f"{prefix}/{key}" if prefix else key
163164
if isinstance(value, dict):
@@ -174,15 +175,13 @@ def __init__(self, delegates: Sequence[EvalClient]):
174175
self._delegates = list(delegates)
175176

176177
@classmethod
177-
def maybe_create(
178-
cls, args, env_configs: Optional[Sequence[EvalEnvConfig]] = None
179-
) -> Optional["EvalDelegateClient"]:
178+
def maybe_create(cls, args, env_configs: Sequence[EvalEnvConfig] | None = None) -> Optional["EvalDelegateClient"]:
180179
env_configs = list(env_configs) if env_configs is not None else getattr(args, "eval_delegate_config", None)
181180
if not env_configs:
182181
return None
183182

184183
router_addr = f"http://{args.sglang_router_ip}:{args.sglang_router_port}"
185-
delegates: List[EvalClient] = []
184+
delegates: list[EvalClient] = []
186185
for env_cfg in env_configs:
187186
delegate = cls._create_delegate(env_cfg, router_addr)
188187
if delegate is not None:
@@ -201,9 +200,9 @@ def _create_delegate(env_cfg: EvalEnvConfig, router_addr: str):
201200
logger.warning("No delegate client registered for environment: %s", env_name)
202201
return None
203202

204-
def evaluate(self, args, rollout_id: int) -> tuple[Dict[str, Any], Dict[str, Any]]:
205-
aggregated_metrics: Dict[str, Any] = {}
206-
raw_responses: Dict[str, Any] = {}
203+
def evaluate(self, args, rollout_id: int) -> tuple[dict[str, Any], dict[str, Any]]:
204+
aggregated_metrics: dict[str, Any] = {}
205+
raw_responses: dict[str, Any] = {}
207206
for delegate in self._delegates:
208207
metrics, response = delegate.evaluate(args, rollout_id)
209208
if metrics:

examples/eval/eval_delegate_rollout.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import os
55
from pathlib import Path
6-
from typing import Any, Optional
6+
from typing import Any
77

88
from examples.eval.eval_delegate import EvalDelegateClient, _rebuild_delegate_config
99
from omegaconf import OmegaConf
@@ -13,7 +13,7 @@
1313

1414
logger = logging.getLogger(__name__)
1515

16-
_DELEGATE_CACHE: dict[str, tuple[Optional[float], Optional[EvalDelegateClient]]] = {}
16+
_DELEGATE_CACHE: dict[str, tuple[float | None, EvalDelegateClient | None]] = {}
1717

1818

1919
def generate_rollout(
@@ -32,7 +32,7 @@ def generate_rollout(
3232
return result
3333

3434

35-
def _get_delegate_client(args) -> Optional[EvalDelegateClient]:
35+
def _get_delegate_client(args) -> EvalDelegateClient | None:
3636
config_path = getattr(args, "eval_config", None)
3737
if not config_path:
3838
return None
@@ -48,7 +48,7 @@ def _get_delegate_client(args) -> Optional[EvalDelegateClient]:
4848
return client
4949

5050

51-
def _build_delegate_client(args, config_path: str) -> Optional[EvalDelegateClient]:
51+
def _build_delegate_client(args, config_path: str) -> EvalDelegateClient | None:
5252
cfg = OmegaConf.load(config_path)
5353
cfg_dict = OmegaConf.to_container(cfg, resolve=True)
5454
if not isinstance(cfg_dict, dict):
@@ -70,22 +70,22 @@ def _build_delegate_client(args, config_path: str) -> Optional[EvalDelegateClien
7070
return EvalDelegateClient.maybe_create(args, env_configs=env_configs)
7171

7272

73-
def _safe_mtime(path: str) -> Optional[float]:
73+
def _safe_mtime(path: str) -> float | None:
7474
try:
7575
return os.path.getmtime(path)
7676
except OSError:
7777
return None
7878

7979

80-
def _log_delegate_metrics(args, rollout_id: int, metrics: dict | None, raw_response: Optional[dict]) -> dict:
80+
def _log_delegate_metrics(args, rollout_id: int, metrics: dict | None, raw_response: dict | None) -> dict:
8181
flattened = _flatten_metrics(metrics)
8282
if raw_response is not None:
8383
logger.info("External eval raw response for rollout %s: %s", rollout_id, raw_response)
8484
logger.info("eval %s (external): %s", rollout_id, flattened)
8585
return flattened
8686

8787

88-
def _flatten_metrics(metric_source: Optional[dict]) -> dict:
88+
def _flatten_metrics(metric_source: dict | None) -> dict:
8989
flattened_metrics: dict[str, float] = {}
9090
if not isinstance(metric_source, dict):
9191
return flattened_metrics

examples/eval/nemo_skills/skills_client.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import time
3-
from typing import Any, Dict, Optional
3+
from typing import Any
44

55
import requests
66
from examples.eval.eval_delegate import EvalClient, EvalDelegateError
@@ -28,7 +28,7 @@ def from_config(cls, config: SkillsEvalEnvConfig, router_url: str):
2828
return None
2929
return cls(config, router_url)
3030

31-
def evaluate(self, args, rollout_id: int) -> tuple[Dict[str, Any], Dict[str, Any]]:
31+
def evaluate(self, args, rollout_id: int) -> tuple[dict[str, Any], dict[str, Any]]:
3232
if not self._config.datasets:
3333
logger.warning("No Skills datasets configured; skipping delegate evaluation.")
3434
return {}, {}
@@ -38,7 +38,7 @@ def evaluate(self, args, rollout_id: int) -> tuple[Dict[str, Any], Dict[str, Any
3838
metrics = response["raw_metrics"]
3939
return metrics, response
4040

41-
def _build_payload(self, args, rollout_id: int) -> Dict[str, Any]:
41+
def _build_payload(self, args, rollout_id: int) -> dict[str, Any]:
4242
benchmarks = [cfg.to_payload() for cfg in self._config.datasets]
4343
benchmarks = [cfg for cfg in benchmarks if cfg]
4444
return {
@@ -47,8 +47,8 @@ def _build_payload(self, args, rollout_id: int) -> Dict[str, Any]:
4747
"benchmarks": benchmarks,
4848
}
4949

50-
def _request(self, payload: Dict[str, Any]) -> Dict[str, Any]:
51-
last_error: Optional[Exception] = None
50+
def _request(self, payload: dict[str, Any]) -> dict[str, Any]:
51+
last_error: Exception | None = None
5252
for attempt in range(1, self._max_retries + 1):
5353
try:
5454
response = self._session.post(

examples/eval/nemo_skills/skills_config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

3+
from collections.abc import Mapping
34
from dataclasses import dataclass, field
4-
from typing import Any, List, Mapping
5+
from typing import Any
56

67
from examples.eval.eval_delegate import EvalEnvConfig, EvalEnvDatasetConfig
78

@@ -35,10 +36,10 @@ def parse(cls, args, dataset_cfg: Mapping[str, Any], defaults: Mapping[str, Any]
3536
class SkillsEvalEnvConfig(EvalEnvConfig):
3637
"""Environment configuration shared by the Skills client/server."""
3738

38-
datasets: List[SkillsEvalEnvDatasetConfig] = field(default_factory=list)
39+
datasets: list[SkillsEvalEnvDatasetConfig] = field(default_factory=list)
3940

4041
@classmethod
41-
def parse(cls, args, raw_env_config: Mapping[str, Any], defaults: Mapping[str, Any]) -> "SkillsEvalEnvConfig":
42+
def parse(cls, args, raw_env_config: Mapping[str, Any], defaults: Mapping[str, Any]) -> SkillsEvalEnvConfig:
4243
base_cfg: SkillsEvalEnvConfig = super().parse(raw_env_config, defaults)
4344
datasets = raw_env_config.get("datasets") or []
4445
base_cfg.datasets = [

examples/eval/nemo_skills/skills_server.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@
3030
import threading
3131
import time
3232
import uuid
33+
from collections.abc import Mapping
3334
from dataclasses import dataclass, field
3435
from pathlib import Path
35-
from typing import Any, Dict, List, Mapping
36+
from typing import Any
3637

3738
REPO_ROOT = Path(__file__).resolve().parents[3]
3839
if str(REPO_ROOT) not in sys.path:
@@ -56,8 +57,8 @@
5657
class EvalRequestPayload:
5758
rollout_id: int
5859
router_url: str
59-
defaults: Dict[str, Any] = field(default_factory=dict)
60-
benchmarks: List[SkillsEvalEnvDatasetConfig] = field(default_factory=list)
60+
defaults: dict[str, Any] = field(default_factory=dict)
61+
benchmarks: list[SkillsEvalEnvDatasetConfig] = field(default_factory=list)
6162

6263

6364
# ---------------------------------------------------------------------------
@@ -83,8 +84,8 @@ def _hydra_overrides_from_benchmark(
8384
router_url: str,
8485
openai_model_name: str,
8586
max_concurrent_requests: int,
86-
) -> List[str]:
87-
overrides: List[str] = []
87+
) -> list[str]:
88+
overrides: list[str] = []
8889
for key, hydra_key in HYDRA_OVERRIDE_MAP.items():
8990
value = getattr(benchmark_cfg, key, None)
9091
if value is None:
@@ -114,7 +115,7 @@ class ServerConfig:
114115
max_concurrent_requests: int = 512
115116

116117
@classmethod
117-
def from_args(cls, args: argparse.Namespace) -> "ServerConfig":
118+
def from_args(cls, args: argparse.Namespace) -> ServerConfig:
118119
return cls(
119120
output_root=Path(args.output_root).expanduser().resolve(),
120121
cluster=args.cluster,
@@ -130,7 +131,7 @@ def __init__(self, config: ServerConfig):
130131
self._lock = threading.Lock()
131132
self._config.output_root.mkdir(parents=True, exist_ok=True)
132133

133-
def evaluate(self, payload: EvalRequestPayload) -> Dict[str, Any]:
134+
def evaluate(self, payload: EvalRequestPayload) -> dict[str, Any]:
134135
if not payload.benchmarks:
135136
warning_msg = "No benchmarks specified in delegate config; skipping NeMo Skills evaluation."
136137
logger.warning(warning_msg)
@@ -149,8 +150,8 @@ def evaluate(self, payload: EvalRequestPayload) -> Dict[str, Any]:
149150
run_dir = self._config.output_root / f"{int(time.time())}-{exp_name}"
150151
run_dir.mkdir(parents=True, exist_ok=True)
151152

152-
runs: List[Dict[str, Any]] = []
153-
raw_metrics: Dict[str, Any] = {}
153+
runs: list[dict[str, Any]] = []
154+
raw_metrics: dict[str, Any] = {}
154155
with self._lock:
155156
for benchmark in payload.benchmarks:
156157
result = self._run_single_benchmark(
@@ -182,7 +183,7 @@ def _run_single_benchmark(
182183
exp_name: str,
183184
router_url: str,
184185
run_dir: Path,
185-
) -> Dict[str, Any]:
186+
) -> dict[str, Any]:
186187
name = benchmark.name
187188
benchmark_run_dir = run_dir / name
188189
benchmark_run_dir.mkdir(parents=True, exist_ok=True)
@@ -220,7 +221,7 @@ def _build_command(
220221
run_dir: Path,
221222
defaults: Mapping[str, Any],
222223
benchmark_cfg: SkillsEvalEnvDatasetConfig,
223-
) -> List[str]:
224+
) -> list[str]:
224225
base_cmd = [
225226
"ns",
226227
"eval",
@@ -250,29 +251,29 @@ def _build_command(
250251
)
251252
return base_cmd + hydra_overrides
252253

253-
def _build_env(self) -> Dict[str, str]:
254+
def _build_env(self) -> dict[str, str]:
254255
env = os.environ.copy()
255256
return env
256257

257258
@staticmethod
258-
def _run_command(cmd: List[str], *, env: Dict[str, str], log_path: Path):
259+
def _run_command(cmd: list[str], *, env: dict[str, str], log_path: Path):
259260
with open(log_path, "w", encoding="utf-8") as log_file:
260261
process = subprocess.Popen(cmd, stdout=log_file, stderr=subprocess.STDOUT, env=env)
261262
retcode = process.wait()
262263
if retcode != 0:
263-
with open(log_path, "r", encoding="utf-8", errors="ignore") as log_file:
264+
with open(log_path, encoding="utf-8", errors="ignore") as log_file:
264265
tail = "".join(log_file.readlines()[-200:])
265266
raise RuntimeError(f"`ns eval` failed with exit code {retcode}. See {log_path}\n{tail}")
266267

267268
@staticmethod
268-
def _collect_metrics(run_dir: Path, benchmark: str) -> Dict[str, Any]:
269+
def _collect_metrics(run_dir: Path, benchmark: str) -> dict[str, Any]:
269270
benchmark_name = benchmark.split(":")[0]
270271
metrics_path = run_dir / "eval-results" / benchmark_name / "metrics.json"
271272
if not metrics_path.exists():
272273
logger.warning("Metrics file missing for %s at %s", benchmark_name, metrics_path)
273274
return {}
274275
try:
275-
with open(metrics_path, "r", encoding="utf-8") as fp:
276+
with open(metrics_path, encoding="utf-8") as fp:
276277
metrics_data = json.load(fp)
277278
except json.JSONDecodeError as exc:
278279
logger.warning("Failed to parse %s: %s", metrics_path, exc)

examples/formal_math/single_round/kimina_wrapper.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import os
33
import random
44
import time
5-
from typing import List
65

76
import ray
87
import requests
@@ -25,7 +24,7 @@ async def check(self, *args, **kwargs) -> CheckResponse:
2524

2625

2726
class _KiminaClientCluster:
28-
def __init__(self, servers: List["_KiminaServerActor"]):
27+
def __init__(self, servers: list["_KiminaServerActor"]):
2928
self._clients = [AsyncKiminaClient(api_url=ray.get(server.get_api_url.remote())) for server in servers]
3029
self._next_client_index = 0
3130

@@ -35,7 +34,7 @@ async def check(self, *args, **kwargs):
3534
return await client.check(*args, **kwargs)
3635

3736

38-
def _create_actor_per_node(actor_cls) -> List:
37+
def _create_actor_per_node(actor_cls) -> list:
3938
# for simplicity, we use all available nodes
4039
nodes = [n for n in ray.nodes() if n.get("Alive")]
4140
assert len(nodes) > 0

0 commit comments

Comments
 (0)