Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@
import torch.nn.functional as F
import torchstore as ts
from datasets import load_dataset
from forge.actors._torchstore_utils import (
get_dcp_whole_state_dict_key,
get_param_prefix,
)
from forge.actors.generator import Generator
from forge.actors.reference_model import ReferenceModel
from forge.actors.replay_buffer import ReplayBuffer
Expand All @@ -34,6 +30,7 @@
from forge.observability.perf_tracker import Tracer

from forge.types import LauncherConfig, ProvisionerConfig
from forge.util._torchstore import get_dcp_whole_state_dict_key, get_param_prefix
from forge.util.ops import compute_logprobs
from monarch.actor import endpoint
from omegaconf import DictConfig
Expand Down
16 changes: 8 additions & 8 deletions src/forge/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,6 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.worker.worker_base import WorkerWrapperBase

from forge.actors._torchstore_utils import (
extract_param_name,
get_dcp_whole_state_dict_key,
get_param_key,
get_param_prefix,
load_tensor_from_dcp,
)

from forge.controller import (
ForgeActor,
get_proc_mesh,
Expand All @@ -61,6 +53,14 @@
from forge.observability.perf_tracker import Tracer
from forge.types import ProcessConfig

from forge.util._torchstore import (
extract_param_name,
get_dcp_whole_state_dict_key,
get_param_key,
get_param_prefix,
load_tensor_from_dcp,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

Expand Down
12 changes: 6 additions & 6 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,18 @@
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

from forge.actors._torchstore_utils import (
DcpHandle,
get_dcp_whole_state_dict_key,
get_param_key,
)

from forge.controller import ForgeActor
from forge.data.utils import batch_to_device
from forge.env import TORCHSTORE_USE_RDMA
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer

from forge.util._torchstore import (
DcpHandle,
get_dcp_whole_state_dict_key,
get_param_key,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

Expand Down
2 changes: 2 additions & 0 deletions src/forge/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from . import _torchstore
from .distributed import get_world_size_and_rank
from .logging import get_logger, log_once, log_rank_zero
from .metric_logging import get_metric_logger
Expand All @@ -13,4 +14,5 @@
"log_once",
"log_rank_zero",
"get_metric_logger",
"_torchstore",
]
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/sandbox/toy_rl/sumdigits.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import torch
import torch.nn.functional as F
import torchstore as ts
from forge.actors._torchstore_utils import get_param_key
from forge.actors.generator import Generator
from forge.actors.replay_buffer import ReplayBuffer
from forge.cli.config import parse
Expand All @@ -25,6 +24,7 @@
from forge.observability.metric_actors import get_or_create_metric_logger

from forge.observability.metrics import record_metric, Reduce
from forge.util._torchstore import get_param_key
from forge.util.ops import selective_log_softmax
from monarch.actor import endpoint
from omegaconf import DictConfig
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/test_torchstore_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import torch
import torch.distributed.checkpoint as dcp
from forge.actors._torchstore_utils import DcpHandle
from forge.util._torchstore import DcpHandle

ignore_torch_distributed_unitialized_warning = pytest.mark.filterwarnings(
r"ignore:.*torch.distributed"
Expand Down
Loading