diff --git a/.github/workflows/unit_test.yaml b/.github/workflows/unit_test.yaml new file mode 100644 index 000000000..88b08ba76 --- /dev/null +++ b/.github/workflows/unit_test.yaml @@ -0,0 +1,36 @@ +name: Unit Test + +on: + pull_request: + + +jobs: + unit_tests: + runs-on: ubuntu-latest + timeout-minutes: 15 + strategy: + matrix: + python-version: ['3.10', '3.11', '3.12'] + steps: + - name: Check out repo + uses: actions/checkout@v4 + - name: Setup conda env + uses: conda-incubator/setup-miniconda@v2 + with: + auto-update-conda: true + miniconda-version: "latest" + activate-environment: test + python-version: ${{ matrix.python-version }} + - name: Update pip + run: python -m pip install --upgrade pip + - name: Install dependencies + run: | + python -m pip install -e ".[dev]" + # Will have to pin until monarch wheel is reasonable + pip install torch==2.9.0.dev20250815+cpu --index-url https://download.pytorch.org/whl/nightly/cpu + python -m pip install --no-build-isolation --verbose assets/wheels/monarch_no_torch-0.1.0.dev20250815-py3-none-any.whl + - name: Run unit tests with coverage + # TODO add all tests + run: pytest tests/unit_tests --cov=. --cov-report=xml --durations=20 -vv + - name: Upload Coverage to Codecov + uses: codecov/codecov-action@v3 diff --git a/.gitignore b/.gitignore index 7a405ecf0..3198097ef 100644 --- a/.gitignore +++ b/.gitignore @@ -32,7 +32,6 @@ lib64/ parts/ sdist/ var/ -wheels/ share/python-wheels/ *.egg-info/ .installed.cfg diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 854ee0e97..1a9bad2ca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,8 +12,6 @@ repos: - id: check-merge-conflict - id: no-commit-to-branch args: ['--branch=main'] - - id: check-added-large-files - args: ['--maxkb=1000'] - id: end-of-file-fixer exclude: '^(.*\.svg)$' diff --git a/apps/toy_rl/main.py b/apps/toy_rl/main.py index f5cbd203d..27461d550 100644 --- a/apps/toy_rl/main.py +++ b/apps/toy_rl/main.py @@ -16,7 +16,6 @@ import torch from forge.actors.collector import Collector -from forge.controller.stack import stack from forge.data.replay_buffer import ReplayBuffer from forge.interfaces import Environment, Policy from forge.types import Action, Observation, State @@ -150,7 +149,8 @@ async def main(): # This policy just generates something between -2. and 2. policy = await policy_procs.spawn("policy", ToyPolicy, action_range=(-2.0, 2.0)) - browser_collectors = await browser_procs.spawn( + # TODO - replace multiple collectors with a service. + collectors = await browser_procs.spawn( "browser", Collector, max_collector_steps=5, @@ -162,29 +162,6 @@ async def main(): # to do it this way. environment_creator=partial(ToyEnvironment, name="browser", max_steps=5), ) - deep_research_collectors = await deep_research_procs.spawn( - "deep_research", - Collector, - max_collector_steps=5, - policy=policy, - replay_buffer=replay_buffer, - environment_creator=partial(ToyEnvironment, name="deep_research", max_steps=5), - ) - coding_collectors = await coder_procs.spawn( - "coding", - Collector, - max_collector_steps=5, - policy=policy, - replay_buffer=replay_buffer, - environment_creator=partial(ToyEnvironment, name="coding", max_steps=5), - ) - - collectors = stack( - browser_collectors, - deep_research_collectors, - coding_collectors, - interface=Collector, - ) # Create two async tasks async def episode_collector_task(): @@ -193,16 +170,7 @@ async def episode_collector_task(): while True: try: print(f"🎮 Running episode {episode_count + 1}...") - - # call() is essentially our "map" - every collector runs their own - # episode loop. - # What's pretty elegant here is if we wanted to control off policiness, we could - # easily counter on steps and call policy.update_weights.call() at our desired - # frequency. - results = collectors.run_episode.call() - - # Temporary hack due to Monarch changes - ideally you could just await results - results = [await r for r in results] + results = await collectors.run_episode.call() num_trajectories = sum([len(r._values) for r in results]) episode_count += 1 print( diff --git a/assets/version.txt b/assets/version.txt new file mode 100644 index 000000000..6e8bf73aa --- /dev/null +++ b/assets/version.txt @@ -0,0 +1 @@ +0.1.0 diff --git a/assets/wheels/monarch_no_torch-0.1.0.dev20250815-py3-none-any.whl b/assets/wheels/monarch_no_torch-0.1.0.dev20250815-py3-none-any.whl new file mode 100644 index 000000000..52287b4cc Binary files /dev/null and b/assets/wheels/monarch_no_torch-0.1.0.dev20250815-py3-none-any.whl differ diff --git a/pyproject.toml b/pyproject.toml index 0d8f8eeca..4062a9a85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ # PyTorch "torchdata>=0.8.0", "torchtitan", + "torchao", # vLLM # TODO: pin specific vllm version #"vllm==0.10.0", @@ -36,6 +37,7 @@ dev = [ "pytest-cov", "pytest-timeout", "tensorboard", + "expecttest", "tomli>=1.1.0", "anyio", "pytest-asyncio", diff --git a/src/forge/actors/__init__.py b/src/forge/actors/__init__.py index be6b0af75..84937b36b 100644 --- a/src/forge/actors/__init__.py +++ b/src/forge/actors/__init__.py @@ -5,6 +5,5 @@ # LICENSE file in the root directory of this source tree. from .collector import Collector -from .policy import Policy, PolicyRouter -__all__ = ["Collector", "Policy", "PolicyRouter"] +__all__ = ["Collector"] diff --git a/src/forge/controller/__init__.py b/src/forge/controller/__init__.py index d27f0a6ab..e3215720a 100644 --- a/src/forge/controller/__init__.py +++ b/src/forge/controller/__init__.py @@ -8,7 +8,6 @@ from .recoverable_mesh import RecoverableProcMesh from .service import AutoscalingConfig, Service, ServiceConfig from .spawn import spawn_service -from .stack import stack __all__ = [ "AutoscalingConfig", @@ -19,5 +18,4 @@ "get_proc_mesh", "ForgeActor", "RecoverableProcMesh", - "stack", ] diff --git a/src/forge/controller/stack.py b/src/forge/controller/stack.py deleted file mode 100644 index 4d2dff2dc..000000000 --- a/src/forge/controller/stack.py +++ /dev/null @@ -1,237 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import inspect -import random -from typing import ( - Any, - Coroutine, - Generator, - Generic, - Optional, - overload, - ParamSpec, - Tuple, - Type, - TypeVar, -) - -from monarch._rust_bindings.monarch_hyperactor.shape import Shape -from monarch._src.actor.actor_mesh import Actor, ActorMeshRef, Endpoint, ValueMesh - -from monarch._src.actor.future import Future -from monarch._src.actor.shape import MeshTrait, NDSlice - -T = TypeVar("T") -T_co = TypeVar("T_co", covariant=True) -P = ParamSpec("P") -R = TypeVar("R") -A = TypeVar("A") - - -class StackedEndpoint(Generic[P, R]): - """ - A class that represents a collection of endpoints stacked together. - - This class allows for operations to be performed across multiple endpoints - as if they were a single entity. - - This provides the same interface as the Endpoint class. - - """ - - def __init__(self, endpoints: list[Endpoint]) -> None: - self.endpoints = endpoints - - def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]: - """Load balanced sends a message to one chosen actor from all stacked actors.""" - endpoint = random.choice(self.endpoints) - return endpoint.choose(*args, **kwargs) - - def call(self, *args: P.args, **kwargs: P.kwargs) -> "list[Future[ValueMesh[R]]]": - """Sends a message to all actors in all stacked endpoints and collects results. - - Currently this returns a list of futures rather than a single future, due to - changes in Monarch. - - """ - return [endpoint.call(*args, **kwargs) for endpoint in self.endpoints] - - def _stream( - self, *args: P.args, **kwargs: P.kwargs - ) -> "Generator[Coroutine[Any, Any, R], None, None]": - """ - Broadcasts to all actors in all stacked endpoints and yields their responses as coroutines. - - This enables processing results from multiple stacked endpoints incrementally as - they become available. Returns a generator of coroutines that can be awaited. - """ - for endpoint in self.endpoints: - # Get the coroutines from each endpoint's _stream method - for coro in endpoint._stream(*args, **kwargs): - yield coro - - def stream( - self, *args: P.args, **kwargs: P.kwargs - ) -> Generator[Future[R], None, None]: - """Broadcasts to all actors in all stacked endpoints and yields responses as a stream.""" - for endpoint in self.endpoints: - # endpoint.stream() returns a Generator[Future[R], None, None] - for future in endpoint.stream(*args, **kwargs): - yield future - - def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None: - """Fire-and-forget broadcast to all actors in all stacked endpoints.""" - for endpoint in self.endpoints: - endpoint.broadcast(*args, **kwargs) - - -class StackedActorMeshRef(MeshTrait): - def __init__(self, *actors: ActorMeshRef, interface=None) -> None: - self._actors = actors - self._interface = interface - - # Create endpoints by looking at the interface class for endpoint methods - if interface is not None and actors: - # Look for methods decorated with @endpoint in the interface class - for attr_name in dir(interface): - if not attr_name.startswith("_"): # Skip private methods - # Check if this method exists as an endpoint on the first actor - first_actor = actors[0] - if hasattr(first_actor, attr_name): - first_endpoint = getattr(first_actor, attr_name) - if isinstance(first_endpoint, Endpoint): - # Get the corresponding endpoint from each mesh - endpoints = [] - for mesh in self._actors: - if hasattr(mesh, attr_name): - endpoint = getattr(mesh, attr_name) - if isinstance(endpoint, Endpoint): - endpoints.append(endpoint) - - # Create a stacked endpoint with all the collected endpoints - if endpoints and len(endpoints) == len(self._actors): - setattr(self, attr_name, StackedEndpoint(endpoints)) - - def __getattr__(self, name: str) -> StackedEndpoint: - """ - Fallback for accessing dynamically created endpoint attributes. - This helps the type checker understand that any attribute access - on a StackedActorMeshRef should return a StackedEndpoint. - """ - # This should only be called if the attribute doesn't exist - # which means it wasn't created in __init__, so we raise AttributeError - raise AttributeError( - f"'{self.__class__.__name__}' object has no attribute '{name}'" - ) - - @property - def _ndslice(self) -> NDSlice: - raise NotImplementedError( - "actor implementations are not meshes, but we can't convince the typechecker of it..." - ) - - @property - def _labels(self) -> Tuple[str, ...]: - raise NotImplementedError( - "actor implementations are not meshes, but we can't convince the typechecker of it..." - ) - - def _new_with_shape(self, shape: Shape) -> "StackedActorMeshRef": - raise NotImplementedError( - "actor implementations are not meshes, but we can't convince the typechecker of it..." - ) - - -def _common_ancestor(*actors: ActorMeshRef) -> Optional[Type]: - """Finds the common ancestor class of a list of actor mesh references. - - This determines the most specific common base class shared by all - provided actors. - - Args: - *actors: Variable number of ActorMeshRef instances to analyze - - Returns: - Optional[Type]: The most specific common ancestor class, or None if no - actors were provided or no common ancestor exists - - Example: - ```python - # Find common ancestor of two counter actors - counter_a = proc.spawn("counter_a", CounterA, 0).get() - counter_b = proc.spawn("counter_b", CounterB, 0).get() - common_class = _common_ancestor(counter_a, counter_b) # Returns Counter - ``` - """ - if not actors: - return None - base_classes = [obj._class for obj in actors] - all_mros = [inspect.getmro(cls) for cls in base_classes] - common_bases = set(all_mros[0]).intersection(*all_mros[1:]) - if common_bases: - return min( - common_bases, key=lambda cls: min(mro.index(cls) for mro in all_mros) - ) - return None - - -@overload -def stack(*actors: Any, interface: Type[T]) -> StackedActorMeshRef: - pass - - -@overload -def stack(*actors: Any) -> StackedActorMeshRef: - pass - - -def stack(*actors: Any, interface: Optional[Type] = None) -> StackedActorMeshRef: - """Stacks multiple actor mesh references into a single unified interface. - - This allows you to combine multiple actors that share a common interface - into a single object that can be used to interact with all of them simultaneously. - When methods are called on the stacked actor, they are distributed to all - underlying actors according to the endpoint's behavior (choose, call, stream, etc). - - Args: - *actors: Variable number of ActorMeshRef instances to stack together - interface: Optional class that defines the interface to expose. If not provided, - the common ancestor class of all actors will be used. - - Returns: - StackedActorMeshRef: A reference that provides access to all stacked actors - through a unified interface. - - Raises: - TypeError: If any of the provided objects is not an ActorMeshRef, or if - no common ancestor can be found and no interface is provided. - - Example: - ```python - # Stack two counter actors together - counter1 = proc1.spawn("counter1", Counter, 0).get() - counter2 = proc2.spawn("counter2", Counter, 0).get() - stacked = stack(counter1, counter2) - - # Call methods on all actors at once - stacked.incr.broadcast() # Increments both counters - ``` - - """ - for actor in actors: - if not isinstance(actor, ActorMeshRef): - raise TypeError( - "stack be provided with Monarch Actors, got {}".format(type(actor)) - ) - if interface is None: - interface = _common_ancestor(*actors) - - if interface is None or interface == Actor: - raise TypeError( - "No common ancestor found for the given actors. Please provide an interface explicitly." - ) - return StackedActorMeshRef(*actors, interface=interface) diff --git a/src/forge/data/datasets/__init__.py b/src/forge/data/datasets/__init__.py index 2b18c5a64..8d62e221d 100644 --- a/src/forge/data/datasets/__init__.py +++ b/src/forge/data/datasets/__init__.py @@ -4,16 +4,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .dataset import DatasetInfo, InfiniteTuneIterableDataset +from .dataset import DatasetInfo, InfiniteTuneIterableDataset, InterleavedDataset from .hf_dataset import HfIterableDataset from .packed import PackedDataset -from .sft_dataset import SFTOutputTransform, sft_iterable_dataset +from .sft_dataset import sft_iterable_dataset, SFTOutputTransform __all__ = [ "DatasetInfo", "HfIterableDataset", + "InterleavedDataset", "InfiniteTuneIterableDataset", "PackedDataset", "SFTOutputTransform", "sft_iterable_dataset", -] \ No newline at end of file +] diff --git a/tests/integration_tests.py b/tests/integration_tests.py deleted file mode 100755 index 3ccbc1890..000000000 --- a/tests/integration_tests.py +++ /dev/null @@ -1,603 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -import logging -import os -import subprocess -from collections import defaultdict -from dataclasses import dataclass -from typing import Sequence - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -try: - import tomllib -except ModuleNotFoundError: - import tomli as tomllib - - -@dataclass -class OverrideDefinitions: - """ - This class is used to define the override definitions for the integration tests. - """ - - override_args: Sequence[Sequence[str]] = tuple(tuple(" ")) - test_descr: str = "default" - test_name: str = "default" - ngpu: int = 4 - - def __repr__(self): - return self.test_descr - - -def build_test_list(): - """ - key is the config file name and value is a list of OverrideDefinitions - that is used to generate variations of integration tests based on the - same root config file. - """ - integration_tests_flavors = defaultdict(list) - integration_tests_flavors["debug_model.toml"] = [ - OverrideDefinitions( - [ - [ - "--profiling.enable_profiling", - "--metrics.enable_tensorboard", - ], - ], - "default", - "default", - ), - OverrideDefinitions( - [ - [ - "--training.compile", - ], - ], - "1D compile", - "1d_compile", - ), - OverrideDefinitions( - [ - [ - "--training.compile", - "--activation_checkpoint.mode selective", - "--activation_checkpoint.selective_ac_option op", - ], - ], - "1D compile with selective op AC", - "1d_compile_sac_op", - ), - OverrideDefinitions( - [ - [ - "--parallelism.tensor_parallel_degree 2", - ], - ], - "2D eager", - "2d_eager", - ), - OverrideDefinitions( - [ - [ - "--training.compile", - "--parallelism.tensor_parallel_degree 2", - ], - ], - "2D compile", - "2d_compile", - ), - # TODO: re-enable this test once the async TP issue is fixed - # OverrideDefinitions( - # [ - # [ - # "--training.compile", - # "--parallelism.tensor_parallel_degree 2", - # "--parallelism.enable_async_tensor_parallel", - # ], - # ], - # "2D async TP compile", - # "2d_asynctp_compile", - # ), - OverrideDefinitions( - [ - [ - "--checkpoint.enable_checkpoint", - ], - [ - "--checkpoint.enable_checkpoint", - "--training.steps 20", - ], - ], - "Checkpoint Integration Test - Save Load Full Checkpoint", - "full_checkpoint", - ), - OverrideDefinitions( - [ - [ - "--checkpoint.enable_checkpoint", - "--checkpoint.last_save_model_weights_only", - ], - ], - "Checkpoint Integration Test - Save Model Weights Only fp32", - "last_save_model_weights_only_fp32", - ), - OverrideDefinitions( - [ - [ - "--checkpoint.enable_checkpoint", - "--checkpoint.last_save_model_weights_only", - "--checkpoint.export_dtype bfloat16", - ], - ], - "Checkpoint Integration Test - Save Model Weights Only bf16", - "last_save_model_weights_only_bf16", - ), - OverrideDefinitions( - [ - [ - "--parallelism.pipeline_parallel_degree 4", - "--parallelism.pipeline_parallel_schedule InterleavedZeroBubble", - ], - ], - "PP looped zero bubble test", - "pp_looped_zero_bubble", - ngpu=4, - ), - OverrideDefinitions( - [ - [ - "--parallelism.pipeline_parallel_degree 2", - "--parallelism.pipeline_parallel_schedule ZBVZeroBubble", - ], - ], - "PP zero bubble test (v shaped)", - "pp_zbv", - ngpu=2, - ), - OverrideDefinitions( - [ - [ - "--parallelism.pipeline_parallel_degree 2", - "--parallelism.pipeline_parallel_schedule 1F1B", - "--parallelism.data_parallel_shard_degree 1", - ], - ], - "PP 1D test 1F1B", - "pp_1f1b", - ngpu=2, - ), - OverrideDefinitions( - [ - [ - "--parallelism.pipeline_parallel_degree 2", - "--parallelism.pipeline_parallel_schedule GPipe", - "--parallelism.data_parallel_shard_degree 1", - ], - ], - "PP 1D test GPipe", - "pp_gpipe", - ngpu=2, - ), - OverrideDefinitions( - [ - [ - "--parallelism.pipeline_parallel_degree 2", - "--parallelism.pipeline_parallel_schedule 1F1B", - "--parallelism.data_parallel_shard_degree 2", - ], - [ - "--parallelism.pipeline_parallel_degree 2", - "--parallelism.pipeline_parallel_schedule 1F1B", - "--parallelism.pipeline_parallel_layers_per_stage 4", - "--parallelism.data_parallel_shard_degree 2", - ], - ], - "PP+DP 1F1B 2D test", - "pp_dp_1f1b", - ), - OverrideDefinitions( - [ - [ - "--parallelism.pipeline_parallel_degree 2", - "--parallelism.pipeline_parallel_schedule GPipe", - "--parallelism.data_parallel_shard_degree 2", - ], - ], - "PP+DP GPipe 2D test", - "pp_dp_gpipe", - ), - OverrideDefinitions( - [ - [ - "--parallelism.pipeline_parallel_degree 2", - "--parallelism.tensor_parallel_degree 2", - ], - ], - "PP+TP 2D test", - "pp_tp", - ), - OverrideDefinitions( - [ - [ - "--checkpoint.enable_checkpoint", - "--parallelism.pipeline_parallel_degree 2", - "--parallelism.data_parallel_shard_degree 2", - "--parallelism.tensor_parallel_degree 2", - ], - [ - "--training.steps 20", - "--checkpoint.enable_checkpoint", - "--parallelism.pipeline_parallel_degree 2", - "--parallelism.data_parallel_shard_degree 2", - "--parallelism.tensor_parallel_degree 2", - ], - ], - "PP+DP+TP 3D test with save/load resume ckpt", - "pp_dp_tp", - ngpu=8, - ), - OverrideDefinitions( - [ - [ - "--parallelism.pipeline_parallel_degree 2", - "--parallelism.data_parallel_shard_degree 2", - "--parallelism.tensor_parallel_degree 2", - "--training.compile", - ], - ], - "PP+DP+TP 3D test with torch.compile", - "3d_compile", - ngpu=8, - ), - OverrideDefinitions( - [ - [ - "--parallelism.pipeline_parallel_degree 4", - "--parallelism.pipeline_parallel_schedule Interleaved1F1B", - ], - [ - "--parallelism.pipeline_parallel_degree 4", - "--parallelism.pipeline_parallel_schedule Interleaved1F1B", - "--parallelism.pipeline_parallel_layers_per_stage 1", - ], - ], - "PP looped 1F1B test", - "pp_looped_1f1b", - ngpu=4, - ), - OverrideDefinitions( - [ - [ - "--parallelism.pipeline_parallel_degree 2", - "--parallelism.pipeline_parallel_schedule PipelineScheduleMulti", - "--parallelism.pipeline_parallel_schedule_csv ./tests/assets/custom_schedule.csv", - ], - ], - "PP with custom pipeline schedule loaded from CSV file", - "pp_custom_csv", - ngpu=2, - ), - OverrideDefinitions( - [ - [ - "--optimizer.name AdamW --optimizer.implementation foreach", - ] - ], - "Foreach Optimizer Test", - "optimizer_foreach", - ngpu=2, - ), - OverrideDefinitions( - [ - [ - "--parallelism.data_parallel_shard_degree=1", - "--parallelism.data_parallel_replicate_degree=4", - ] - ], - "DDP", - "ddp", - ngpu=4, - ), - OverrideDefinitions( - [ - [ - "--parallelism.data_parallel_shard_degree=2", - "--parallelism.data_parallel_replicate_degree=2", - ] - ], - "HSDP", - "hsdp", - ngpu=4, - ), - OverrideDefinitions( - [ - [ - "--parallelism.data_parallel_shard_degree=4", - "--activation_checkpoint.mode='full'", - "--model.flavor=debugmodel_flex_attn", - ] - ], - "FSDP+FLEX_ATTN", - "fsdp+flex_attn", - ngpu=4, - ), - OverrideDefinitions( - [ - [ - "--parallelism.context_parallel_degree=4", - "--parallelism.context_parallel_rotate_method='allgather'", - ] - ], - "CP (allgather)", - "cp_allgather", - ngpu=4, - ), - OverrideDefinitions( - [ - [ - "--parallelism.context_parallel_degree=4", - "--parallelism.context_parallel_rotate_method='alltoall'", - ] - ], - "CP (alltoall)", - "cp_alltoall", - ngpu=4, - ), - OverrideDefinitions( - [ - [ - "--parallelism.data_parallel_shard_degree=2", - "--parallelism.data_parallel_replicate_degree=2", - "--parallelism.tensor_parallel_degree=2", - ] - ], - "HSDP+TP", - "hsdp+tp", - ngpu=8, - ), - OverrideDefinitions( - [ - [ - "--parallelism.data_parallel_shard_degree=2", - "--parallelism.context_parallel_degree=2", - ] - ], - "FSDP+CP", - "fsdp+cp", - ngpu=4, - ), - OverrideDefinitions( - [ - [ - "--parallelism.data_parallel_shard_degree=1", - "--parallelism.data_parallel_replicate_degree=2", - "--parallelism.context_parallel_degree=2", - ] - ], - "HSDP+CP (without dp_shard)", - "hsdp+cp_without_dp_shard", - ngpu=4, - ), - OverrideDefinitions( - [ - [ - "--parallelism.data_parallel_shard_degree=2", - "--parallelism.data_parallel_replicate_degree=2", - "--parallelism.context_parallel_degree=2", - ] - ], - "HSDP+CP (with dp_shard)", - "hsdp+cp_with_dp_shard", - ngpu=8, - ), - OverrideDefinitions( - [ - [ - "--parallelism.data_parallel_shard_degree=2", - "--parallelism.tensor_parallel_degree=2", - "--parallelism.context_parallel_degree=2", - ] - ], - "FSDP+TP+CP", - "fsdp+tp+cp", - ngpu=8, - ), - OverrideDefinitions( - [ - [ - "--checkpoint.enable_checkpoint", - "--parallelism.tensor_parallel_degree=2", - "--parallelism.context_parallel_degree=2", - "--training.enable_cpu_offload", - "--optimizer.early_step_in_backward", - ], - [ - "--parallelism.tensor_parallel_degree=2", - "--parallelism.context_parallel_degree=2", - "--parallelism.data_parallel_replicate_degree=2", - "--training.enable_cpu_offload", - "--optimizer.early_step_in_backward", - ], - ], - "Enable CPU Offload, Optimizer in backward with TP, DP, CP", - "cpu_offload+opt_in_bwd+TP+DP+CP", - ngpu=8, - ), - OverrideDefinitions( - [ - [ - "--memory_estimation.enabled", - ] - ], - "FSDP2 Memory Tracking and Estimation", - "fsdp2_memory_estimation", - ngpu=2, - ), - OverrideDefinitions( - [ - [ - "--checkpoint.enable_checkpoint", - ], - [ - # placeholder for the generation script's generate step - ], - ], - "Generation script test", - "test_generate", - ngpu=2, - ), - OverrideDefinitions( - [ - [ - "--parallelism.fsdp_reshard_after_forward always", - ], - ], - "Test always resharding after forward pass", - "fsdp_reshard_always", - ngpu=2, - ), - OverrideDefinitions( - [ - [ - "--checkpoint.enable_checkpoint", - "--training.steps 10", - ], - # Save at [dp:4] and load at [dp:2, tp:2]. Note that the dataloader should be - # excluded during loading to avoid errors caused by mismatched dp_degree. - [ - "--checkpoint.enable_checkpoint", - "--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer", - "--parallelism.tensor_parallel_degree 2", - "--training.steps 20", - ], - ], - "Optional checkpoint", - "optional_checkpoint", - ), - OverrideDefinitions( - [ - [ - "--model.converters float8", - "--float8.enable_fsdp_float8_all_gather", - "--float8.precompute_float8_dynamic_scale_for_fsdp", - "--float8.force_recompute_fp8_weight_in_bwd", - "--float8.emulate", - ], - ], - "Float8 emulation test", - "float8_emulation", - ), - OverrideDefinitions( - [ - [ - # Local batch size = 8, and `ngpu=2`, so default - # global batch size = 8 * 2 = 16. - # To achieve 2 gradient accumulation steps, multiply - # default global batch size by 2. 16 * 2 = 32. - "--training.local_batch_size 8", - "--training.global_batch_size 32", - ], - ], - "Gradient accumulation", - "gradient_accumulation", - ngpu=2, - ), - ] - return integration_tests_flavors - - -def _run_cmd(cmd): - return subprocess.run([cmd], text=True, shell=True) - - -def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str): - # run_test supports sequence of tests. - test_name = test_flavor.test_name - dump_folder_arg = f"--job.dump_folder {output_dir}/{test_name}" - all_ranks = ",".join(map(str, range(test_flavor.ngpu))) - - for idx, override_arg in enumerate(test_flavor.override_args): - cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_train.sh" - # dump compile trace for debugging purpose - cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd - if test_name == "fsdp2_memory_estimation": - cmd = ( - f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} " - "./scripts/estimate/run_memory_estimation.sh" - ) - cmd += " " + dump_folder_arg - if override_arg: - cmd += " " + " ".join(override_arg) - logger.info( - f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}=====" - ) - - # save checkpoint (idx == 0) and load it for generation (idx == 1) - if test_name == "test_generate" and idx == 1: - cmd = ( - f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} " - f"CHECKPOINT_DIR={output_dir}/{test_name}/checkpoint/step-10 " - "PROMPT='What is the meaning of life?' " - f"./scripts/generate/run_llama_generate.sh --out > {output_dir}/{test_name}/generated_output.json" - ) - - result = _run_cmd(cmd) - logger.info(result.stdout) - if result.returncode != 0: - raise Exception( - f"Integration test failed, flavor : {test_flavor.test_descr}, command : {cmd}" - ) - - -def run_tests(args): - integration_tests_flavors = build_test_list() - for config_file in os.listdir(args.config_dir): - if config_file.endswith(".toml"): - full_path = os.path.join(args.config_dir, config_file) - with open(full_path, "rb") as f: - config = tomllib.load(f) - is_integration_test = config["job"].get( - "use_for_integration_test", False - ) - if is_integration_test: - for test_flavor in integration_tests_flavors[config_file]: - if args.test == "all" or test_flavor.test_name == args.test: - if args.ngpu < test_flavor.ngpu: - logger.info( - f"Skipping test {test_flavor.test_name} that requires {test_flavor.ngpu} gpus," - f" because --ngpu arg is {args.ngpu}" - ) - else: - run_test(test_flavor, full_path, args.output_dir) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("output_dir") - parser.add_argument( - "--config_dir", default="./torchtitan/models/llama3/train_configs" - ) - parser.add_argument( - "--test", - default="all", - help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)", - ) - parser.add_argument("--ngpu", default=8, type=int) - args = parser.parse_args() - - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) - if os.listdir(args.output_dir): - raise RuntimeError("Please provide an empty output directory.") - run_tests(args) - - -if __name__ == "__main__": - main() diff --git a/tests/integration_tests_h100.py b/tests/integration_tests_h100.py deleted file mode 100755 index da9539957..000000000 --- a/tests/integration_tests_h100.py +++ /dev/null @@ -1,196 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -import logging -import os -import subprocess -from collections import defaultdict -from dataclasses import dataclass -from typing import Sequence - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -try: - import tomllib -except ModuleNotFoundError: - import tomli as tomllib - - -@dataclass -class OverrideDefinitions: - """ - This class is used to define the override definitions for the integration tests. - """ - - override_args: Sequence[Sequence[str]] = tuple(tuple(" ")) - test_descr: str = "default" - test_name: str = "default" - ngpu: int = 4 - - def __repr__(self): - return self.test_descr - - -def build_test_list(): - """ - key is the config file name and value is a list of OverrideDefinitions - that is used to generate variations of integration tests based on the - same root config file. - """ - integration_tests_flavors = defaultdict(list) - integration_tests_flavors["debug_model.toml"] = [ - OverrideDefinitions( - [ - [ - "--training.compile", - "--parallelism.tensor_parallel_degree 2", - "--parallelism.enable_async_tensor_parallel", - ], - ], - "2D async TP compile", - "2d_asynctp_compile", - ), - OverrideDefinitions( - [ - [ - "--model.converters float8", - "--float8.enable_fsdp_float8_all_gather", - "--float8.precompute_float8_dynamic_scale_for_fsdp", - "--float8.force_recompute_fp8_weight_in_bwd", - ], - ], - "Float8 test", - "float8", - ), - OverrideDefinitions( - [ - [ - "--training.compile", - "--parallelism.data_parallel_shard_degree=2", - "--parallelism.tensor_parallel_degree=2", - "--parallelism.pipeline_parallel_degree=2", - "--parallelism.enable_async_tensor_parallel", - "--model.converters float8", - "--float8.enable_fsdp_float8_all_gather", - "--float8.precompute_float8_dynamic_scale_for_fsdp", - "--float8.force_recompute_fp8_weight_in_bwd", - ] - ], - "FSDP+async TP+PP+torch.compile+Float8", - "fsdp+tp+cp+compile+float8", - ngpu=8, - ), - OverrideDefinitions( - [ - [ - "--training.compile", - "--parallelism.data_parallel_shard_degree=2", - "--parallelism.data_parallel_replicate_degree=2", - "--parallelism.context_parallel_degree=2", - "--model.converters float8", - "--float8.enable_fsdp_float8_all_gather", - "--float8.precompute_float8_dynamic_scale_for_fsdp", - "--float8.force_recompute_fp8_weight_in_bwd", - ] - ], - "HSDP+CP+torch.compile+Float8", - "hsdp+cp+compile+float8", - ngpu=8, - ), - ] - return integration_tests_flavors - - -def _run_cmd(cmd): - return subprocess.run([cmd], text=True, shell=True) - - -def run_test(test_flavor: OverrideDefinitions, full_path: str, output_dir: str): - # run_test supports sequence of tests. - test_name = test_flavor.test_name - dump_folder_arg = f"--job.dump_folder {output_dir}/{test_name}" - all_ranks = ",".join(map(str, range(test_flavor.ngpu))) - - for idx, override_arg in enumerate(test_flavor.override_args): - cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_train.sh" - # dump compile trace for debugging purpose - cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd - if test_name == "fsdp2_memory_estimation": - cmd = ( - f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} " - "./scripts/estimate/run_memory_estimation.sh" - ) - cmd += " " + dump_folder_arg - if override_arg: - cmd += " " + " ".join(override_arg) - logger.info( - f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}=====" - ) - - # save checkpoint (idx == 0) and load it for generation (idx == 1) - if test_name == "test_generate" and idx == 1: - cmd = ( - f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} " - f"CHECKPOINT_DIR={output_dir}/{test_name}/checkpoint/step-10 " - "PROMPT='What is the meaning of life?' " - f"./scripts/generate/run_llama_generate.sh --out > {output_dir}/{test_name}/generated_output.json" - ) - - result = _run_cmd(cmd) - logger.info(result.stdout) - if result.returncode != 0: - raise Exception( - f"Integration test failed, flavor : {test_flavor.test_descr}, command : {cmd}" - ) - - -def run_tests(args): - integration_tests_flavors = build_test_list() - for config_file in os.listdir(args.config_dir): - if config_file.endswith(".toml"): - full_path = os.path.join(args.config_dir, config_file) - with open(full_path, "rb") as f: - config = tomllib.load(f) - is_integration_test = config["job"].get( - "use_for_integration_test", False - ) - if is_integration_test: - for test_flavor in integration_tests_flavors[config_file]: - if args.test == "all" or test_flavor.test_name == args.test: - if args.ngpu < test_flavor.ngpu: - logger.info( - f"Skipping test {test_flavor.test_name} that requires {test_flavor.ngpu} gpus," - f" because --ngpu arg is {args.ngpu}" - ) - else: - run_test(test_flavor, full_path, args.output_dir) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("output_dir") - parser.add_argument( - "--config_dir", default="./torchtitan/models/llama3/train_configs" - ) - parser.add_argument( - "--test", - default="all", - help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)", - ) - parser.add_argument("--ngpu", default=8, type=int) - args = parser.parse_args() - - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) - if os.listdir(args.output_dir): - raise RuntimeError("Please provide an empty output directory.") - run_tests(args) - - -if __name__ == "__main__": - main() diff --git a/tests/unit_tests/data/test_metrics_aggregator.py b/tests/unit_tests/data/test_metrics_aggregator.py index 5cfd94ef0..5b847c92f 100644 --- a/tests/unit_tests/data/test_metrics_aggregator.py +++ b/tests/unit_tests/data/test_metrics_aggregator.py @@ -22,7 +22,7 @@ import pytest import torch.distributed as dist -from forge.data.metrics import AggregationType, Metric, MetricsAggregator +from forge.data.dataset_metrics import AggregationType, Metric, MetricsAggregator from torch.testing._internal.common_fsdp import FSDPTest from tests.test_utils import gpu_test @@ -237,7 +237,7 @@ def test_handler_replacement_warning(self, caplog): aggregator.update(metrics) # Replace the SUM handler - should generate warning - from forge.data.metrics.metric_agg_handlers import SumAggHandler + from forge.data.dataset_metrics import SumAggHandler with caplog.at_level(logging.WARNING): aggregator.register_handler(AggregationType.SUM, SumAggHandler()) diff --git a/tests/unit_tests/data/test_metrics_transform.py b/tests/unit_tests/data/test_metrics_transform.py index 0a79550a4..078b511a8 100644 --- a/tests/unit_tests/data/test_metrics_transform.py +++ b/tests/unit_tests/data/test_metrics_transform.py @@ -14,7 +14,7 @@ import pytest -from forge.data.metrics import AggregationType, DefaultTrainingMetricTransform +from forge.data.dataset_metrics import AggregationType, DefaultTrainingMetricTransform class TestDefaultTrainingMetricTransform: diff --git a/tests/unit_tests/datasets/test_hf.py b/tests/unit_tests/datasets/test_hf.py index acb2b0870..c1535c8b8 100644 --- a/tests/unit_tests/datasets/test_hf.py +++ b/tests/unit_tests/datasets/test_hf.py @@ -26,8 +26,8 @@ import pytest import torch.distributed as dist -from forge.data.metrics import DefaultTrainingMetricTransform, MetricsAggregator -from forge.datasets import HfIterableDataset +from forge.data.dataset_metrics import DefaultTrainingMetricTransform, MetricsAggregator +from forge.data.datasets import HfIterableDataset from torch.testing._internal.common_fsdp import FSDPTest from torchdata.stateful_dataloader import StatefulDataLoader diff --git a/tests/unit_tests/datasets/test_interleaved.py b/tests/unit_tests/datasets/test_interleaved.py index b860d5b2c..0073b905e 100644 --- a/tests/unit_tests/datasets/test_interleaved.py +++ b/tests/unit_tests/datasets/test_interleaved.py @@ -29,8 +29,8 @@ import torch import torch.distributed as dist -from forge.data.metrics import DefaultTrainingMetricTransform, MetricsAggregator -from forge.datasets import HfIterableDataset, InterleavedDataset +from forge.data.dataset_metrics import DefaultTrainingMetricTransform, MetricsAggregator +from forge.data.datasets import HfIterableDataset, InterleavedDataset from torch.testing._internal.common_fsdp import FSDPTest from torchdata.stateful_dataloader import StatefulDataLoader diff --git a/tests/unit_tests/datasets/test_iterable_utils.py b/tests/unit_tests/datasets/test_iterable_utils.py index b3c2c69e3..cdeced7c7 100644 --- a/tests/unit_tests/datasets/test_iterable_utils.py +++ b/tests/unit_tests/datasets/test_iterable_utils.py @@ -7,7 +7,7 @@ from typing import Any, Optional import torch -from forge.data.metrics import MetricsAggregator +from forge.data.dataset_metrics import MetricsAggregator from torch.utils.data import DataLoader diff --git a/tests/unit_tests/datasets/test_packed.py b/tests/unit_tests/datasets/test_packed.py index 4f5d188f5..352fbf703 100644 --- a/tests/unit_tests/datasets/test_packed.py +++ b/tests/unit_tests/datasets/test_packed.py @@ -14,10 +14,15 @@ import torch from forge.data.collate import collate_packed -from forge.data.metrics import MetricsAggregator -from forge.datasets.hf import HfIterableDataset -from forge.datasets.packed import DPOPacker, PackedDataset, Packer, TextPacker -from forge.tools._import_guard import _SUPPORTS_FLEX_ATTENTION +from forge.data.dataset_metrics import MetricsAggregator +from forge.data.datasets import HfIterableDataset +from forge.data.datasets.packed import ( + _SUPPORTS_FLEX_ATTENTION, + DPOPacker, + PackedDataset, + Packer, + TextPacker, +) from torchdata.stateful_dataloader import StatefulDataLoader from .test_iterable_utils import generate_ckpt diff --git a/tests/unit_tests/rl/environments/test_chat.py b/tests/unit_tests/rl/environments/test_chat.py index 52c481979..678e0a33f 100644 --- a/tests/unit_tests/rl/environments/test_chat.py +++ b/tests/unit_tests/rl/environments/test_chat.py @@ -15,7 +15,7 @@ import torch -from forge.rl.environments.chat import ( +from forge.envs.chat import ( ChatAction, ChatEnvironment, ChatObservation, diff --git a/tests/unit_tests/test_checkpoint.py b/tests/unit_tests/test_checkpoint.py deleted file mode 100644 index 23496699a..000000000 --- a/tests/unit_tests/test_checkpoint.py +++ /dev/null @@ -1,539 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -import shutil -import tempfile -import time -import unittest -from types import SimpleNamespace -from unittest import mock - -import torch -import torch.nn as nn -from torch.utils.data import DataLoader -from torchtitan.components.checkpoint import CheckpointManager, MODEL -from torchtitan.config_manager import Checkpoint as CheckpointConfig - - -class FakeOptimizersContainer: - """A fake OptimizersContainer that returns fake state dicts.""" - - def __init__(self): - self._fake_param = torch.tensor([1.0], dtype=torch.float32) - - def state_dict(self): - return {"fake_param": self._fake_param} - - def load_state_dict(self, sd: dict): - if "fake_param" in sd: - self._fake_param = sd["fake_param"] - - def init_cache_state_dict(self): - pass - - -class FakeLRSchedulersContainer: - """A fake LRSchedulersContainer that does nothing.""" - - def __init__(self): - pass - - def state_dict(self): - return {} - - def load_state_dict(self, sd: dict): - pass - - -class FakeDataLoader(DataLoader): - """A fake DataLoader that returns a fake batch.""" - - def __init__(self): - super().__init__(dataset=[], batch_size=1) - - def state_dict(self): - return {} - - def load_state_dict(self, sd: dict): - pass - - -class DummyFTManager: - """A fake FTManager-like object with enabled=False.""" - - def __init__(self): - self.enabled = False - self.manager = None - - -class DummyFuture: - def __init__(self): - self.result = mock.Mock() - - -def fake_async_save(*args, **kwargs): - return DummyFuture() - - -class DummyJobConfig: - def __init__(self, job): - self.job = job - self.checkpoint = CheckpointConfig( - enable_checkpoint=True, - async_mode="disabled", - folder="", - interval=1, - keep_latest_k=0, - last_save_model_weights_only=False, - export_dtype="float32", - exclude_from_loading=[], - initial_load_path=None, - initial_load_model_weights_only=False, - ) - self.fault_tolerance = SimpleNamespace(replica_id=0) - - -class TestCheckpointManager(unittest.TestCase): - def setUp(self): - self.base_temp_dir = tempfile.mkdtemp() - self.test_folder = os.path.join(self.base_temp_dir, self._testMethodName) - os.makedirs(self.test_folder, exist_ok=True) - - self.model_part = nn.Linear(2, 2) - self.model_parts = [self.model_part] - # TODO: Use a real OptimizerContainer here so that we can actually verify - # some optimizer.state_dict() behavior (e.g., the key being the parameter name.) - self.optimizers = FakeOptimizersContainer() - self.lr_schedulers = FakeLRSchedulersContainer() - self.states = {} - self.data_loader = FakeDataLoader() - self.ft_manager = DummyFTManager() - - ckpt_cfg = CheckpointConfig( - enable_checkpoint=True, - async_mode="DISABLED", - folder="", - interval=1, - keep_latest_k=2, - last_save_model_weights_only=False, - export_dtype="float32", - exclude_from_loading=[], - initial_load_path=None, - initial_load_model_weights_only=False, - ) - ft_ns = SimpleNamespace(replica_id=0) - job_ns = SimpleNamespace(dump_folder=self.test_folder) - self.job_config = SimpleNamespace( - checkpoint=ckpt_cfg, - fault_tolerance=ft_ns, - job=job_ns, - ) - - # Patch process group creation - self.patcher_group = mock.patch( - "torch.distributed.new_group", return_value="pg" - ) - self.patcher_group.start() - - def tearDown(self): - self.patcher_group.stop() - shutil.rmtree(self.base_temp_dir) - time.sleep(0.1) - - def fake_save(self, state_dict: dict, checkpoint_id: str): - os.makedirs(checkpoint_id, exist_ok=True) - sd_to_save = {} - for key, val in state_dict.items(): - if hasattr(val, "state_dict"): - sd_to_save[key] = val.state_dict() - elif isinstance(val, torch.Tensor): - sd_to_save[key] = val - torch.save(sd_to_save, os.path.join(checkpoint_id, "state_dict.pt")) - - def fake_load(self, states: dict, checkpoint_id=None): - path = os.path.join(checkpoint_id, "state_dict.pt") - loaded = torch.load(path, weights_only="False") - for key, val in loaded.items(): - if key in states and hasattr(states[key], "load_state_dict"): - states[key].load_state_dict(val) - elif key in states and isinstance(states[key], torch.Tensor): - states[key] = val - - @mock.patch("torch.distributed.get_rank", return_value=0) - @mock.patch("torchtitan.components.checkpoint.dcp.save") - @mock.patch("torchtitan.components.checkpoint.dcp.load") - def test_save_load_restores_state(self, mock_load, mock_save, mock_rank): - mock_save.side_effect = self.fake_save - mock_load.side_effect = self.fake_load - manager = CheckpointManager( - dataloader=self.data_loader, - model_parts=self.model_parts, - optimizers=self.optimizers, - lr_schedulers=self.lr_schedulers, - states=self.states, - job_config=self.job_config, - ft_manager=self.ft_manager, - ) - - w0 = self.model_part.weight.clone() - b0 = self.model_part.bias.clone() - p0 = self.optimizers._fake_param.clone() - manager.save(curr_step=1) - with torch.no_grad(): - self.model_part.weight.zero_() - self.model_part.bias.zero_() - self.optimizers._fake_param = torch.tensor([42.0], dtype=torch.float32) - manager.load(step=1) - - self.assertTrue(torch.equal(self.model_part.weight, w0)) - self.assertTrue(torch.equal(self.model_part.bias, b0)) - self.assertTrue(torch.equal(self.optimizers._fake_param, p0)) - manager.close() - - @mock.patch("torch.distributed.get_rank", return_value=0) - @mock.patch("torchtitan.components.checkpoint.dcp.save") - @mock.patch("torchtitan.components.checkpoint.dcp.load") - def test_save_and_purge_keeps_last_k_checkpoints( - self, mock_load, mock_save, mock_rank - ): - mock_save.side_effect = self.fake_save - manager = CheckpointManager( - dataloader=self.data_loader, - model_parts=self.model_parts, - optimizers=self.optimizers, - lr_schedulers=self.lr_schedulers, - states=self.states, - job_config=self.job_config, - ft_manager=self.ft_manager, - ) - - manager.save(curr_step=1) - manager.save(curr_step=2) - manager.save(curr_step=3) - deadline = time.time() + 5.0 - - while True: - exist = sorted(os.listdir(self.test_folder)) - if exist == ["step-2", "step-3"]: - break - if time.time() > deadline: - self.fail(f"Purge timed out; found {exist}") - time.sleep(0.05) - - self.assertListEqual(sorted(os.listdir(self.test_folder)), ["step-2", "step-3"]) - calls = [c.kwargs.get("checkpoint_id") for c in mock_save.call_args_list] - expected = [os.path.join(self.test_folder, f"step-{i}") for i in (1, 2, 3)] - self.assertListEqual(calls, expected) - sd = torch.load( - os.path.join(self.test_folder, "step-3", "state_dict.pt"), - weights_only=False, - ) - self.assertIn("optimizer", sd) - torch.testing.assert_close(sd["optimizer"]["fake_param"], torch.tensor([1.0])) - manager.close() - - @mock.patch("torch.distributed.get_rank", return_value=1) - @mock.patch("torchtitan.components.checkpoint.dcp.save") - @mock.patch("torchtitan.components.checkpoint.dcp.load") - def test_nonzero_rank_does_not_purge_or_save(self, mock_load, mock_save, mock_rank): - mock_save.side_effect = self.fake_save - manager = CheckpointManager( - dataloader=self.data_loader, - model_parts=self.model_parts, - optimizers=self.optimizers, - lr_schedulers=self.lr_schedulers, - states=self.states, - job_config=self.job_config, - ft_manager=self.ft_manager, - ) - manager.save(curr_step=1) - manager.save(curr_step=2) - manager.save(curr_step=3) - time.sleep(1) - self.assertListEqual( - sorted(os.listdir(self.test_folder)), ["step-1", "step-2", "step-3"] - ) - self.assertEqual(len(mock_save.call_args_list), 3) - manager.close() - - def test_load_returns_false_when_no_checkpoint_folder(self): - cfg = self.job_config.checkpoint - cfg.folder = "nonexistent" - manager = CheckpointManager( - dataloader=self.data_loader, - model_parts=self.model_parts, - optimizers=self.optimizers, - lr_schedulers=self.lr_schedulers, - states=self.states, - job_config=self.job_config, - ft_manager=self.ft_manager, - ) - self.assertFalse(manager.load(step=-1)) - manager.close() - - @mock.patch("torch.distributed.get_rank", return_value=0) - @mock.patch("torchtitan.components.checkpoint.dcp.load") - def test_load_finds_latest_and_calls_dcp_load(self, mock_load, mock_rank): - ckpt_folder = os.path.join(self.test_folder, "checkpoints") - os.makedirs(ckpt_folder, exist_ok=True) - for s in (2, 5): - d = os.path.join(ckpt_folder, f"step-{s}") - os.makedirs(d, exist_ok=True) - open(os.path.join(d, ".metadata"), "w").close() - cfg = self.job_config.checkpoint - cfg.folder = "checkpoints" - manager = CheckpointManager( - dataloader=self.data_loader, - model_parts=self.model_parts, - optimizers=self.optimizers, - lr_schedulers=self.lr_schedulers, - states=self.states, - job_config=self.job_config, - ft_manager=self.ft_manager, - ) - res = manager.load(step=-1) - expected = os.path.join(ckpt_folder, "step-5") - mock_load.assert_called_once() - args, kwargs = mock_load.call_args - self.assertEqual(args[0], manager._states_to_load(model_only=False)) - self.assertEqual(kwargs.get("checkpoint_id"), expected) - self.assertTrue(res) - manager.close() - - @mock.patch("torch.distributed.get_rank", return_value=0) - @mock.patch("torchtitan.components.checkpoint.dcp.save") - @mock.patch("torchtitan.components.checkpoint.dcp.load") - def test_interval_respects_interval(self, mock_load, mock_save, mock_rank): - """ - Test that save() only triggers on step 1 and multiples of interval, skipping others, - but respects force flag to override interval. - """ - cfg = self.job_config.checkpoint - cfg.interval = 3 - cfg.keep_latest_k = 0 - mock_save.side_effect = self.fake_save - manager = CheckpointManager( - dataloader=self.data_loader, - model_parts=self.model_parts, - optimizers=self.optimizers, - lr_schedulers=self.lr_schedulers, - states=self.states, - job_config=self.job_config, - ft_manager=self.ft_manager, - ) - manager.save(curr_step=1) - self.assertEqual(mock_save.call_count, 0) - manager.save(curr_step=2) - self.assertEqual(mock_save.call_count, 0) - manager.save(curr_step=2, force=True) - self.assertEqual(mock_save.call_count, 1) - manager.save(curr_step=3) - self.assertEqual(mock_save.call_count, 2) - manager.save(curr_step=4) - self.assertEqual(mock_save.call_count, 2) - manager.save(curr_step=4, force=True) - self.assertEqual(mock_save.call_count, 3) - manager.close() - - @mock.patch("torch.distributed.get_rank", return_value=0) - @mock.patch("torchtitan.components.checkpoint.dcp.save") - @mock.patch("torchtitan.components.checkpoint.dcp.load") - def test_last_save_model_weights_only_and_initial_load_model_weights_only( - self, mock_load, mock_save, mock_rank - ): - mock_save.side_effect = self.fake_save - mock_load.side_effect = self.fake_load - # Phase 1: save model weights only - self.job_config.checkpoint.last_save_model_weights_only = True - manager1 = CheckpointManager( - dataloader=self.data_loader, - model_parts=self.model_parts, - optimizers=self.optimizers, - lr_schedulers=self.lr_schedulers, - states={MODEL: self.model_part}, - job_config=self.job_config, - ft_manager=self.ft_manager, - ) - manager1.save(curr_step=1, force=True) - path1 = os.path.join(self.test_folder, "step-1") - self.assertTrue(os.path.isdir(path1)) - # Phase 2: initial load from step-1 - cfg = self.job_config.checkpoint - cfg.last_save_model_weights_only = False - cfg.initial_load_model_weights_only = True - cfg.initial_load_path = path1 - cfg.folder = "" - self.job_config.job.dump_folder = self.test_folder - manager2 = CheckpointManager( - dataloader=self.data_loader, - model_parts=self.model_parts, - optimizers=self.optimizers, - lr_schedulers=self.lr_schedulers, - states={MODEL: self.model_part}, - job_config=self.job_config, - ft_manager=self.ft_manager, - ) - r1 = manager2.load(step=1) - self.assertTrue(r1) - mock_load.assert_called_once() - args1, kwargs1 = mock_load.call_args - self.assertEqual(kwargs1.get("checkpoint_id"), path1) - # Phase 3: save new step under default folder, then load that - manager2.save(curr_step=2, force=True) - # Default folder is test_folder, so step-2 under that - step2_dir = os.path.join(self.test_folder, "step-2") - self.assertTrue(os.path.isdir(step2_dir)) - r2 = manager2.load(step=2) - self.assertTrue(r2) - self.assertEqual(mock_load.call_count, 2) - args2, kwargs2 = mock_load.call_args_list[1] - self.assertEqual(kwargs2.get("checkpoint_id"), step2_dir) - manager1.close() - manager2.close() - - @mock.patch("torchtitan.components.checkpoint.dist.new_group") - @mock.patch( - "torchtitan.components.checkpoint.dcp.async_save", side_effect=fake_async_save - ) - def test_async_save_calls_async_wait(self, mock_async_save, mock_new_group): - """ - Test that in AsyncMode.ASYNC, save() waits on previous async future. - """ - # Configure async mode - job_config = DummyJobConfig(job=self.job_config.job) - job_config.checkpoint.async_mode = "async" - ft_manager = DummyFTManager() - states = {"trainer": torch.tensor([0])} - manager = CheckpointManager( - dataloader=self.data_loader, - model_parts=self.model_parts, - optimizers=self.optimizers, - lr_schedulers=self.lr_schedulers, - states=states, - job_config=job_config, - ft_manager=ft_manager, - ) - - # First save schedules async - manager.save(curr_step=10, force=False) - future = manager.async_future - future.result.assert_not_called() - - # Second save should wait - manager.save(curr_step=20, force=False) - future.result.assert_called_once() - - # New future created - new_future = manager.async_future - new_future.result.assert_not_called() - - @mock.patch("torch.cuda.Stream") - @mock.patch("torchtitan.components.checkpoint.dist.new_group") - @mock.patch( - "torchtitan.components.checkpoint.dcp.async_save", side_effect=fake_async_save - ) - def test_ft_async_save_calls_async_wait( - self, - mock_async_save, - mock_new_group, - mock_cuda_stream, - ): - """ - Test that with FT enabled, AsyncMode.ASYNC via FT triggers correct waits. - """ - job_config = DummyJobConfig(job=self.job_config.job) - job_config.checkpoint.async_mode = "disabled" - ft_manager = mock.Mock() - ft_manager.enabled = True - states = {"trainer": torch.tensor([0])} - manager = CheckpointManager( - dataloader=self.data_loader, - model_parts=self.model_parts, - optimizers=self.optimizers, - lr_schedulers=self.lr_schedulers, - states=states, - job_config=job_config, - ft_manager=ft_manager, - ) - - # Initially no future - self.assertIsNone(manager.async_future) - manager.save(curr_step=5, force=False) - self.assertIsNotNone(manager.async_future) - - manager.async_future.result.assert_not_called() - prev_future = manager.async_future - manager.save(curr_step=6, force=False) - prev_future.result.assert_called_once() - self.assertIsNotNone(manager.async_future) - manager.async_future.result.assert_not_called() - - @mock.patch("torch.distributed.get_rank", return_value=0) - @mock.patch("torchtitan.components.checkpoint.dcp.save") - def test_enable_first_step_checkpoint(self, mock_save, mock_rank): - """ - Test that enable_first_step_checkpoint triggers checkpoint save at step 1. - """ - mock_save.side_effect = self.fake_save - - # Test with enable_first_step_checkpoint=False (default case) - cfg = self.job_config.checkpoint - cfg.interval = 10 # Set interval to 10 so step 1 wouldn't normally trigger save - cfg.keep_latest_k = 0 # Disable purging to avoid confusion - - manager = CheckpointManager( - dataloader=self.data_loader, - model_parts=self.model_parts, - optimizers=self.optimizers, - lr_schedulers=self.lr_schedulers, - states=self.states, - job_config=self.job_config, - ft_manager=self.ft_manager, - ) - - # Step 1 should not trigger save when enable_first_step_checkpoint=False - # and not at interval - manager.save(curr_step=1) - self.assertEqual(mock_save.call_count, 0) - - # Step 10 should trigger save due to interval - manager.save(curr_step=10) - self.assertEqual(mock_save.call_count, 1) - - manager.close() - - # Test with enable_first_step_checkpoint=True - mock_save.reset_mock() - cfg.enable_first_step_checkpoint = True - - manager2 = CheckpointManager( - dataloader=self.data_loader, - model_parts=self.model_parts, - optimizers=self.optimizers, - lr_schedulers=self.lr_schedulers, - states=self.states, - job_config=self.job_config, - ft_manager=self.ft_manager, - ) - - # Step 1 should trigger save due to enable_first_step_checkpoint=True - manager2.save(curr_step=1) - self.assertEqual(mock_save.call_count, 1) - - # Step 2 should not trigger save (not at interval and not forced) - manager2.save(curr_step=2) - self.assertEqual(mock_save.call_count, 1) - - # Step 10 should trigger save due to interval - manager2.save(curr_step=10) - self.assertEqual(mock_save.call_count, 2) - - manager2.close() - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit_tests/test_controller.py b/tests/unit_tests/test_controller.py deleted file mode 100644 index 320a62a37..000000000 --- a/tests/unit_tests/test_controller.py +++ /dev/null @@ -1,320 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -"""Tests for monarch_utils.py. - -Run this with: -$ pytest ./tests/unit_tests/test_controller.py - -""" - -import operator - -import pytest -from forge.controller.stack import _common_ancestor, stack, StackedActorMeshRef -from monarch.actor import Accumulator, Actor, endpoint, local_proc_mesh - - -class Counter(Actor): - def __init__(self, v: int): - self.v = v - - @endpoint - async def incr(self): - self.v += 1 - - @endpoint - async def value(self) -> int: - return self.v - - -class CounterA(Counter): - def __init__(self, v: int, step: int = 1): - super().__init__(v) - self.step = step - - @endpoint - async def decr(self): - """Decrement the counter by step. - This is a function that is unique to CounterA. - """ - self.v -= self.step - - @endpoint - async def incr(self): - self.v += self.step - - -class CounterB(Counter): - def __init__(self, v: int, multiplier: int = 2): - super().__init__(v) - self.multiplier = multiplier - - @endpoint - async def reset(self): - """Reset the counter to 0. - This is a function that is unique to CounterB. - """ - self.v = 0 - - @endpoint - async def incr(self): - self.v *= self.multiplier - - -class CounterC(Actor): - def __init__(self, v: int, increment: int = 1): - self.v = v - self.increment = increment - - @endpoint - async def incr(self): - self.v += self.increment - - @endpoint - async def value(self) -> int: - return self.v - - @endpoint - async def double(self): - self.v *= 2 - - -class CounterD(CounterA): - def __init__(self, v: int, step: int = 1, factor: int = 3): - super().__init__(v, step) - self.factor = factor - - @endpoint - async def multiply(self): - """Multiply the counter by factor. - This is a function that is unique to CounterD. - """ - self.v *= self.factor - - @endpoint - async def decr(self): - """Override decr to decrement by step * factor.""" - self.v -= self.step * self.factor - - -def test_common_ancestor(): - proc = local_proc_mesh(gpus=1).get() - - # Test with same class - counter1 = proc.spawn("counter1", Counter, 0).get() - counter2 = proc.spawn("counter2", Counter, 0).get() - assert _common_ancestor(counter1, counter2) == Counter - - # Test with parent-child relationship - counter_a = proc.spawn("counter_a", CounterA, 0).get() - assert _common_ancestor(counter1, counter_a) == Counter - - # Test with siblings - counter_b = proc.spawn("counter_b", CounterB, 0).get() - assert _common_ancestor(counter_a, counter_b) == Counter - - # Test with unrelated classes - counter_c = proc.spawn("counter_c", CounterC, 0).get() - assert _common_ancestor(counter_a, counter_c) == Actor - - # Test with empty list - assert _common_ancestor() is None - - # Test with mixed hierarchy - assert _common_ancestor(counter1, counter_a, counter_b) == Counter - assert _common_ancestor(counter_a, counter_b, counter_c) == Actor - - -def test_identical_actor_stack(): - proc1 = local_proc_mesh(gpus=1).get() - proc2 = local_proc_mesh(gpus=1).get() - - counter1 = proc1.spawn("counter1", Counter, 0).get() - counter2 = proc2.spawn("counter2", Counter, 0).get() - - stacked = stack(counter1, counter2) - assert stacked is not None - assert isinstance(stacked, StackedActorMeshRef) - - result = stacked.incr.call() - [r.get() for r in result] - assert counter1.value.choose().get() == 1 - assert counter2.value.choose().get() == 1 - - -def test_heterogeneous_actor_stack(): - """Test stacking actors of different types that share a common ancestor.""" - proc = local_proc_mesh(gpus=1).get() - - # Create different types of counters - counter = proc.spawn("counter", Counter, 0).get() - counter_a = proc.spawn("counter_a", CounterA, 0).get() - counter_b = proc.spawn("counter_b", CounterB, 0).get() - - # Stack them together - they should use Counter as the common interface - stacked = stack(counter, counter_a, counter_b) - - # Verify the stacked actor has the common endpoints - assert hasattr(stacked, "incr") - assert hasattr(stacked, "value") - - # Verify unique endpoints are not accessible on the stacked actor - assert not hasattr(stacked, "decr") # CounterA specific - assert not hasattr(stacked, "reset") # CounterB specific - - # Test that the common endpoints work - res = stacked.incr.call() - [r.get() for r in res] - - # Verify each actor was affected according to its implementation - assert counter.value.choose().get() == 1 # Regular counter: +1 - assert counter_a.value.choose().get() == 1 # CounterA: +step (default 1) - assert counter_b.value.choose().get() == 0 # CounterB: *multiplier (default 2) - - -def test_stack_with_custom_interface(): - """Test stacking actors with a specified interface.""" - proc = local_proc_mesh(gpus=1).get() - - # Create different types of counters - counter_a = proc.spawn("counter_a", CounterA, 0).get() - counter_d = proc.spawn("counter_d", CounterD, 0).get() - - # Without specifying interface, they would use CounterA as common ancestor - # But we want to use Counter interface instead - stacked = stack(counter_a, counter_d, interface=Counter) - - # Verify the stacked actor has only Counter endpoints - assert hasattr(stacked, "incr") - assert hasattr(stacked, "value") - - # Verify CounterA/CounterD specific endpoints are not accessible - assert not hasattr(stacked, "decr") # Should not be available - assert not hasattr(stacked, "multiply") # CounterD specific - - # Test that the common endpoints work - res = stacked.incr.call() - [r.get() for r in res] - - # Verify each actor was affected according to its implementation - assert counter_a.value.choose().get() == 1 # CounterA: +step (default 1) - assert counter_d.value.choose().get() == 1 # CounterD: +step (default 1) - - -def test_stacked_endpoint_consistency(): - """Tests that the StackedEndpoint shares the same APIs as EndPoint.""" - - proc1 = local_proc_mesh(gpus=1).get() - proc2 = local_proc_mesh(gpus=1).get() - counter1 = proc1.spawn("counter1", Counter, 0).get() - counter2 = proc2.spawn("counter2", Counter, 0).get() - - stacked = stack(counter1, counter2) - regular_endpoint = counter1.incr - stacked_endpoint = stacked.incr - - # Verify both endpoints implement all methods from EndpointInterface - for method_name in ["call", "broadcast", "choose", "stream"]: - assert hasattr(regular_endpoint, method_name), f"Endpoint missing {method_name}" - assert hasattr( - stacked_endpoint, method_name - ), f"StackedEndpoint missing {method_name}" - - -def test_stacked_endpoint_choose(): - """Tests that the StackedEndpoint.choose method works correctly.""" - proc1 = local_proc_mesh(gpus=1).get() - proc2 = local_proc_mesh(gpus=1).get() - counter1 = proc1.spawn("counter1", Counter, 0).get() - counter2 = proc2.spawn("counter2", Counter, 0).get() - - stacked = stack(counter1, counter2) - stacked_endpoint = stacked.incr - - # Test choose - stacked_endpoint.choose().get() - # At least one counter should be incremented - assert counter1.value.choose().get() + counter2.value.choose().get() >= 1 - - -def test_stacked_endpoint_call(): - """Tests that the StackedEndpoint.call method works correctly.""" - proc1 = local_proc_mesh(gpus=1).get() - proc2 = local_proc_mesh(gpus=1).get() - counter1 = proc1.spawn("counter1", Counter, 0).get() - counter2 = proc2.spawn("counter2", Counter, 0).get() - - stacked = stack(counter1, counter2) - stacked_endpoint = stacked.incr - - # Test call - result = stacked_endpoint.call() - [r.get() for r in result] - assert isinstance(result, list) - assert len(result) == 2 - - # Verify both counters were incremented - assert counter1.value.choose().get() == 1 - assert counter2.value.choose().get() == 1 - - -def test_stacked_endpoint_broadcast(): - """Tests that the StackedEndpoint.broadcast method works correctly.""" - proc1 = local_proc_mesh(gpus=1).get() - proc2 = local_proc_mesh(gpus=1).get() - counter1 = proc1.spawn("counter1", Counter, 0).get() - counter2 = proc2.spawn("counter2", Counter, 0).get() - - stacked = stack(counter1, counter2) - stacked_endpoint = stacked.incr - - # Test broadcast - stacked_endpoint.broadcast() - # Both counters should be incremented - assert counter1.value.choose().get() == 1 - assert counter2.value.choose().get() == 1 - - -@pytest.mark.asyncio -async def test_stacked_endpoint_stream(): - """Tests that the StackedEndpoint.stream method works correctly.""" - proc1 = await local_proc_mesh(gpus=1) - proc2 = await local_proc_mesh(gpus=1) - counter1 = await proc1.spawn("counter1", Counter, 0) - counter2 = await proc2.spawn("counter2", Counter, 0) - - stacked = stack(counter1, counter2) - stacked_endpoint = stacked.incr - - # Test stream - async def test_stream(): - return [await x for x in stacked_endpoint.stream()] - - results = await test_stream() - assert len(results) == 2 - - # Verify both counters were incremented - assert counter1.value.choose().get() == 1 - assert counter2.value.choose().get() == 1 - - -def test_stacked_actor_with_accumulator(): - """Tests that Accumulator works correctly with StackedActor endpoints.""" - proc1 = local_proc_mesh(gpus=1).get() - proc2 = local_proc_mesh(gpus=1).get() - - counter1 = proc1.spawn("counter1", Counter, 5).get() - counter2 = proc2.spawn("counter2", Counter, 10).get() - stacked = stack(counter1, counter2) - - acc = Accumulator(stacked.value, 0, operator.add) - result = acc.accumulate().get() - assert result == 15 - - stacked.incr.broadcast() - result = acc.accumulate().get() - assert result == 17 diff --git a/tests/unit_tests/test_dataset_checkpointing.py b/tests/unit_tests/test_dataset_checkpointing.py deleted file mode 100644 index 00998fc49..000000000 --- a/tests/unit_tests/test_dataset_checkpointing.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch -from datasets import load_dataset -from torchtitan.config_manager import ConfigManager -from torchtitan.datasets.hf_datasets import build_hf_dataloader, DatasetConfig, DATASETS -from torchtitan.datasets.tokenizer.tiktoken import TikTokenizer - - -class TestDatasetCheckpointing(unittest.TestCase): - def setUp(self): - DATASETS["c4_test_streaming"] = DatasetConfig( - path="tests/assets/c4_test", - loader=lambda path: load_dataset(path, split="train").to_iterable_dataset( - num_shards=4 - ), - text_processor=lambda sample: sample["text"], - ) - - def tearDown(self): - del DATASETS["c4_test_streaming"] - - def test_c4_resumption(self): - for dataset_name in ["c4_test", "c4_test_streaming"]: - for world_size in [2, 4]: - for rank in range(world_size): - batch_size = 1 - seq_len = 1024 - - dl = self._build_dataloader( - dataset_name, batch_size, seq_len, world_size, rank - ) - - it = iter(dl) - for _ in range(250): - next(it) - state = dl.state_dict() - - # Create new dataloader, restore checkpoint, and check if next data yielded is the same as above - dl_resumed = self._build_dataloader( - dataset_name, batch_size, seq_len, world_size, rank - ) - dl_resumed.load_state_dict(state) - it_resumed = iter(dl_resumed) - - for _ in range(500): - expected_input_ids, expected_labels = next(it) - input_ids, labels = next(it_resumed) - assert torch.equal( - input_ids["input"], expected_input_ids["input"] - ) - assert torch.equal(labels, expected_labels) - - def _build_dataloader(self, dataset_name, batch_size, seq_len, world_size, rank): - tokenizer = TikTokenizer("./tests/assets/test_tiktoken.model") - config_manager = ConfigManager() - config = config_manager.parse_args( - [ - "--training.dataset", - dataset_name, - "--training.local_batch_size", - str(batch_size), - "--training.seq_len", - str(seq_len), - ] - ) - - return build_hf_dataloader( - tokenizer=tokenizer, - dp_world_size=world_size, - dp_rank=rank, - job_config=config, - ) diff --git a/tests/unit_tests/test_job_config.py b/tests/unit_tests/test_job_config.py deleted file mode 100644 index d13c0db62..000000000 --- a/tests/unit_tests/test_job_config.py +++ /dev/null @@ -1,316 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import tempfile -import unittest -from dataclasses import dataclass - -import pytest -import tomli_w -from torchtitan.config_manager import ConfigManager, JobConfig - - -class TestJobConfig(unittest.TestCase): - def test_command_line_args(self): - config_manager = ConfigManager() - config = config_manager.parse_args([]) - assert config.training.steps == 10000 - - def test_job_config_file(self): - config_manager = ConfigManager() - config = config_manager.parse_args( - [ - "--job.config_file", - "./torchtitan/models/llama3/train_configs/debug_model.toml", - ] - ) - assert config.training.steps == 10 - - def test_job_file_does_not_exist(self): - with pytest.raises(FileNotFoundError): - config_manager = ConfigManager() - config_manager.parse_args(["--job.config_file", "ohno.toml"]) - - def test_empty_config_file(self): - with tempfile.NamedTemporaryFile() as fp: - config_manager = ConfigManager() - config = config_manager.parse_args(["--job.config_file", fp.name]) - assert config.job.description - - def test_job_config_file_cmd_overrides(self): - config_manager = ConfigManager() - config = config_manager.parse_args( - [ - "--job.config_file", - "./torchtitan/models/llama3/train_configs/debug_model.toml", - "--job.dump_folder", - "/tmp/test_tt/", - ] - ) - assert config.job.dump_folder == "/tmp/test_tt/" - - def test_parse_pp_split_points(self): - toml_splits = ["layers.2", "layers.4", "layers.6"] - cmdline_splits = ["layers.1", "layers.3", "layers.5"] - # no split points specified - config_manager = ConfigManager() - config = config_manager.parse_args( - [ - "--job.config_file", - "./torchtitan/models/llama3/train_configs/debug_model.toml", - ] - ) - assert config.parallelism.pipeline_parallel_split_points == [] - - # toml has no split points, but cmdline splits are specified - config_manager = ConfigManager() - config = config_manager.parse_args( - [ - "--job.config_file", - "./torchtitan/models/llama3/train_configs/debug_model.toml", - "--parallelism.pipeline_parallel_split_points", - ",".join(cmdline_splits), - ] - ) - assert ( - config.parallelism.pipeline_parallel_split_points == cmdline_splits - ), config.parallelism.pipeline_parallel_split_points - - # toml has split points, cmdline does not - with tempfile.NamedTemporaryFile() as fp: - with open(fp.name, "wb") as f: - tomli_w.dump( - { - "parallelism": { - "pipeline_parallel_split_points": toml_splits, - } - }, - f, - ) - config_manager = ConfigManager() - config = config_manager.parse_args(["--job.config_file", fp.name]) - assert ( - config.parallelism.pipeline_parallel_split_points == toml_splits - ), config.parallelism.pipeline_parallel_split_points - - # toml has split points, cmdline overrides them - with tempfile.NamedTemporaryFile() as fp: - with open(fp.name, "wb") as f: - tomli_w.dump( - { - "parallelism": { - "pipeline_parallel_split_points": toml_splits, - } - }, - f, - ) - config_manager = ConfigManager() - config = config_manager.parse_args( - [ - "--job.config_file", - fp.name, - "--parallelism.pipeline_parallel_split_points", - ",".join(cmdline_splits), - ] - ) - assert ( - config.parallelism.pipeline_parallel_split_points == cmdline_splits - ), config.parallelism.pipeline_parallel_split_points - - def test_parse_exclude_from_loading(self): - toml_splits = ["optimizer", "dataloader"] - cmdline_splits = ["optimizer", "lr_scheduler"] - # no split points specified - config_manager = ConfigManager() - config = config_manager.parse_args( - [ - "--job.config_file", - "./torchtitan/models/llama3/train_configs/debug_model.toml", - ] - ) - assert config.checkpoint.exclude_from_loading == [] - - # toml has no split points, but cmdline splits are specified - config_manager = ConfigManager() - config = config_manager.parse_args( - [ - "--job.config_file", - "./torchtitan/models/llama3/train_configs/debug_model.toml", - "--checkpoint.exclude_from_loading", - ",".join(cmdline_splits), - ] - ) - assert ( - config.checkpoint.exclude_from_loading == cmdline_splits - ), config.checkpoint.exclude_from_loading - - # toml has split points, cmdline does not - with tempfile.NamedTemporaryFile() as fp: - with open(fp.name, "wb") as f: - tomli_w.dump( - { - "checkpoint": { - "exclude_from_loading": toml_splits, - } - }, - f, - ) - config_manager = ConfigManager() - config = config_manager.parse_args(["--job.config_file", fp.name]) - assert ( - config.checkpoint.exclude_from_loading == toml_splits - ), config.checkpoint.exclude_from_loading - - # toml has split points, cmdline overrides them - with tempfile.NamedTemporaryFile() as fp: - with open(fp.name, "wb") as f: - tomli_w.dump( - { - "checkpoint": { - "exclude_from_loading": toml_splits, - } - }, - f, - ) - config_manager = ConfigManager() - config = config_manager.parse_args( - [ - "--job.config_file", - fp.name, - "--checkpoint.exclude_from_loading", - ",".join(cmdline_splits), - ] - ) - assert ( - config.checkpoint.exclude_from_loading == cmdline_splits - ), config.checkpoint.exclude_from_loading - - def test_job_config_model_converters_split(self): - config_manager = ConfigManager() - config = config_manager.parse_args([]) - assert config.model.converters == [] - - config_manager = ConfigManager() - config = config_manager.parse_args(["--model.converters", "float8,mxfp"]) - assert config.model.converters == ["float8", "mxfp"] - - def test_print_help(self): - from tyro.extras import get_parser - - parser = get_parser(ConfigManager) - parser.print_help() - - def test_extend_jobconfig_directly(self): - @dataclass - class CustomCheckpoint: - convert_path: str = "/custom/path" - fake_model: bool = True - - @dataclass - class CustomJobConfig: - checkpoint: CustomCheckpoint - - MergedJobConfig = ConfigManager._merge_configs(JobConfig, CustomJobConfig) - - cli_args = [ - "--checkpoint.convert_path=/override/path", - "--checkpoint.fake_model", - ] - - config_manager = ConfigManager(config_cls=MergedJobConfig) - config = config_manager.parse_args(cli_args) - - assert config.checkpoint.convert_path == "/override/path" - assert config.checkpoint.fake_model is True - assert hasattr(config, "model") - - def test_custom_parser(self): - path = "tests.assets.extend_jobconfig_example" - - config_manager = ConfigManager() - config = config_manager.parse_args( - [ - f"--experimental.custom_args_module={path}", - "--custom_args.how-is-your-day", - "bad", - "--model.converters", - "float8,mxfp", - ] - ) - assert config.custom_args.how_is_your_day == "bad" - assert config.model.converters == ["float8", "mxfp"] - result = config.to_dict() - assert isinstance(result, dict) - - # There will be a SystemExit raised by ArgumentParser with exist status 2. - with self.assertRaisesRegex(SystemExit, "2"): - config = config_manager.parse_args( - [ - f"--experimental.custom_args_module={path}", - "--custom_args.how-is-your-day", - "bad", - "--model.converters", - "float8,mxfp", - "--abcde", - ] - ) - - with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as fp: - tomli_w.dump( - { - "experimental": { - "custom_args_module": path, - } - }, - fp, - ) - fp.flush() - - config_manager = ConfigManager() - config = config_manager.parse_args( - [ - f"--job.config_file={fp.name}", - f"--experimental.custom_args_module={path}", - "--custom_args.how-is-your-day", - "bad", - "--model.converters", - "float8,mxfp", - ] - ) - assert config.custom_args.how_is_your_day == "bad" - assert config.training.my_custom_steps == 32 - assert config.model.converters == ["float8", "mxfp"] - result = config.to_dict() - assert isinstance(result, dict) - - with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as fp: - tomli_w.dump( - { - "experimental": { - "custom_args_module": path, - }, - "custom_args": {"how_is_your_day": "really good"}, - "model": {"converters": ["float8", "mxfp"]}, - }, - fp, - ) - fp.flush() - - config_manager = ConfigManager() - config = config_manager.parse_args( - [ - f"--job.config_file={fp.name}", - ] - ) - - assert config.custom_args.how_is_your_day == "really good" - assert config.model.converters == ["float8", "mxfp"] - result = config.to_dict() - assert isinstance(result, dict) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit_tests/test_lr_scheduler.py b/tests/unit_tests/test_lr_scheduler.py deleted file mode 100644 index e66d4d265..000000000 --- a/tests/unit_tests/test_lr_scheduler.py +++ /dev/null @@ -1,280 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest -from unittest.mock import MagicMock - -import torch -from torch.optim import Adam - -from torchtitan.components.lr_scheduler import build_lr_schedulers -from torchtitan.components.optimizer import OptimizersContainer - - -class TestLRScheduler(unittest.TestCase): - def setUp(self): - # Create a simple model with parameters - self.model = torch.nn.Linear(10, 10) - # Create an optimizer - self.optimizer = Adam(self.model.parameters(), lr=0.1) - # Create an optimizer container - self.optimizer_container = MagicMock(spec=OptimizersContainer) - self.optimizer_container.__iter__.return_value = iter([self.optimizer]) - self.optimizer_container.__len__.return_value = 1 - - def create_job_config( - self, - training_steps=10, - warmup_steps=None, - decay_ratio=None, - decay_type=None, - lr_min=None, - ): - # Create a job config with the specified parameters - from torchtitan.config_manager import ConfigManager - - args = [ - "--training.steps", - str(training_steps), - ] - - args += ( - ["--lr_scheduler.warmup_steps", str(warmup_steps)] - if warmup_steps is not None - else [] - ) - args += ( - ["--lr_scheduler.decay_ratio", str(decay_ratio)] - if decay_ratio is not None - else [] - ) - args += ( - ["--lr_scheduler.decay_type", decay_type] if decay_type is not None else [] - ) - args += ["--lr_scheduler.lr_min", str(lr_min)] if lr_min is not None else [] - - config_manager = ConfigManager() - # Create base config with parameters passed directly - config = config_manager.parse_args(args) - - return config - - def test_linear_warmup_decay(self): - """Test the linear warmup followed by linear decay schedule.""" - # Create a job config with 10 steps, 2 warmup steps, and linear decay - config = self.create_job_config( - training_steps=10, - warmup_steps=2, - decay_ratio=None, # Use default decay: start decay immediately - decay_type=None, - lr_min=None, - ) - - # Build the lr scheduler - lr_scheduler = build_lr_schedulers(self.optimizer_container, config) - - # Expected adjustment factors for each step - expected_factors = [ - 0.5, # Step 0: 50% of max LR (warmup) - 1.0, # Step 1: 100% of max LR (warmup complete) - 1.0, # Step 2: We maunally added step of stable phase, to prevent LR from dropping to 0 at last step - 7.0 / 8.0, # Step 3: 7/8 of max LR - 6.0 / 8.0, # Step 4: 3/4 of max LR - 5.0 / 8.0, # Step 5: 5/8 of max LR - 4.0 / 8.0, # Step 6: 1/2 of max LR - 3.0 / 8.0, # Step 7: 3/8 of max LR - 2.0 / 8.0, # Step 8: 1/4 of max LR - 1.0 / 8.0, # Step 9: 1/8 of max LR - ] - - # Check the learning rate at each step - for i, factor in enumerate(expected_factors): - # The LambdaLR multiplies the base lr by the factor - expected_lr = 0.1 * factor - self.assertAlmostEqual( - self.optimizer.param_groups[0]["lr"], - expected_lr, - places=6, - msg=f"Step {i}: Expected LR {expected_lr}, got {self.optimizer.param_groups[0]['lr']}", - ) - lr_scheduler.step() - - def test_warmup_stable_decay(self): - """Test warmup followed by stable phase and then decay.""" - # Create a job config with 10 steps, 2 warmup steps, 3 stable steps, and 5 decay steps - config = self.create_job_config( - training_steps=10, - warmup_steps=2, - decay_ratio=0.5, # 50% of steps for decay - decay_type="linear", - lr_min=0.0, - ) - - # Build the lr scheduler - lr_scheduler = build_lr_schedulers(self.optimizer_container, config) - - # Expected adjustment factors for each step - expected_factors = [ - 0.5, # Step 0: 50% of max LR (warmup) - 1.0, # Step 1: 100% of max LR (warmup complete) - 1.0, # Step 2: Stable phase - 1.0, # Step 3: Stable phase - 1.0, # Step 4: Stable phase - 1.0, # Step 5: We maunally added step of stable phase, to prevent LR from dropping to 0 at last step - 0.8, # Step 6: Linear decay starts (80% of max LR) - 0.6, # Step 7: 60% of max LR - 0.4, # Step 8: 40% of max LR - 0.2, # Step 9: 20% of max LR - ] - - # Check the learning rate at each step - for i, factor in enumerate(expected_factors): - expected_lr = 0.1 * factor - self.assertAlmostEqual( - self.optimizer.param_groups[0]["lr"], - expected_lr, - places=6, - msg=f"Step {i}: Expected LR {expected_lr}, got {self.optimizer.param_groups[0]['lr']}", - ) - lr_scheduler.step() - - def test_min_lr(self): - """Test that the learning rate doesn't go below the minimum.""" - # Create a job config with a minimum learning rate - config = self.create_job_config( - training_steps=10, - warmup_steps=2, - decay_ratio=None, - decay_type="linear", - lr_min=0.2, # 20% of base LR as minimum - ) - - # Build the lr scheduler - lr_scheduler = build_lr_schedulers(self.optimizer_container, config) - - # Step through all steps - for _ in range(10): - lr_scheduler.step() - - # After all steps, LR should be at minimum (0.1 * 0.2 = 0.02) - self.assertAlmostEqual(self.optimizer.param_groups[0]["lr"], 0.02, places=6) - - def test_warmup_exceeds_training(self): - """Test when warmup steps exceed training steps.""" - # Create a job config where warmup steps > training steps - config = self.create_job_config( - training_steps=5, - warmup_steps=10, # More than training steps - decay_ratio=None, - decay_type="linear", - lr_min=0.0, - ) - - # Build the lr scheduler - should adjust warmup steps - lr_scheduler = build_lr_schedulers(self.optimizer_container, config) - - # Expected adjustment factors for each step - expected_factors = [ - 0.2, # Step 0: 50% of max LR (warmup) - 0.4, # Step 1: 100% of max LR (warmup complete) - 0.6, # Step 2: Stable phase - 0.8, # Step 3: Stable phase - 1.0, # Step 4: Stable phase - ] - - # Check the learning rate at each step - for i, factor in enumerate(expected_factors): - expected_lr = 0.1 * factor - self.assertAlmostEqual( - self.optimizer.param_groups[0]["lr"], - expected_lr, - places=6, - msg=f"Step {i}: Expected LR {expected_lr}, got {self.optimizer.param_groups[0]['lr']}", - ) - lr_scheduler.step() - - def test_warmup_stable_only(self): - """Test warmup followed by stable phase only, with no decay phase.""" - # Create a job config with 10 steps, 2 warmup steps, and no decay phase - config = self.create_job_config( - training_steps=10, - warmup_steps=2, - decay_ratio=0.0, # 0% of steps for decay (no decay) - decay_type="linear", - lr_min=0.0, - ) - - # Build the lr scheduler - lr_scheduler = build_lr_schedulers(self.optimizer_container, config) - - # Expected adjustment factors for each step - expected_factors = [ - 0.5, # Step 0: 50% of max LR (warmup) - 1.0, # Step 1: 100% of max LR (warmup complete) - 1.0, # Step 2: We maunally added step of stable phase, to prevent LR from dropping to 0 at last step - 1.0, # Step 3: Stable phase - 1.0, # Step 4: Stable phase - 1.0, # Step 5: Stable phase - 1.0, # Step 6: Stable phase - 1.0, # Step 7: Stable phase - 1.0, # Step 8: Stable phase - 1.0, # Step 9: Stable phase - ] - - # Check the learning rate at each step - for i, factor in enumerate(expected_factors): - expected_lr = 0.1 * factor - self.assertAlmostEqual( - self.optimizer.param_groups[0]["lr"], - expected_lr, - places=6, - msg=f"Step {i}: Expected LR {expected_lr}, got {self.optimizer.param_groups[0]['lr']}", - ) - lr_scheduler.step() - - def test_warmup_plus_decay_exceeds_training(self): - """Test when warmup + decay steps exceed training steps.""" - # Create a job config where warmup + decay steps > training steps - # Expected behaviro: warmup steps = 5, decay steps = 5 - config = self.create_job_config( - training_steps=10, - warmup_steps=5, - decay_ratio=0.8, # 80% of steps for decay (8 steps) - decay_type="linear", - lr_min=0.0, - ) - - # Build the lr scheduler - should adjust warmup steps - lr_scheduler = build_lr_schedulers(self.optimizer_container, config) - - # Expected adjustment factors for each step - expected_factors = [ - 0.2, # Step 0: 50% of max LR (warmup) - 0.4, # Step 1: 100% of max LR (warmup complete) - 0.6, # Step 2: Stable phase - 0.8, # Step 3: Stable phase - 1.0, # Step 4: Stable phase - 1.0, # Step 5: We maunally added step of stable phase, to prevent LR from dropping to 0 at last step - 0.8, # Step 6: Linear decay starts (80% of max LR) - 0.6, # Step 7: 60% of max LR - 0.4, # Step 8: 40% of max LR - 0.2, # Step 9: 20% of max LR - ] - - # Check the learning rate at each step - for i, factor in enumerate(expected_factors): - expected_lr = 0.1 * factor - self.assertAlmostEqual( - self.optimizer.param_groups[0]["lr"], - expected_lr, - places=6, - msg=f"Step {i}: Expected LR {expected_lr}, got {self.optimizer.param_groups[0]['lr']}", - ) - lr_scheduler.step() - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit_tests/test_model_converter.py b/tests/unit_tests/test_model_converter.py deleted file mode 100644 index bb5a1ecc4..000000000 --- a/tests/unit_tests/test_model_converter.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from torchtitan.components.quantization.float8 import Float8Converter -from torchtitan.config_manager import ConfigManager -from torchtitan.distributed import ParallelDims -from torchtitan.protocols.model_converter import ( - build_model_converters, - ModelConvertersContainer, -) - - -def build_parallel_dims(job_config, world_size): - parallelism_config = job_config.parallelism - parallel_dims = ParallelDims( - dp_shard=parallelism_config.data_parallel_shard_degree, - dp_replicate=parallelism_config.data_parallel_replicate_degree, - cp=parallelism_config.context_parallel_degree, - tp=parallelism_config.tensor_parallel_degree, - pp=parallelism_config.pipeline_parallel_degree, - world_size=world_size, - enable_loss_parallel=not parallelism_config.disable_loss_parallel, - ) - return parallel_dims - - -def test_build_model_converters_empty_list(): - config_manager = ConfigManager() - config = config_manager.parse_args([]) - parallel_dims = build_parallel_dims(config, 1) - - model_converters = build_model_converters(config, parallel_dims) - assert isinstance(model_converters, ModelConvertersContainer) - assert model_converters.converters == [] - - -def test_build_model_converters_float8_converter(): - config_manager = ConfigManager() - config = config_manager.parse_args( - ["--model.converters", "float8", "--float8.emulate"] - ) - parallel_dims = build_parallel_dims(config, 1) - - model_converters = build_model_converters(config, parallel_dims) - assert isinstance(model_converters, ModelConvertersContainer) - assert len(model_converters.converters) == 1 - assert isinstance(model_converters.converters[0], Float8Converter) diff --git a/tests/test_replay_buffer.py b/tests/unit_tests/test_replay_buffer.py similarity index 100% rename from tests/test_replay_buffer.py rename to tests/unit_tests/test_replay_buffer.py diff --git a/tests/unit_tests/test_train_spec.py b/tests/unit_tests/test_train_spec.py deleted file mode 100644 index 15780d10a..000000000 --- a/tests/unit_tests/test_train_spec.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from functools import partial - -import pytest -import torch -import torch.nn as nn -from torchtitan.components.loss import build_cross_entropy_loss -from torchtitan.components.lr_scheduler import build_lr_schedulers -from torchtitan.components.optimizer import build_optimizers, OptimizersContainer -from torchtitan.config_manager import JobConfig -from torchtitan.datasets.hf_datasets import build_hf_dataloader -from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer -from torchtitan.models.llama3 import parallelize_llama, pipeline_llama -from torchtitan.protocols.train_spec import ( - apply_to_train_specs, - BaseModelArgs, - get_train_spec, - ModelProtocol, - register_train_spec, - TrainSpec, -) - - -class FakeModel(nn.Module, ModelProtocol): - def __init__(self, model_args: BaseModelArgs) -> None: - super().__init__() - self.linear = nn.Linear(8, 8) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(x) - - def init_weights(self, buffer_device: torch.device | None = None) -> None: - nn.init.normal_(self.linear.weight, mean=0.0, std=0.02) - - -def fake_build_optimizers( - model_parts: list[nn.Module], job_config: JobConfig -) -> OptimizersContainer: - optimizer_kwargs = { - "lr": 0.1, - "betas": (0.9, 0.95), - "weight_decay": 0.1, - "fused": True, - "foreach": False, - } - return OptimizersContainer( - model_parts=model_parts, - optimizer_cls=torch.optim.Adam, - optimizer_kwargs=optimizer_kwargs, - ) - - -class TestTrainSpec: - def test_register_train_spec(self): - fake_config = {"fake": None} - spec = TrainSpec( - name="fake", - cls=FakeModel, - config=fake_config, - parallelize_fn=parallelize_llama, - pipelining_fn=pipeline_llama, - build_optimizers_fn=build_optimizers, - build_lr_schedulers_fn=build_lr_schedulers, - build_dataloader_fn=build_hf_dataloader, - build_tokenizer_fn=build_tiktoken_tokenizer, - build_loss_fn=build_cross_entropy_loss, - ) - register_train_spec(spec) - new_spec = get_train_spec("fake") - assert new_spec == spec - - with pytest.raises(ValueError): - new_spec = get_train_spec("fake2") - - def test_optim_hook(self): - fake_config = {"fake": None} - spec = TrainSpec( - name="fake2", - cls=FakeModel, - config=fake_config, - parallelize_fn=parallelize_llama, - pipelining_fn=pipeline_llama, - build_optimizers_fn=fake_build_optimizers, - build_lr_schedulers_fn=build_lr_schedulers, - build_dataloader_fn=build_hf_dataloader, - build_tokenizer_fn=build_tiktoken_tokenizer, - build_loss_fn=build_cross_entropy_loss, - ) - register_train_spec(spec) - new_spec = get_train_spec("fake2") - - # Demonstrate how to register a optimizer hook for all model specs - hook_called = False - - def my_hook( - optimizer: torch.optim.Optimizer, - args, - kwargs, - model_parts: list[nn.Module], - ) -> None: - nonlocal hook_called - hook_called = True - - def register_optimizer_hook_to_spec(spec: TrainSpec) -> TrainSpec: - # Create a closure to capture the original spec.build_optimizers_fn - original_build_optimizers_fn = spec.build_optimizers_fn - - def my_build_optimizer_fn( - model_parts: list[nn.Module], job_config: JobConfig - ) -> OptimizersContainer: - optimizers = original_build_optimizers_fn(model_parts, job_config) - optimizers.register_step_post_hook( - partial(my_hook, model_parts=model_parts) - ) - return optimizers - - spec.build_optimizers_fn = my_build_optimizer_fn - - apply_to_train_specs(register_optimizer_hook_to_spec) - - model = new_spec.cls(BaseModelArgs()) - model_parts = [model] - optimizers = new_spec.build_optimizers_fn(model_parts, JobConfig()) - assert optimizers.optimizers[0].__class__.__name__ == "Adam" - batch = torch.randn(8, 8) - model(batch).sum().backward() - assert not hook_called - optimizers.step() - assert hook_called