Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ async def setup(self):

self.request_id = 0
self.policy_version = 0
self.requests: dict[str, tuple[None | ParentRequest, asyncio.Future]] = {}
self.requests: dict[str, tuple[ParentRequest | None, asyncio.Future]] = {}

# TODO: Investigate whether this can be combined with `policy.running`
# Whether this policy is accepting requests.
Expand Down
3 changes: 1 addition & 2 deletions src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import os
import socket
import uuid
from typing import Optional

from monarch._src.actor.shape import NDSlice, Shape
from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host
Expand Down Expand Up @@ -163,7 +162,7 @@ async def get_proc_mesh(
num_procs: int,
with_gpus: bool = False,
num_hosts: int | None = None,
mesh_name: Optional[str] = None,
mesh_name: str | None = None,
host_mesh: HostMesh | None = None,
env_vars: dict[str, str] | None = None,
addr: str | None = None,
Expand Down
9 changes: 4 additions & 5 deletions src/forge/controller/service/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from collections import deque
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional

from monarch.actor import ActorError

Expand Down Expand Up @@ -81,7 +80,7 @@ class ServiceRequest:
"""

session_id: Optional[str]
session_id: str | None
function: str
args: tuple
kwargs: dict
Expand All @@ -107,7 +106,7 @@ class Replica:
actor_kwargs: dict

# The Actor that this replica is running
actor: Optional[ForgeActor] = None
actor: ForgeActor | None = None

# Async queue for incoming requests
request_queue: asyncio.Queue[ServiceRequest] = field(default_factory=asyncio.Queue)
Expand All @@ -127,10 +126,10 @@ class Replica:
return_first_rank_result: bool = False

# Recovery-related state
_recovery_task: Optional[asyncio.Task] = None
_recovery_task: asyncio.Task | None = None

# Run task is the replica's event loop
_run_task: Optional[asyncio.Task] = None
_run_task: asyncio.Task | None = None

# Metrics tracking
metrics: ReplicaMetrics = field(default_factory=ReplicaMetrics)
Expand Down
42 changes: 21 additions & 21 deletions src/forge/data/datasets/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Any, Callable, Iterator, Optional
from typing import Any, Callable, Iterator

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -37,44 +37,44 @@ class HfIterableDataset(InfiniteTuneIterableDataset):
- Returning an infinite iterator over the dataset

Args:
message_transform (Optional[Transform]): Transforms raw data into a `Message`.
model_transform (Optional[Transform]): Prepares messages for the model,
message_transform (Transform | None): Transforms raw data into a `Message`.
model_transform (Transform | None): Prepares messages for the model,
usually by tokenizing them.
output_transform (Optional[Transform]): Prepares tokenized inputs for the
output_transform (Transform | None): Prepares tokenized inputs for the
recipe, often by manipulating labels (e.g., setting an ignore index).
This transform is recipe-dependent (e.g., SFT, DPO, etc.).
metric_transform (Optional[MetricTransform]): Computes metrics from a
metric_transform (MetricTransform | None): Computes metrics from a
sample (e.g., token count). If ``None``, a default transform is used.
To disable standard metric tracking, set this to ``lambda x: x``.
shuffle_buffer_size (Optional[int]): Size of the shuffle buffer.
shuffle_buffer_size (int | None): Size of the shuffle buffer.
If ``None`` or 0, no shuffling is performed.
weight (Optional[float]): Weight for this dataset. Defaults to 1.0.
weight (float | None): Weight for this dataset. Defaults to 1.0.
seed (int): Seed for shuffling.
num_shards_per_rank (int): The target number of shards per worker (GPU).
The actual number of shards will be a multiple of
``world_size * dataloader_workers``.
dataset_name (Optional[str]): Name of the dataset. If ``None``, a name is
dataset_name (str | None): Name of the dataset. If ``None``, a name is
generated from the ``path``, ``source``, and ``split``.
filter_fn (Optional[Callable]): A function to filter the dataset.
filter_kwargs (Optional[dict[str, Any]]): Keyword arguments for ``filter_fn``.
filter_fn (Callable | None): A function to filter the dataset.
filter_kwargs (dict[str, Any] | None): Keyword arguments for ``filter_fn``.
**load_dataset_kwargs: Keyword arguments for the
:func:`~datasets.load_dataset` function.
"""

def __init__(
self,
*,
message_transform: Optional[Transform] = None,
model_transform: Optional[Transform] = None,
output_transform: Optional[Transform] = None,
metric_transform: Optional[MetricTransform] = None,
shuffle_buffer_size: Optional[int] = 1000,
weight: Optional[float] = 1.0,
message_transform: Transform | None = None,
model_transform: Transform | None = None,
output_transform: Transform | None = None,
metric_transform: MetricTransform | None = None,
shuffle_buffer_size: int | None = 1000,
weight: float | None = 1.0,
seed: int = 42,
num_shards_per_rank: int = 64,
dataset_name: Optional[str] = None,
filter_fn: Optional[Callable] = None,
filter_kwargs: Optional[dict[str, Any]] = None,
dataset_name: str | None = None,
filter_fn: Callable | None = None,
filter_kwargs: dict[str, Any] | None = None,
**load_dataset_kwargs,
):
# Store configuration
Expand Down Expand Up @@ -135,8 +135,8 @@ def _setup_hf_dataset(
self,
load_dataset_kwargs: dict[str, Any],
num_shards_per_rank: int,
filter_fn: Optional[Callable] = None,
filter_kwargs: Optional[dict[str, Any]] = None,
filter_fn: Callable | None = None,
filter_kwargs: dict[str, Any] | None = None,
):
"""
One-time setup of HuggingFace dataset that handles Handles distributed sharding,
Expand Down
14 changes: 7 additions & 7 deletions src/forge/data/datasets/packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
from abc import ABC, abstractmethod
from collections import deque
from typing import Any, Generic, Iterable, Iterator, Optional, TypeVar
from typing import Any, Generic, Iterable, Iterator, TypeVar

import torch
from torch.nn.attention.flex_attention import (
Expand Down Expand Up @@ -329,13 +329,13 @@ def _reset_packer_state(self) -> None:
self._buffer.clear()

# current_pack: the current pack being built
self._current_pack: Optional[dict[str, list]] = None
self._current_pack: dict[str, list] | None = None

# current_pack_size: the number of tokens in the current pack
self._current_pack_size: int = 0

# iterator: the iterator over the dataset
self._iterator: Optional[Iterator[SampleType]] = None
self._iterator: Iterator[SampleType] | None = None

# current_doc_id_in_pack: the document ID to use for the next sample
self._current_doc_id_in_pack: int = 0
Expand Down Expand Up @@ -367,15 +367,15 @@ def _fill_buffer(self, iterator: Iterator[SampleType]) -> None:
except StopIteration:
self._exhausted = True

def _find_next_fitting_sample(self, remaining_size: int) -> Optional[int]:
def _find_next_fitting_sample(self, remaining_size: int) -> int | None:
"""
Find the first sample in the buffer that fits in the remaining space.

Args:
remaining_size (int): The remaining space in the current pack.

Returns:
Optional[int]: The index of the sample in the buffer, or None if no sample fits.
int | None: The index of the sample in the buffer, or None if no sample fits.

Example:
self._buffer = deque([(sample1, 200), (sample2, 100), (sample3, 48), (sample4, 200)])
Expand All @@ -397,15 +397,15 @@ def _find_next_fitting_sample(self, remaining_size: int) -> Optional[int]:
return i
return None

def _build_one_pack(self, iterator: Iterator[SampleType]) -> Optional[SampleDict]:
def _build_one_pack(self, iterator: Iterator[SampleType]) -> SampleDict | None:
"""
Builds a pack of samples from the buffer.

Args:
iterator (Iterator[SampleType]): The iterator over the dataset.

Returns:
Optional[SampleDict]: The pack of samples, or None if the dataset is exhausted.
SampleDict | None: The pack of samples, or None if the dataset is exhausted.
"""
# Start a new pack if necessary
if self._current_pack is None:
Expand Down
22 changes: 11 additions & 11 deletions src/forge/data/datasets/sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Callable, Optional
from typing import Any, Callable

import torch

Expand All @@ -26,7 +26,7 @@ class AlpacaToMessages(Transform):
due to this custom logic.

Args:
column_map (Optional[dict[str, str]]): a mapping to change the expected "instruction", "input",
column_map (dict[str, str] | None): a mapping to change the expected "instruction", "input",
and "output" column names to the actual column names in the dataset. Default is None,
keeping the default column names.
masking_strategy (str): masking strategy to use for model training.
Expand All @@ -45,7 +45,7 @@ class AlpacaToMessages(Transform):

def __init__(
self,
column_map: Optional[dict[str, str]] = None,
column_map: dict[str, str] | None = None,
masking_strategy: str = "train_on_all",
):
self.masking_strategy = masking_strategy
Expand Down Expand Up @@ -158,12 +158,12 @@ def sft_iterable_dataset(
*,
weight: int = 1,
message_transform: Transform,
shuffle_buffer_size: Optional[int] = 1000,
shuffle_buffer_size: int | None = 1000,
seed: int = 42,
num_shards_per_rank: int = 64,
dataset_name: Optional[str] = None,
filter_fn: Optional[Callable] = None,
filter_kwargs: Optional[dict[str, Any]] = None,
dataset_name: str | None = None,
filter_fn: Callable | None = None,
filter_kwargs: dict[str, Any] | None = None,
**load_dataset_kwargs: dict[str, Any],
) -> HfIterableDataset:
"""
Expand All @@ -173,12 +173,12 @@ def sft_iterable_dataset(
model_transform (Transform): Usually the tokenizer
weight (int): Weight of the dataset. Used for sampling when interleaving datasets.
message_transform (Transform): Transform to convert raw data to messages
shuffle_buffer_size (Optional[int]): Buffer size for shuffling
shuffle_buffer_size (int | None): Buffer size for shuffling
seed (int): Random seed for shuffling
num_shards_per_rank (int): Target shards per worker
dataset_name (Optional[str]): Name for metrics namespacing
filter_fn (Optional[Callable]): Filter function
filter_kwargs (Optional[dict[str, Any]]): Filter function kwargs
dataset_name (str | None): Name for metrics namespacing
filter_fn (Callable | None): Filter function
filter_kwargs (dict[str, Any] | None): Filter function kwargs
**load_dataset_kwargs (dict[str, Any]): Args passed to load_dataset

Returns:
Expand Down
20 changes: 10 additions & 10 deletions src/forge/data/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import json
from typing import Any, Optional
from typing import Any

import jinja2
from jinja2 import StrictUndefined
Expand All @@ -28,8 +28,8 @@ class HuggingFaceBaseTokenizer(BaseTokenizer):

Args:
tokenizer_json_path (str): Path to tokenizer.json file
tokenizer_config_json_path (Optional[str]): Path to tokenizer_config.json file. Default: None
generation_config_path (Optional[str]): Path to generation_config.json file.
tokenizer_config_json_path (str | None): Path to tokenizer_config.json file. Default: None
generation_config_path (str | None): Path to generation_config.json file.
Default: None

Raises:
Expand All @@ -40,8 +40,8 @@ def __init__(
self,
tokenizer_json_path: str,
*,
tokenizer_config_json_path: Optional[str] = None,
generation_config_path: Optional[str] = None,
tokenizer_config_json_path: str | None = None,
generation_config_path: str | None = None,
):
self.tokenizer = Tokenizer.from_file(tokenizer_json_path)
if not (tokenizer_config_json_path or generation_config_path):
Expand Down Expand Up @@ -209,8 +209,8 @@ class HuggingFaceModelTokenizer(ModelTokenizer):

Args:
tokenizer_json_path (str): Path to tokenizer.json file
tokenizer_config_json_path (Optional[str]): Path to tokenizer_config.json file. Default: None
generation_config_path (Optional[str]): Path to generation_config.json file.
tokenizer_config_json_path (str | None): Path to tokenizer_config.json file. Default: None
generation_config_path (str | None): Path to generation_config.json file.
Default: None
truncation_type (str): type of truncation to apply, either "left" or "right".
Default is "right".
Expand All @@ -220,8 +220,8 @@ def __init__(
self,
tokenizer_json_path: str,
*,
tokenizer_config_json_path: Optional[str] = None,
generation_config_path: Optional[str] = None,
tokenizer_config_json_path: str | None = None,
generation_config_path: str | None = None,
truncation_type: str = "right",
):
self.base_tokenizer = HuggingFaceBaseTokenizer(
Expand Down Expand Up @@ -274,7 +274,7 @@ def tokenize_messages(
self,
messages: list[Message],
add_eos: bool = True,
max_seq_len: Optional[int] = None,
max_seq_len: int | None = None,
) -> tuple[list[int], list[bool]]:
tokenized_messages = []
mask = []
Expand Down
6 changes: 3 additions & 3 deletions src/forge/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

from enum import Enum
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Union

import torch

Expand Down Expand Up @@ -118,7 +118,7 @@ def __repr__(self) -> str:
def truncate(
tokens: list[Any],
max_seq_len: int,
eos_id: Optional[Any] = None,
eos_id: Any | None = None,
truncation_type: str = "right",
) -> list[Any]:
"""
Expand All @@ -128,7 +128,7 @@ def truncate(
Args:
tokens (list[Any]): list of tokens to truncate
max_seq_len (int): maximum length of the list
eos_id (Optional[Any]): token to replace the last token with. If None, the
eos_id (Any | None): token to replace the last token with. If None, the
last token will not be replaced. Default is None.
truncation_type (str): type of truncation to apply, either "left" or "right".
Default is "right".
Expand Down
4 changes: 2 additions & 2 deletions src/forge/data_models/episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Optional, Sequence
from typing import Sequence

import torch

Expand All @@ -32,7 +32,7 @@ class Episode:

# The log probabilities of the target tokens, for prompt part it's set to 0,
# for generation part it's computed from the Generator/Sampler.
log_probs: Optional[torch.Tensor] = None
log_probs: torch.Tensor | None = None

# TODO: add more fields as required
state: str = ""
Expand Down
4 changes: 2 additions & 2 deletions src/forge/observability/metric_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import asyncio
import logging
import os
from typing import Any, Dict, Optional
from typing import Any, Dict, Union

from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc

Expand Down Expand Up @@ -120,7 +120,7 @@ class LocalFetcherActor(Actor):
GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector
"""

def __init__(self, global_logger: Optional["GlobalLoggingActor"] = None) -> None:
def __init__(self, global_logger: Union["GlobalLoggingActor", None] = None) -> None:
self.global_logger = global_logger
_is_initialized = False

Expand Down
Loading
Loading