Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .meta/mast/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import sys

from apps.grpo.main import main as grpo_main
from forge.cli.config import parse
from forge.controller.launcher import (
JOB_NAME_KEY,
LAUNCHER_KEY,
Expand All @@ -25,6 +24,7 @@
ProvisionerConfig,
ServiceConfig,
)
from forge.util.config import parse
from omegaconf import DictConfig

DEFAULT_CHECKPOINT_FOLDER_KEY = "checkpoint_folder"
Expand Down
2 changes: 1 addition & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from forge.actors.reference_model import ReferenceModel
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.trainer import RLTrainer
from forge.cli.config import parse
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import init_provisioner, shutdown
from forge.data.rewards import MathReward, ThinkingReward
Expand All @@ -34,6 +33,7 @@
from forge.observability.perf_tracker import Tracer

from forge.types import LauncherConfig, ProvisionerConfig
from forge.util.config import parse
from forge.util.ops import compute_logprobs
from monarch.actor import endpoint
from omegaconf import DictConfig
Expand Down
2 changes: 1 addition & 1 deletion apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
import torch

import torchtitan.experiments.forge.train_spec as forge_train_spec
from forge.cli.config import parse
from forge.controller import ForgeActor
from forge.data.collate import collate_packed
from forge.data.datasets.packed import PackedDataset, TextPacker
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
from forge.data.tokenizer import HuggingFaceModelTokenizer
from forge.util.config import parse

from monarch.actor import current_rank, current_size, endpoint
from omegaconf import DictConfig, OmegaConf
Expand Down
5 changes: 0 additions & 5 deletions src/forge/cli/__init__.py

This file was deleted.

File renamed without changes.
2 changes: 1 addition & 1 deletion tests/integration_tests/test_policy_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from forge.actors.generator import Generator

from forge.actors.trainer import RLTrainer
from forge.cli.config import resolve_hf_hub_paths
from forge.controller.provisioner import init_provisioner

from forge.controller.service.service import uuid
from forge.types import LauncherConfig, ProvisionerConfig
from forge.util.config import resolve_hf_hub_paths
from monarch.actor import endpoint

from omegaconf import DictConfig, OmegaConf
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/test_titan_fwd_vs_hf_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
import torch

from forge.actors.reference_model import ReferenceModel
from forge.cli.config import _resolve_hf_model_path
from forge.controller import ForgeActor
from forge.controller.provisioner import shutdown
from forge.util.config import _resolve_hf_model_path
from monarch.actor import endpoint
from torchtitan.config.job_config import Checkpoint, Compile, Model, Parallelism
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand Down
2 changes: 1 addition & 1 deletion tests/sandbox/rl_trainer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch
import torchstore as ts
from forge.actors.trainer import RLTrainer
from forge.cli.config import parse
from forge.controller.launcher import JOB_NAME_KEY, LAUNCHER_KEY
from forge.controller.provisioner import init_provisioner, shutdown
from forge.observability.metric_actors import get_or_create_metric_logger
Expand All @@ -23,6 +22,7 @@
ProvisionerConfig,
ServiceConfig,
)
from forge.util.config import parse
from omegaconf import DictConfig
from vllm.transformers_utils.tokenizer import get_tokenizer

Expand Down
2 changes: 1 addition & 1 deletion tests/sandbox/toy_rl/sumdigits.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
from forge.actors._torchstore_utils import get_param_key
from forge.actors.generator import Generator
from forge.actors.replay_buffer import ReplayBuffer
from forge.cli.config import parse
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import shutdown
from forge.losses.grpo_loss import SimpleGRPOLoss
from forge.observability.metric_actors import get_or_create_metric_logger

from forge.observability.metrics import record_metric, Reduce
from forge.util.config import parse
from forge.util.ops import selective_log_softmax
from monarch.actor import endpoint
from omegaconf import DictConfig
Expand Down
2 changes: 1 addition & 1 deletion tests/sandbox/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
import os

from forge.actors.generator import Generator
from forge.cli.config import parse

from forge.controller.provisioner import init_provisioner, shutdown

from forge.data_models.completion import Completion
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.types import LauncherConfig, ProvisionerConfig
from forge.util.config import parse
from omegaconf import DictConfig

os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600"
Expand Down
12 changes: 6 additions & 6 deletions tests/unit_tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import pytest

from forge.cli.config import resolve_hf_hub_paths
from forge.util.config import resolve_hf_hub_paths
from omegaconf import DictConfig, OmegaConf


Expand Down Expand Up @@ -39,7 +39,7 @@
({"level1": {"level2": {"model": "hf://deep/model"}}}, [("deep/model",)]),
],
)
@patch("forge.cli.config.snapshot_download")
@patch("forge.util.config.snapshot_download")
def test_hf_path_resolution(mock_download, config_data, expected_calls):
"""Test hf:// path resolution in various data structures."""
mock_download.return_value = "/fake/cache/model"
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_non_hf_paths_unchanged(config_data):


# Cache behavior tests
@patch("forge.cli.config.snapshot_download")
@patch("forge.util.config.snapshot_download")
def test_cache_hit_scenario(mock_download):
"""Test behavior when model is already cached."""
mock_download.return_value = "/fake/cache/model"
Expand All @@ -93,7 +93,7 @@ def test_cache_hit_scenario(mock_download):
assert result.model == "/fake/cache/model"


@patch("forge.cli.config.snapshot_download")
@patch("forge.util.config.snapshot_download")
def test_cache_miss_scenario(mock_download):
"""Test behavior when model is not cached."""
from huggingface_hub.utils import LocalEntryNotFoundError
Expand Down Expand Up @@ -145,7 +145,7 @@ def test_invalid_hf_urls(invalid_hf_url, expected_error):
assert expected_error in str(exc_info.value)


@patch("forge.cli.config.snapshot_download")
@patch("forge.util.config.snapshot_download")
def test_download_failure_handling(mock_download):
"""Test error handling when download fails."""
mock_download.side_effect = Exception("Network error: Repository not found")
Expand All @@ -159,7 +159,7 @@ def test_download_failure_handling(mock_download):


# Integration test with mixed data types
@patch("forge.cli.config.snapshot_download")
@patch("forge.util.config.snapshot_download")
def test_complex_real_world_config(mock_download):
"""Test with a realistic complex configuration."""
mock_download.return_value = "/fake/cache/model"
Expand Down
Loading