diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 330bead57..80fced0dd 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -246,7 +246,7 @@ async def setup(self): self.request_id = 0 self.policy_version = 0 - self.requests: dict[str, tuple[None | ParentRequest, asyncio.Future]] = {} + self.requests: dict[str, tuple[ParentRequest | None, asyncio.Future]] = {} # TODO: Investigate whether this can be combined with `policy.running` # Whether this policy is accepting requests. diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 5ca331f32..258849429 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -12,7 +12,6 @@ import os import socket import uuid -from typing import Optional from monarch._src.actor.shape import NDSlice, Shape from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host @@ -163,7 +162,7 @@ async def get_proc_mesh( num_procs: int, with_gpus: bool = False, num_hosts: int | None = None, - mesh_name: Optional[str] = None, + mesh_name: str | None = None, host_mesh: HostMesh | None = None, env_vars: dict[str, str] | None = None, addr: str | None = None, diff --git a/src/forge/controller/service/replica.py b/src/forge/controller/service/replica.py index 9ab8ec20a..fa9a7791a 100644 --- a/src/forge/controller/service/replica.py +++ b/src/forge/controller/service/replica.py @@ -11,7 +11,6 @@ from collections import deque from dataclasses import dataclass, field from enum import Enum -from typing import Optional from monarch.actor import ActorError @@ -81,7 +80,7 @@ class ServiceRequest: """ - session_id: Optional[str] + session_id: str | None function: str args: tuple kwargs: dict @@ -107,7 +106,7 @@ class Replica: actor_kwargs: dict # The Actor that this replica is running - actor: Optional[ForgeActor] = None + actor: ForgeActor | None = None # Async queue for incoming requests request_queue: asyncio.Queue[ServiceRequest] = field(default_factory=asyncio.Queue) @@ -127,10 +126,10 @@ class Replica: return_first_rank_result: bool = False # Recovery-related state - _recovery_task: Optional[asyncio.Task] = None + _recovery_task: asyncio.Task | None = None # Run task is the replica's event loop - _run_task: Optional[asyncio.Task] = None + _run_task: asyncio.Task | None = None # Metrics tracking metrics: ReplicaMetrics = field(default_factory=ReplicaMetrics) diff --git a/src/forge/data/datasets/hf_dataset.py b/src/forge/data/datasets/hf_dataset.py index 56b17712b..6be68b41b 100644 --- a/src/forge/data/datasets/hf_dataset.py +++ b/src/forge/data/datasets/hf_dataset.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Any, Callable, Iterator, Optional +from typing import Any, Callable, Iterator import torch import torch.distributed as dist @@ -37,26 +37,26 @@ class HfIterableDataset(InfiniteTuneIterableDataset): - Returning an infinite iterator over the dataset Args: - message_transform (Optional[Transform]): Transforms raw data into a `Message`. - model_transform (Optional[Transform]): Prepares messages for the model, + message_transform (Transform | None): Transforms raw data into a `Message`. + model_transform (Transform | None): Prepares messages for the model, usually by tokenizing them. - output_transform (Optional[Transform]): Prepares tokenized inputs for the + output_transform (Transform | None): Prepares tokenized inputs for the recipe, often by manipulating labels (e.g., setting an ignore index). This transform is recipe-dependent (e.g., SFT, DPO, etc.). - metric_transform (Optional[MetricTransform]): Computes metrics from a + metric_transform (MetricTransform | None): Computes metrics from a sample (e.g., token count). If ``None``, a default transform is used. To disable standard metric tracking, set this to ``lambda x: x``. - shuffle_buffer_size (Optional[int]): Size of the shuffle buffer. + shuffle_buffer_size (int | None): Size of the shuffle buffer. If ``None`` or 0, no shuffling is performed. - weight (Optional[float]): Weight for this dataset. Defaults to 1.0. + weight (float | None): Weight for this dataset. Defaults to 1.0. seed (int): Seed for shuffling. num_shards_per_rank (int): The target number of shards per worker (GPU). The actual number of shards will be a multiple of ``world_size * dataloader_workers``. - dataset_name (Optional[str]): Name of the dataset. If ``None``, a name is + dataset_name (str | None): Name of the dataset. If ``None``, a name is generated from the ``path``, ``source``, and ``split``. - filter_fn (Optional[Callable]): A function to filter the dataset. - filter_kwargs (Optional[dict[str, Any]]): Keyword arguments for ``filter_fn``. + filter_fn (Callable | None): A function to filter the dataset. + filter_kwargs (dict[str, Any] | None): Keyword arguments for ``filter_fn``. **load_dataset_kwargs: Keyword arguments for the :func:`~datasets.load_dataset` function. """ @@ -64,17 +64,17 @@ class HfIterableDataset(InfiniteTuneIterableDataset): def __init__( self, *, - message_transform: Optional[Transform] = None, - model_transform: Optional[Transform] = None, - output_transform: Optional[Transform] = None, - metric_transform: Optional[MetricTransform] = None, - shuffle_buffer_size: Optional[int] = 1000, - weight: Optional[float] = 1.0, + message_transform: Transform | None = None, + model_transform: Transform | None = None, + output_transform: Transform | None = None, + metric_transform: MetricTransform | None = None, + shuffle_buffer_size: int | None = 1000, + weight: float | None = 1.0, seed: int = 42, num_shards_per_rank: int = 64, - dataset_name: Optional[str] = None, - filter_fn: Optional[Callable] = None, - filter_kwargs: Optional[dict[str, Any]] = None, + dataset_name: str | None = None, + filter_fn: Callable | None = None, + filter_kwargs: dict[str, Any] | None = None, **load_dataset_kwargs, ): # Store configuration @@ -135,8 +135,8 @@ def _setup_hf_dataset( self, load_dataset_kwargs: dict[str, Any], num_shards_per_rank: int, - filter_fn: Optional[Callable] = None, - filter_kwargs: Optional[dict[str, Any]] = None, + filter_fn: Callable | None = None, + filter_kwargs: dict[str, Any] | None = None, ): """ One-time setup of HuggingFace dataset that handles Handles distributed sharding, diff --git a/src/forge/data/datasets/packed.py b/src/forge/data/datasets/packed.py index 105921acb..d09c158c0 100644 --- a/src/forge/data/datasets/packed.py +++ b/src/forge/data/datasets/packed.py @@ -7,7 +7,7 @@ import logging from abc import ABC, abstractmethod from collections import deque -from typing import Any, Generic, Iterable, Iterator, Optional, TypeVar +from typing import Any, Generic, Iterable, Iterator, TypeVar import torch from torch.nn.attention.flex_attention import ( @@ -329,13 +329,13 @@ def _reset_packer_state(self) -> None: self._buffer.clear() # current_pack: the current pack being built - self._current_pack: Optional[dict[str, list]] = None + self._current_pack: dict[str, list] | None = None # current_pack_size: the number of tokens in the current pack self._current_pack_size: int = 0 # iterator: the iterator over the dataset - self._iterator: Optional[Iterator[SampleType]] = None + self._iterator: Iterator[SampleType] | None = None # current_doc_id_in_pack: the document ID to use for the next sample self._current_doc_id_in_pack: int = 0 @@ -367,7 +367,7 @@ def _fill_buffer(self, iterator: Iterator[SampleType]) -> None: except StopIteration: self._exhausted = True - def _find_next_fitting_sample(self, remaining_size: int) -> Optional[int]: + def _find_next_fitting_sample(self, remaining_size: int) -> int | None: """ Find the first sample in the buffer that fits in the remaining space. @@ -375,7 +375,7 @@ def _find_next_fitting_sample(self, remaining_size: int) -> Optional[int]: remaining_size (int): The remaining space in the current pack. Returns: - Optional[int]: The index of the sample in the buffer, or None if no sample fits. + int | None: The index of the sample in the buffer, or None if no sample fits. Example: self._buffer = deque([(sample1, 200), (sample2, 100), (sample3, 48), (sample4, 200)]) @@ -397,7 +397,7 @@ def _find_next_fitting_sample(self, remaining_size: int) -> Optional[int]: return i return None - def _build_one_pack(self, iterator: Iterator[SampleType]) -> Optional[SampleDict]: + def _build_one_pack(self, iterator: Iterator[SampleType]) -> SampleDict | None: """ Builds a pack of samples from the buffer. @@ -405,7 +405,7 @@ def _build_one_pack(self, iterator: Iterator[SampleType]) -> Optional[SampleDict iterator (Iterator[SampleType]): The iterator over the dataset. Returns: - Optional[SampleDict]: The pack of samples, or None if the dataset is exhausted. + SampleDict | None: The pack of samples, or None if the dataset is exhausted. """ # Start a new pack if necessary if self._current_pack is None: diff --git a/src/forge/data/datasets/sft_dataset.py b/src/forge/data/datasets/sft_dataset.py index e6f6edcfb..fca97f912 100644 --- a/src/forge/data/datasets/sft_dataset.py +++ b/src/forge/data/datasets/sft_dataset.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Callable, Optional +from typing import Any, Callable import torch @@ -26,7 +26,7 @@ class AlpacaToMessages(Transform): due to this custom logic. Args: - column_map (Optional[dict[str, str]]): a mapping to change the expected "instruction", "input", + column_map (dict[str, str] | None): a mapping to change the expected "instruction", "input", and "output" column names to the actual column names in the dataset. Default is None, keeping the default column names. masking_strategy (str): masking strategy to use for model training. @@ -45,7 +45,7 @@ class AlpacaToMessages(Transform): def __init__( self, - column_map: Optional[dict[str, str]] = None, + column_map: dict[str, str] | None = None, masking_strategy: str = "train_on_all", ): self.masking_strategy = masking_strategy @@ -158,12 +158,12 @@ def sft_iterable_dataset( *, weight: int = 1, message_transform: Transform, - shuffle_buffer_size: Optional[int] = 1000, + shuffle_buffer_size: int | None = 1000, seed: int = 42, num_shards_per_rank: int = 64, - dataset_name: Optional[str] = None, - filter_fn: Optional[Callable] = None, - filter_kwargs: Optional[dict[str, Any]] = None, + dataset_name: str | None = None, + filter_fn: Callable | None = None, + filter_kwargs: dict[str, Any] | None = None, **load_dataset_kwargs: dict[str, Any], ) -> HfIterableDataset: """ @@ -173,12 +173,12 @@ def sft_iterable_dataset( model_transform (Transform): Usually the tokenizer weight (int): Weight of the dataset. Used for sampling when interleaving datasets. message_transform (Transform): Transform to convert raw data to messages - shuffle_buffer_size (Optional[int]): Buffer size for shuffling + shuffle_buffer_size (int | None): Buffer size for shuffling seed (int): Random seed for shuffling num_shards_per_rank (int): Target shards per worker - dataset_name (Optional[str]): Name for metrics namespacing - filter_fn (Optional[Callable]): Filter function - filter_kwargs (Optional[dict[str, Any]]): Filter function kwargs + dataset_name (str | None): Name for metrics namespacing + filter_fn (Callable | None): Filter function + filter_kwargs (dict[str, Any] | None): Filter function kwargs **load_dataset_kwargs (dict[str, Any]): Args passed to load_dataset Returns: diff --git a/src/forge/data/tokenizer.py b/src/forge/data/tokenizer.py index 3cb90f79c..f3d244ff5 100644 --- a/src/forge/data/tokenizer.py +++ b/src/forge/data/tokenizer.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import json -from typing import Any, Optional +from typing import Any import jinja2 from jinja2 import StrictUndefined @@ -28,8 +28,8 @@ class HuggingFaceBaseTokenizer(BaseTokenizer): Args: tokenizer_json_path (str): Path to tokenizer.json file - tokenizer_config_json_path (Optional[str]): Path to tokenizer_config.json file. Default: None - generation_config_path (Optional[str]): Path to generation_config.json file. + tokenizer_config_json_path (str | None): Path to tokenizer_config.json file. Default: None + generation_config_path (str | None): Path to generation_config.json file. Default: None Raises: @@ -40,8 +40,8 @@ def __init__( self, tokenizer_json_path: str, *, - tokenizer_config_json_path: Optional[str] = None, - generation_config_path: Optional[str] = None, + tokenizer_config_json_path: str | None = None, + generation_config_path: str | None = None, ): self.tokenizer = Tokenizer.from_file(tokenizer_json_path) if not (tokenizer_config_json_path or generation_config_path): @@ -209,8 +209,8 @@ class HuggingFaceModelTokenizer(ModelTokenizer): Args: tokenizer_json_path (str): Path to tokenizer.json file - tokenizer_config_json_path (Optional[str]): Path to tokenizer_config.json file. Default: None - generation_config_path (Optional[str]): Path to generation_config.json file. + tokenizer_config_json_path (str | None): Path to tokenizer_config.json file. Default: None + generation_config_path (str | None): Path to generation_config.json file. Default: None truncation_type (str): type of truncation to apply, either "left" or "right". Default is "right". @@ -220,8 +220,8 @@ def __init__( self, tokenizer_json_path: str, *, - tokenizer_config_json_path: Optional[str] = None, - generation_config_path: Optional[str] = None, + tokenizer_config_json_path: str | None = None, + generation_config_path: str | None = None, truncation_type: str = "right", ): self.base_tokenizer = HuggingFaceBaseTokenizer( @@ -274,7 +274,7 @@ def tokenize_messages( self, messages: list[Message], add_eos: bool = True, - max_seq_len: Optional[int] = None, + max_seq_len: int | None = None, ) -> tuple[list[int], list[bool]]: tokenized_messages = [] mask = [] diff --git a/src/forge/data/utils.py b/src/forge/data/utils.py index e87ff50c9..b2fdaec0c 100644 --- a/src/forge/data/utils.py +++ b/src/forge/data/utils.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from enum import Enum -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Union import torch @@ -118,7 +118,7 @@ def __repr__(self) -> str: def truncate( tokens: list[Any], max_seq_len: int, - eos_id: Optional[Any] = None, + eos_id: Any | None = None, truncation_type: str = "right", ) -> list[Any]: """ @@ -128,7 +128,7 @@ def truncate( Args: tokens (list[Any]): list of tokens to truncate max_seq_len (int): maximum length of the list - eos_id (Optional[Any]): token to replace the last token with. If None, the + eos_id (Any | None): token to replace the last token with. If None, the last token will not be replaced. Default is None. truncation_type (str): type of truncation to apply, either "left" or "right". Default is "right". diff --git a/src/forge/data_models/episode.py b/src/forge/data_models/episode.py index 835373d18..6b908ff87 100644 --- a/src/forge/data_models/episode.py +++ b/src/forge/data_models/episode.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Optional, Sequence +from typing import Sequence import torch @@ -32,7 +32,7 @@ class Episode: # The log probabilities of the target tokens, for prompt part it's set to 0, # for generation part it's computed from the Generator/Sampler. - log_probs: Optional[torch.Tensor] = None + log_probs: torch.Tensor | None = None # TODO: add more fields as required state: str = "" diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 0c4d15c34..9de5fe60b 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -7,7 +7,7 @@ import asyncio import logging import os -from typing import Any, Dict, Optional +from typing import Any, Dict, Union from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc @@ -120,7 +120,7 @@ class LocalFetcherActor(Actor): GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector """ - def __init__(self, global_logger: Optional["GlobalLoggingActor"] = None) -> None: + def __init__(self, global_logger: Union["GlobalLoggingActor", None] = None) -> None: self.global_logger = global_logger _is_initialized = False diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 3c5386af9..b59e4b723 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List import pytz from monarch.actor import context, current_rank @@ -60,7 +60,7 @@ class Metric: key: str value: Any reduction: Reduce - timestamp: Optional[float] = None + timestamp: float | None = None def __post_init__(self): if self.timestamp is None: @@ -441,7 +441,7 @@ def __init__(self) -> None: async def init_backends( self, - metadata_per_primary_backend: Optional[Dict[str, Dict[str, Any]]], + metadata_per_primary_backend: Dict[str, Dict[str, Any]] | None, config: Dict[str, Any], ) -> None: """A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, @@ -449,7 +449,7 @@ async def init_backends( once globally. Args: - metadata_per_primary_backend (Optional[Dict[str, Dict[str, Any]]]): Metadata from primary + metadata_per_primary_backend (Dict[str, Dict[str, Any]] | None): Metadata from primary logger backend, e.g., {"wandb": {"run_id": "abc123"}}. config (Dict[str, Any]): Logger backend configuration, e.g. {"wandb": {"project": "my_project"}}. """ @@ -592,7 +592,7 @@ def __init__(self, logger_backend_config: Dict[str, Any]) -> None: async def init( self, role: BackendRole, - primary_logger_metadata: Optional[Dict[str, Any]] = None, + primary_logger_metadata: Dict[str, Any] | None = None, ) -> None: """ Initializes backend, e.g. wandb.run.init(). @@ -600,7 +600,7 @@ async def init( Args: role (BackendRole): BackendRole.GLOBAL (controller/primary) or BackendRole.LOCAL (per-rank/secondary). Can be used to behave differently for primary vs secondary roles. - primary_logger_metadata (Optional[Dict[str, Any]]): From global backend for + primary_logger_metadata (Dict[str, Any] | None): From global backend for backend that required shared info, e.g. {"shared_run_id": "abc123"}. Raises: ValueError if missing metadata for shared local init. @@ -621,7 +621,7 @@ async def log(self, metrics: List[Metric], global_step: int) -> None: async def finish(self) -> None: pass - def get_metadata_for_secondary_ranks(self) -> Optional[Dict[str, Any]]: + def get_metadata_for_secondary_ranks(self) -> Dict[str, Any] | None: """Return sharable state after primary init (e.g., for shared modes). Called only on globals.""" return None @@ -635,7 +635,7 @@ def __init__(self, logger_backend_config: Dict[str, Any]) -> None: async def init( self, role: BackendRole, - primary_logger_metadata: Optional[Dict[str, Any]] = None, + primary_logger_metadata: Dict[str, Any] | None = None, ) -> None: self.prefix = ( get_actor_name_with_rank() @@ -688,7 +688,7 @@ def __init__(self, logger_backend_config: Dict[str, Any]) -> None: async def init( self, role: BackendRole, - primary_logger_metadata: Optional[Dict[str, Any]] = None, + primary_logger_metadata: Dict[str, Any] | None = None, ) -> None: if primary_logger_metadata is None: diff --git a/src/forge/observability/perf_tracker.py b/src/forge/observability/perf_tracker.py index 47577d916..19e8ccb47 100644 --- a/src/forge/observability/perf_tracker.py +++ b/src/forge/observability/perf_tracker.py @@ -11,7 +11,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from functools import lru_cache, wraps -from typing import List, Optional, Protocol, Tuple +from typing import List, Protocol, Tuple import torch @@ -111,7 +111,7 @@ def __init__( self._active = False # Timing state - self._timer: Optional[_TimerProtocol] = None + self._timer: _TimerProtocol | None = None # Memory tracking state self._memory_started = False @@ -227,7 +227,7 @@ class _TimerCPU(_TimerProtocol): def __init__(self) -> None: self._durations: List[Tuple[str, float]] = [] - self._chain_start: Optional[float] = None + self._chain_start: float | None = None def start(self) -> None: # Reset state for reuse @@ -259,7 +259,7 @@ def __init__(self, max_workers: int = 2) -> None: Tuple[str, Future[float], int] ] = [] # (name, future, submission_index) self._durations: List[Tuple[str, float]] = [] - self._chain_start: Optional[torch.cuda.Event] = None + self._chain_start: torch.cuda.Event | None = None def start(self) -> None: """Call before any steps. Clear state for reuse; record initial event on current stream.""" diff --git a/src/forge/util/logging.py b/src/forge/util/logging.py index 2850e6828..e47f5dfa3 100644 --- a/src/forge/util/logging.py +++ b/src/forge/util/logging.py @@ -6,17 +6,16 @@ import logging from functools import lru_cache -from typing import Optional from torch import distributed as dist -def get_logger(level: Optional[str] = None) -> logging.Logger: +def get_logger(level: str | None = None) -> logging.Logger: """ Get a logger with a stream handler. Args: - level (Optional[str]): The logging level. See https://docs.python.org/3/library/logging.html#levels for list of levels. + level (str | None): The logging level. See https://docs.python.org/3/library/logging.html#levels for list of levels. Example: >>> logger = get_logger("INFO") diff --git a/src/forge/util/metric_logging.py b/src/forge/util/metric_logging.py index 1141cfbd7..ba8992d20 100644 --- a/src/forge/util/metric_logging.py +++ b/src/forge/util/metric_logging.py @@ -6,7 +6,7 @@ import os import sys import time -from typing import Mapping, Optional, Union +from typing import Mapping, Union from forge.interfaces import MetricLogger from forge.types import Scalar @@ -115,7 +115,7 @@ def __init__( from torch.utils.tensorboard import SummaryWriter self._freq = freq - self._writer: Optional[SummaryWriter] = None + self._writer: SummaryWriter | None = None _, rank = get_world_size_and_rank() # In case organize_logs is `True`, update log_dir to include a subdirectory for the @@ -177,11 +177,11 @@ class WandBLogger(MetricLogger): freq (Union[int, Mapping[str, int]]): If int, all metrics will be logged at this frequency. If Mapping, calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0` - log_dir (Optional[str]): WandB log directory. + log_dir (str | None): WandB log directory. project (str): WandB project name. Default is `torchtune`. - entity (Optional[str]): WandB entity name. If you don't specify an entity, + entity (str | None): WandB entity name. If you don't specify an entity, the run will be sent to your default entity, which is usually your username. - group (Optional[str]): WandB group name for grouping runs together. If you don't + group (str | None): WandB group name for grouping runs together. If you don't specify a group, the run will be logged as an individual experiment. **kwargs: additional arguments to pass to wandb.init @@ -207,8 +207,8 @@ def __init__( freq: Union[int, Mapping[str, int]], project: str, log_dir: str = "metrics_log", - entity: Optional[str] = None, - group: Optional[str] = None, + entity: str | None = None, + group: str | None = None, **kwargs, ): self._freq = freq