diff --git a/.meta/mast/main.py b/.meta/mast/main.py index f5b81e25d..513d96fc6 100644 --- a/.meta/mast/main.py +++ b/.meta/mast/main.py @@ -9,7 +9,6 @@ import sys from apps.grpo.main import main as grpo_main -from forge.cli.config import parse from forge.controller.launcher import ( JOB_NAME_KEY, LAUNCHER_KEY, @@ -25,6 +24,7 @@ ProvisionerConfig, ServiceConfig, ) +from forge.util.config import parse from omegaconf import DictConfig DEFAULT_CHECKPOINT_FOLDER_KEY = "checkpoint_folder" diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 1dbef0b76..0d49ef9ca 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -24,7 +24,6 @@ from forge.actors.reference_model import ReferenceModel from forge.actors.replay_buffer import ReplayBuffer from forge.actors.trainer import RLTrainer -from forge.cli.config import parse from forge.controller.actor import ForgeActor from forge.controller.provisioner import init_provisioner, shutdown from forge.data.rewards import MathReward, ThinkingReward @@ -34,6 +33,7 @@ from forge.observability.perf_tracker import Tracer from forge.types import LauncherConfig, ProvisionerConfig +from forge.util.config import parse from forge.util.ops import compute_logprobs from monarch.actor import endpoint from omegaconf import DictConfig diff --git a/apps/sft/main.py b/apps/sft/main.py index 27a8036d4..aa484608e 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -22,12 +22,12 @@ import torch import torchtitan.experiments.forge.train_spec as forge_train_spec -from forge.cli.config import parse from forge.controller import ForgeActor from forge.data.collate import collate_packed from forge.data.datasets.packed import PackedDataset, TextPacker from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset from forge.data.tokenizer import HuggingFaceModelTokenizer +from forge.util.config import parse from monarch.actor import current_rank, current_size, endpoint from omegaconf import DictConfig, OmegaConf diff --git a/src/forge/cli/__init__.py b/src/forge/cli/__init__.py deleted file mode 100644 index 2e41cd717..000000000 --- a/src/forge/cli/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/src/forge/cli/config.py b/src/forge/util/config.py similarity index 100% rename from src/forge/cli/config.py rename to src/forge/util/config.py diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index 0b99e75a2..01f01a390 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -17,11 +17,11 @@ from forge.actors.generator import Generator from forge.actors.trainer import RLTrainer -from forge.cli.config import resolve_hf_hub_paths from forge.controller.provisioner import init_provisioner from forge.controller.service.service import uuid from forge.types import LauncherConfig, ProvisionerConfig +from forge.util.config import resolve_hf_hub_paths from monarch.actor import endpoint from omegaconf import DictConfig, OmegaConf diff --git a/tests/integration_tests/test_titan_fwd_vs_hf_fwd.py b/tests/integration_tests/test_titan_fwd_vs_hf_fwd.py index 4fcd850e7..83e8809a7 100644 --- a/tests/integration_tests/test_titan_fwd_vs_hf_fwd.py +++ b/tests/integration_tests/test_titan_fwd_vs_hf_fwd.py @@ -25,9 +25,9 @@ import torch from forge.actors.reference_model import ReferenceModel -from forge.cli.config import _resolve_hf_model_path from forge.controller import ForgeActor from forge.controller.provisioner import shutdown +from forge.util.config import _resolve_hf_model_path from monarch.actor import endpoint from torchtitan.config.job_config import Checkpoint, Compile, Model, Parallelism from transformers import AutoModelForCausalLM, AutoTokenizer diff --git a/tests/sandbox/rl_trainer/main.py b/tests/sandbox/rl_trainer/main.py index e5ee6fddd..55714c49d 100644 --- a/tests/sandbox/rl_trainer/main.py +++ b/tests/sandbox/rl_trainer/main.py @@ -11,7 +11,6 @@ import torch import torchstore as ts from forge.actors.trainer import RLTrainer -from forge.cli.config import parse from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY from forge.controller.provisioner import init_provisioner, shutdown from forge.observability.metric_actors import get_or_create_metric_logger @@ -23,6 +22,7 @@ ProvisionerConfig, ServiceConfig, ) +from forge.util.config import parse from omegaconf import DictConfig from vllm.transformers_utils.tokenizer import get_tokenizer diff --git a/tests/sandbox/toy_rl/sumdigits.py b/tests/sandbox/toy_rl/sumdigits.py index 0668f8eca..01a0f3936 100644 --- a/tests/sandbox/toy_rl/sumdigits.py +++ b/tests/sandbox/toy_rl/sumdigits.py @@ -18,13 +18,13 @@ from forge.actors._torchstore_utils import get_param_key from forge.actors.generator import Generator from forge.actors.replay_buffer import ReplayBuffer -from forge.cli.config import parse from forge.controller.actor import ForgeActor from forge.controller.provisioner import shutdown from forge.losses.grpo_loss import SimpleGRPOLoss from forge.observability.metric_actors import get_or_create_metric_logger from forge.observability.metrics import record_metric, Reduce +from forge.util.config import parse from forge.util.ops import selective_log_softmax from monarch.actor import endpoint from omegaconf import DictConfig diff --git a/tests/sandbox/vllm/main.py b/tests/sandbox/vllm/main.py index 0d4652a6b..e9b001aa5 100644 --- a/tests/sandbox/vllm/main.py +++ b/tests/sandbox/vllm/main.py @@ -14,13 +14,13 @@ import os from forge.actors.generator import Generator -from forge.cli.config import parse from forge.controller.provisioner import init_provisioner, shutdown from forge.data_models.completion import Completion from forge.observability.metric_actors import get_or_create_metric_logger from forge.types import LauncherConfig, ProvisionerConfig +from forge.util.config import parse from omegaconf import DictConfig os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600" diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 69cc7e2ed..64a00c759 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -8,7 +8,7 @@ import pytest -from forge.cli.config import resolve_hf_hub_paths +from forge.util.config import resolve_hf_hub_paths from omegaconf import DictConfig, OmegaConf @@ -39,7 +39,7 @@ ({"level1": {"level2": {"model": "hf://deep/model"}}}, [("deep/model",)]), ], ) -@patch("forge.cli.config.snapshot_download") +@patch("forge.util.config.snapshot_download") def test_hf_path_resolution(mock_download, config_data, expected_calls): """Test hf:// path resolution in various data structures.""" mock_download.return_value = "/fake/cache/model" @@ -78,7 +78,7 @@ def test_non_hf_paths_unchanged(config_data): # Cache behavior tests -@patch("forge.cli.config.snapshot_download") +@patch("forge.util.config.snapshot_download") def test_cache_hit_scenario(mock_download): """Test behavior when model is already cached.""" mock_download.return_value = "/fake/cache/model" @@ -93,7 +93,7 @@ def test_cache_hit_scenario(mock_download): assert result.model == "/fake/cache/model" -@patch("forge.cli.config.snapshot_download") +@patch("forge.util.config.snapshot_download") def test_cache_miss_scenario(mock_download): """Test behavior when model is not cached.""" from huggingface_hub.utils import LocalEntryNotFoundError @@ -145,7 +145,7 @@ def test_invalid_hf_urls(invalid_hf_url, expected_error): assert expected_error in str(exc_info.value) -@patch("forge.cli.config.snapshot_download") +@patch("forge.util.config.snapshot_download") def test_download_failure_handling(mock_download): """Test error handling when download fails.""" mock_download.side_effect = Exception("Network error: Repository not found") @@ -159,7 +159,7 @@ def test_download_failure_handling(mock_download): # Integration test with mixed data types -@patch("forge.cli.config.snapshot_download") +@patch("forge.util.config.snapshot_download") def test_complex_real_world_config(mock_download): """Test with a realistic complex configuration.""" mock_download.return_value = "/fake/cache/model"