diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index 5b7e2f884..c64d5c3f3 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -14,7 +14,7 @@ import logging from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, Generic, List, ParamSpec, TypeVar +from typing import Generic, ParamSpec, TypeVar from monarch._src.actor.endpoint import EndpointProperty @@ -96,7 +96,7 @@ async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: sess_id = kwargs.pop("sess_id", None) return await self.service._call(sess_id, self.endpoint_name, *args, **kwargs) - async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> list[R]: """Broadcasts a request to all healthy replicas and returns the results as a list.""" result = await self.service.call_all(self.endpoint_name, *args, **kwargs) return result @@ -107,7 +107,7 @@ async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: "Services only support route() and fanout()." ) - async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + async def call(self, *args: P.args, **kwargs: P.kwargs) -> list[R]: raise NotImplementedError( "You tried to use call() on a service, not an actor. " "Services only support route() and fanout()." @@ -119,7 +119,7 @@ async def call_one(self, *args: P.args, **kwargs: P.kwargs) -> R: "Services only support route() and fanout()." ) - async def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + async def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> list[R]: raise NotImplementedError( "You tried to use broadcast() on a service, not an actor. " "Services only support route() and fanout()." @@ -157,7 +157,7 @@ async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: sess_id, self.endpoint_name, *args, **kwargs ) - async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + async def call(self, *args: P.args, **kwargs: P.kwargs) -> list[R]: """Broadcasts a request to all healthy replicas and returns the results as a list.""" result = await self.actor_mesh.call_all.call_one( self.endpoint_name, *args, **kwargs @@ -314,9 +314,9 @@ class Router(ABC): @abstractmethod def get_replica( self, - healthy_replicas: List[Replica], + healthy_replicas: list[Replica], sess_id: str | None = None, - session_map: Dict[str, int] | None = None, + session_map: dict[str, int] | None = None, ) -> Replica: """Select a replica from the list based on routing logic.""" pass diff --git a/src/forge/controller/service/metrics.py b/src/forge/controller/service/metrics.py index d328728bd..728d7d57a 100644 --- a/src/forge/controller/service/metrics.py +++ b/src/forge/controller/service/metrics.py @@ -12,7 +12,6 @@ """ from dataclasses import dataclass, field -from typing import Dict, List from forge.controller.service.replica import ReplicaMetrics @@ -35,7 +34,7 @@ class ServiceMetrics: """ # Replica metrics - replica_metrics: Dict[int, ReplicaMetrics] = field(default_factory=dict) + replica_metrics: dict[int, ReplicaMetrics] = field(default_factory=dict) # Service-level metrics total_sessions: int = 0 healthy_replicas: int = 0 @@ -50,7 +49,7 @@ def get_total_request_rate(self, window_seconds: float = 60.0) -> float: for metrics in self.replica_metrics.values() ) - def get_avg_queue_depth(self, replicas: List) -> float: + def get_avg_queue_depth(self, replicas: list) -> float: """Get average queue depth across all healthy replicas.""" healthy_replicas = [r for r in replicas if r.healthy] if not healthy_replicas: @@ -58,7 +57,7 @@ def get_avg_queue_depth(self, replicas: List) -> float: total_queue_depth = sum(r.request_queue.qsize() for r in healthy_replicas) return total_queue_depth / len(healthy_replicas) - def get_avg_capacity_utilization(self, replicas: List) -> float: + def get_avg_capacity_utilization(self, replicas: list) -> float: """Get average capacity utilization across all healthy replicas.""" healthy_replicas = [r for r in replicas if r.healthy] if not healthy_replicas: diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py index 502402e36..a53d9c873 100644 --- a/src/forge/controller/service/router.py +++ b/src/forge/controller/service/router.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Dict, List from .interface import Router from .replica import Replica @@ -22,9 +21,9 @@ def __init__(self): def get_replica( self, - healthy_replicas: List[Replica], + healthy_replicas: list[Replica], sess_id: str | None = None, - session_map: Dict[str, int] | None = None, + session_map: dict[str, int] | None = None, ) -> Replica: if not healthy_replicas: raise RuntimeError("No healthy replicas available for load balancing") @@ -40,9 +39,9 @@ class LeastLoadedRouter(Router): def get_replica( self, - healthy_replicas: List[Replica], + healthy_replicas: list[Replica], sess_id: str | None = None, - session_map: Dict[str, int] | None = None, + session_map: dict[str, int] | None = None, ) -> Replica: if not healthy_replicas: raise RuntimeError("No healthy replicas available for session assignment") @@ -57,9 +56,9 @@ def __init__(self, fallback_router: Router): def get_replica( self, - healthy_replicas: List[Replica], + healthy_replicas: list[Replica], sess_id: str | None = None, - session_map: Dict[str, int] | None = None, + session_map: dict[str, int] | None = None, ) -> Replica: if sess_id is None: raise ValueError("SessionRouter requires a session ID") diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 0b655fb6a..fe7184d6e 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -36,7 +36,6 @@ import logging import pprint import uuid -from typing import Dict, List from monarch.actor import Actor, endpoint @@ -92,7 +91,7 @@ def __init__( self._active_sessions = [] self._id_session_map = {} - self._session_replica_map: Dict[str, int] = {} + self._session_replica_map: dict[str, int] = {} # Initialize metrics collection self._metrics = ServiceMetrics() @@ -196,7 +195,7 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): ) raise - async def call_all(self, function: str, *args, **kwargs) -> List: + async def call_all(self, function: str, *args, **kwargs) -> list: """ Broadcasts a function call to all healthy replicas and returns results as a list. @@ -622,7 +621,7 @@ def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict): self._active_sessions = [] self._id_session_map = {} - self._session_replica_map: Dict[str, int] = {} + self._session_replica_map: dict[str, int] = {} self._next_replica_idx = 0 # For round-robin load balancing # Initialize metrics collection @@ -726,7 +725,7 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): raise @endpoint - async def call_all(self, function: str, *args, **kwargs) -> List: + async def call_all(self, function: str, *args, **kwargs) -> list: """ Broadcasts a function call to all healthy replicas and returns results as a list. diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 9de5fe60b..e50cc3fd6 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -7,7 +7,7 @@ import asyncio import logging import os -from typing import Any, Dict, Union +from typing import Any, Union from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc @@ -127,7 +127,7 @@ def __init__(self, global_logger: Union["GlobalLoggingActor", None] = None) -> N @endpoint async def flush( self, global_step: int, return_state: bool = False - ) -> Dict[str, Dict[str, Any]]: + ) -> dict[str, dict[str, Any]]: """Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True. This should only ever be called by the global logger. @@ -136,7 +136,7 @@ async def flush( return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. If False, returns empty dict, else returns the state of all metrics collected. Returns: - Dict[str, Dict[str, Any]]: Dict of {metric_key: metric_state}, + dict[str, dict[str, Any]]: Dict of {metric_key: metric_state}, e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ collector = MetricCollector() @@ -146,8 +146,8 @@ async def flush( @endpoint async def init_backends( self, - metadata_per_primary_backend: Dict[str, Dict[str, Any]], - config: Dict[str, Any], + metadata_per_primary_backend: dict[str, dict[str, Any]], + config: dict[str, Any], ) -> None: """Init local (per-rank) logger backends and MetricCollector.""" collector = MetricCollector() @@ -179,13 +179,13 @@ class GlobalLoggingActor(Actor): """ def __init__(self): - self.fetchers: Dict[str, LocalFetcherActor] = {} - self.config: Dict[str, Any] | None = None - self.global_logger_backends: Dict[str, LoggerBackend] = {} - self.metadata_per_primary_backend: Dict[str, Dict[str, Any]] = {} + self.fetchers: dict[str, LocalFetcherActor] = {} + self.config: dict[str, Any] | None = None + self.global_logger_backends: dict[str, LoggerBackend] = {} + self.metadata_per_primary_backend: dict[str, dict[str, Any]] = {} @endpoint - async def init_backends(self, config: Dict[str, Any]) -> None: + async def init_backends(self, config: dict[str, Any]) -> None: """ Sets config in global actor, so other actors can get it, then eagerly initializes backend and MetricCollectors in all registered fetchers. @@ -201,7 +201,7 @@ async def init_backends(self, config: Dict[str, Any]) -> None: and reduce them to a single value, which will be logged by the primary backend in this controller. Args: - config (Dict[str, Any]): Config for metric logging where keys are backend names, + config (dict[str, Any]): Config for metric logging where keys are backend names, e.g. {"console": {"reduce_across_ranks": True}, "wandb": {"reduce_across_ranks": False}} """ self.config = config diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index b59e4b723..3ce849ad2 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Dict, List +from typing import Any import pytz from monarch.actor import context, current_rank @@ -143,18 +143,18 @@ def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None collector.push(metric) -def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> List[Metric]: +def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metric]: """Reduce metric accumulators states to a list of metrics. Can be used when reducing metrics across ranks or services, as merging states is more precise than merging locally reduced metrics. Args: - states (List[Dict[str, Dict[str, Any]]]): List of state of one or more metrics, + states (list[dict[str, dict[str, Any]]]): List of state of one or more metrics, normally retrieved using `forge.observability.metrics.MetricAccumulator.get_state()`. Returns: - List[Metric]: List of reduced metrics + list[Metric]: List of reduced metrics Example: states = [ @@ -226,13 +226,13 @@ def get_value(self) -> Any: pass @abstractmethod - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: """Returns serializable state for cross-rank merge (e.g., {'sum': 10.0, 'count': 5}).""" pass @classmethod @abstractmethod - def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> Any: + def get_reduced_value_from_states(cls, states: list[dict[str, Any]]) -> Any: """Merges states from multiple ranks into single reduced value (e.g., total_sum/total_count for MEAN).""" pass @@ -256,7 +256,7 @@ def append(self, value: Any) -> None: def get_value(self) -> float: return self.sum / self.count if self.count > 0 else 0.0 - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return { "reduction_type": self.reduction_type.value, "sum": self.sum, @@ -264,7 +264,7 @@ def get_state(self) -> Dict[str, Any]: } @classmethod - def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + def get_reduced_value_from_states(cls, states: list[dict[str, Any]]) -> float: total_sum = sum(s["sum"] for s in states) total_count = sum(s["count"] for s in states) return total_sum / total_count if total_count > 0 else 0.0 @@ -286,11 +286,11 @@ def append(self, value: Any) -> None: def get_value(self) -> float: return self.total - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return {"reduction_type": self.reduction_type.value, "total": self.total} @classmethod - def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + def get_reduced_value_from_states(cls, states: list[dict[str, Any]]) -> float: return sum(s["total"] for s in states) def reset(self) -> None: @@ -309,11 +309,11 @@ def append(self, value: Any) -> None: def get_value(self) -> float: return self.max_val - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return {"reduction_type": self.reduction_type.value, "max_val": self.max_val} @classmethod - def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + def get_reduced_value_from_states(cls, states: list[dict[str, Any]]) -> float: return max(s["max_val"] for s in states) def reset(self) -> None: @@ -332,11 +332,11 @@ def append(self, value: Any) -> None: def get_value(self) -> float: return self.min_val - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return {"reduction_type": self.reduction_type.value, "min_val": self.min_val} @classmethod - def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + def get_reduced_value_from_states(cls, states: list[dict[str, Any]]) -> float: return min(s["min_val"] for s in states) def reset(self) -> None: @@ -365,7 +365,7 @@ def get_value(self) -> float: variance = (self.sum_sq / self.count) - (mean * mean) return max(0.0, variance) ** 0.5 - def get_state(self) -> Dict[str, Any]: + def get_state(self) -> dict[str, Any]: return { "reduction_type": self.reduction_type.value, "sum": self.sum, @@ -374,7 +374,7 @@ def get_state(self) -> Dict[str, Any]: } @classmethod - def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + def get_reduced_value_from_states(cls, states: list[dict[str, Any]]) -> float: total_sum = sum(s["sum"] for s in states) total_sum_sq = sum(s["sum_sq"] for s in states) total_count = sum(s["count"] for s in states) @@ -411,7 +411,7 @@ class MetricCollector: - Resets accumulators post-flush to avoid leaks across train steps; """ - _instances: Dict[int, "MetricCollector"] = {} + _instances: dict[int, "MetricCollector"] = {} _singleton_rank: int def __new__(cls): @@ -434,24 +434,24 @@ def __init__(self) -> None: if hasattr(self, "_is_initialized"): return - self.accumulators: Dict[str, MetricAccumulator] = {} + self.accumulators: dict[str, MetricAccumulator] = {} self.rank = current_rank().rank - self.logger_backends: List[LoggerBackend] = [] + self.logger_backends: list[LoggerBackend] = [] self._is_initialized = False async def init_backends( self, - metadata_per_primary_backend: Dict[str, Dict[str, Any]] | None, - config: Dict[str, Any], + metadata_per_primary_backend: dict[str, dict[str, Any]] | None, + config: dict[str, Any], ) -> None: """A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, the backend is instantiated per-rank, in the MetricCollector, otherwise it is only instantiated once globally. Args: - metadata_per_primary_backend (Dict[str, Dict[str, Any]] | None): Metadata from primary + metadata_per_primary_backend (dict[str, dict[str, Any]] | None): Metadata from primary logger backend, e.g., {"wandb": {"run_id": "abc123"}}. - config (Dict[str, Any]): Logger backend configuration, e.g. {"wandb": {"project": "my_project"}}. + config (dict[str, Any]): Logger backend configuration, e.g. {"wandb": {"project": "my_project"}}. """ if self._is_initialized: logger.debug(f"Rank {self.rank}: MetricCollector already initialized") @@ -519,7 +519,7 @@ def push(self, metric: Metric) -> None: async def flush( self, global_step: int, return_state: bool = False - ) -> Dict[str, Dict[str, Any]]: + ) -> dict[str, dict[str, Any]]: """Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True. Args: @@ -527,7 +527,7 @@ async def flush( return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. If False, returns empty dict, else returns the state of all metrics collected. Returns: - Dict[str, Dict[str, Dict[str, Any]]]: Dict of {metric_key: metric_state}, + dict[str, dict[str, dict[str, Any]]]: Dict of {metric_key: metric_state}, e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ if not self._is_initialized: @@ -585,14 +585,14 @@ async def shutdown(self): class LoggerBackend(ABC): """Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc.""" - def __init__(self, logger_backend_config: Dict[str, Any]) -> None: + def __init__(self, logger_backend_config: dict[str, Any]) -> None: self.logger_backend_config = logger_backend_config @abstractmethod async def init( self, role: BackendRole, - primary_logger_metadata: Dict[str, Any] | None = None, + primary_logger_metadata: dict[str, Any] | None = None, ) -> None: """ Initializes backend, e.g. wandb.run.init(). @@ -600,7 +600,7 @@ async def init( Args: role (BackendRole): BackendRole.GLOBAL (controller/primary) or BackendRole.LOCAL (per-rank/secondary). Can be used to behave differently for primary vs secondary roles. - primary_logger_metadata (Dict[str, Any] | None): From global backend for + primary_logger_metadata (dict[str, Any] | None): From global backend for backend that required shared info, e.g. {"shared_run_id": "abc123"}. Raises: ValueError if missing metadata for shared local init. @@ -608,7 +608,7 @@ async def init( pass @abstractmethod - async def log(self, metrics: List[Metric], global_step: int) -> None: + async def log(self, metrics: list[Metric], global_step: int) -> None: """ Log a batch of metrics to the backend. @@ -621,7 +621,7 @@ async def log(self, metrics: List[Metric], global_step: int) -> None: async def finish(self) -> None: pass - def get_metadata_for_secondary_ranks(self) -> Dict[str, Any] | None: + def get_metadata_for_secondary_ranks(self) -> dict[str, Any] | None: """Return sharable state after primary init (e.g., for shared modes). Called only on globals.""" return None @@ -629,13 +629,13 @@ def get_metadata_for_secondary_ranks(self) -> Dict[str, Any] | None: class ConsoleBackend(LoggerBackend): """Simple console logging of metrics.""" - def __init__(self, logger_backend_config: Dict[str, Any]) -> None: + def __init__(self, logger_backend_config: dict[str, Any]) -> None: super().__init__(logger_backend_config) async def init( self, role: BackendRole, - primary_logger_metadata: Dict[str, Any] | None = None, + primary_logger_metadata: dict[str, Any] | None = None, ) -> None: self.prefix = ( get_actor_name_with_rank() @@ -643,7 +643,7 @@ async def init( else "Controller" ) - async def log(self, metrics: List[Metric], global_step: int) -> None: + async def log(self, metrics: list[Metric], global_step: int) -> None: metrics_str = "\n".join( f" {metric.key}: {metric.value}" for metric in sorted(metrics, key=lambda m: m.key) @@ -674,7 +674,7 @@ class WandbBackend(LoggerBackend): group (str, optional): WandB group name for organizing runs. Defaults to "experiment_group" """ - def __init__(self, logger_backend_config: Dict[str, Any]) -> None: + def __init__(self, logger_backend_config: dict[str, Any]) -> None: super().__init__(logger_backend_config) self.project = logger_backend_config["project"] self.group = logger_backend_config.get("group", "experiment_group") @@ -688,7 +688,7 @@ def __init__(self, logger_backend_config: Dict[str, Any]) -> None: async def init( self, role: BackendRole, - primary_logger_metadata: Dict[str, Any] | None = None, + primary_logger_metadata: dict[str, Any] | None = None, ) -> None: if primary_logger_metadata is None: @@ -737,7 +737,7 @@ async def _init_shared_global(self): ) self.run = wandb.init(project=self.project, group=self.group, settings=settings) - async def _init_shared_local(self, primary_metadata: Dict[str, Any]): + async def _init_shared_local(self, primary_metadata: dict[str, Any]): import wandb shared_id = primary_metadata.get("shared_run_id") @@ -762,7 +762,7 @@ async def _init_shared_local(self, primary_metadata: Dict[str, Any]): settings=settings, ) - async def log(self, metrics: List[Metric], global_step: int) -> None: + async def log(self, metrics: list[Metric], global_step: int) -> None: if self.run: # Convert metrics to WandB log format log_data = {"global_step": global_step} @@ -776,7 +776,7 @@ async def log(self, metrics: List[Metric], global_step: int) -> None: else: logger.debug(f"WandbBackend: No run started, skipping log for {self.name}") - def get_metadata_for_secondary_ranks(self) -> Dict[str, Any]: + def get_metadata_for_secondary_ranks(self) -> dict[str, Any]: if self.run and not self.reduce_across_ranks and self.share_run_id: return {"shared_run_id": self.run.id} return {} diff --git a/src/forge/observability/perf_tracker.py b/src/forge/observability/perf_tracker.py index 19e8ccb47..d34a52a95 100644 --- a/src/forge/observability/perf_tracker.py +++ b/src/forge/observability/perf_tracker.py @@ -11,7 +11,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from functools import lru_cache, wraps -from typing import List, Protocol, Tuple +from typing import Protocol import torch @@ -216,7 +216,7 @@ def start(self) -> None: def step(self, name: str) -> None: ... - def get_all_durations(self) -> List[Tuple[str, float]]: + def get_all_durations(self) -> list[tuple[str, float]]: ... @@ -226,7 +226,7 @@ class _TimerCPU(_TimerProtocol): """ def __init__(self) -> None: - self._durations: List[Tuple[str, float]] = [] + self._durations: list[tuple[str, float]] = [] self._chain_start: float | None = None def start(self) -> None: @@ -242,7 +242,7 @@ def step(self, name: str) -> None: self._durations.append((name, delta_ms)) self._chain_start = now - def get_all_durations(self) -> List[Tuple[str, float]]: + def get_all_durations(self) -> list[tuple[str, float]]: return self._durations[:] @@ -255,10 +255,10 @@ def __init__(self, max_workers: int = 2) -> None: if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available for timing") self._executor = ThreadPoolExecutor(max_workers=max_workers) - self._futures: List[ - Tuple[str, Future[float], int] + self._futures: list[ + tuple[str, Future[float], int] ] = [] # (name, future, submission_index) - self._durations: List[Tuple[str, float]] = [] + self._durations: list[tuple[str, float]] = [] self._chain_start: torch.cuda.Event | None = None def start(self) -> None: @@ -323,7 +323,7 @@ def _collect_completed_futures(self) -> None: self._futures = still_pending - def get_all_durations(self) -> List[Tuple[str, float]]: + def get_all_durations(self) -> list[tuple[str, float]]: """Retrieve list of (name, duration) tuples in submission order after waiting for background polls to finish.""" # Wait and collect if pendings; return durations. self._collect_completed_futures() diff --git a/tests/unit_tests/observability/test_perf_tracker.py b/tests/unit_tests/observability/test_perf_tracker.py index 01d1603d1..81723f2dd 100644 --- a/tests/unit_tests/observability/test_perf_tracker.py +++ b/tests/unit_tests/observability/test_perf_tracker.py @@ -7,7 +7,7 @@ import asyncio import time from contextlib import contextmanager -from typing import List, Literal, Tuple, Union +from typing import Literal, Union from unittest.mock import Mock, patch import pytest @@ -21,7 +21,7 @@ @pytest.fixture def mock_record_metric_calls(monkeypatch): """Mock record_metric that tracks all calls.""" - calls: List[Tuple[str, float, Reduce]] = [] + calls: list[tuple[str, float, Reduce]] = [] def mock_record_metric(name, val, red): calls.append((name, val, red)) diff --git a/tests/unit_tests/rl/environments/test_chat.py b/tests/unit_tests/rl/environments/test_chat.py index 678e0a33f..4abf89dc6 100644 --- a/tests/unit_tests/rl/environments/test_chat.py +++ b/tests/unit_tests/rl/environments/test_chat.py @@ -10,7 +10,7 @@ # LICENSE file in the root directory of this source tree. import unittest -from typing import Any, Dict, List, Optional +from typing import Any, Optional from unittest.mock import MagicMock import torch @@ -29,9 +29,9 @@ class MockTokenizer: def apply_chat_template( self, - conversation: List[Dict[str, str]], - tools: Optional[List[Dict]] = None, - documents: Optional[List[Dict[str, str]]] = None, + conversation: list[dict[str, str]], + tools: Optional[list[dict]] = None, + documents: Optional[list[dict[str, str]]] = None, chat_template: Optional[str] = None, add_generation_prompt: bool = False, continue_final_message: bool = False, @@ -42,7 +42,7 @@ def apply_chat_template( return_tensors: Optional[str] = None, return_dict: bool = False, return_assistant_tokens_mask: bool = False, - tokenizer_kwargs: Optional[Dict[str, Any]] = None, + tokenizer_kwargs: Optional[dict[str, Any]] = None, **kwargs, ) -> torch.Tensor: """Mock implementation of apply_chat_template.""" @@ -90,7 +90,7 @@ def test_init(self): def test_init_with_values(self): """Test initialization with provided values.""" - messages: List[Message] = [{"role": "user", "content": "Hello"}] + messages: list[Message] = [{"role": "user", "content": "Hello"}] tokens = [torch.tensor([1, 2, 3])] state = ChatState(history_messages=messages, history_tokens=tokens) self.assertEqual(state.history_messages, messages) @@ -111,7 +111,7 @@ def test_init(self): def test_init_with_values(self): """Test initialization with provided values.""" - messages: List[Message] = [{"role": "user", "content": "Hello"}] + messages: list[Message] = [{"role": "user", "content": "Hello"}] tokens = torch.tensor([1, 2, 3]) obs = ChatObservation( messages=messages,