diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 1dbef0b76..723813e34 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -16,10 +16,6 @@ import torch.nn.functional as F import torchstore as ts from datasets import load_dataset -from forge.actors._torchstore_utils import ( - get_dcp_whole_state_dict_key, - get_param_prefix, -) from forge.actors.generator import Generator from forge.actors.reference_model import ReferenceModel from forge.actors.replay_buffer import ReplayBuffer @@ -34,6 +30,7 @@ from forge.observability.perf_tracker import Tracer from forge.types import LauncherConfig, ProvisionerConfig +from forge.util._torchstore import get_dcp_whole_state_dict_key, get_param_prefix from forge.util.ops import compute_logprobs from monarch.actor import endpoint from omegaconf import DictConfig diff --git a/src/forge/actors/generator.py b/src/forge/actors/generator.py index e04bed5a8..4e3a81a15 100644 --- a/src/forge/actors/generator.py +++ b/src/forge/actors/generator.py @@ -40,14 +40,6 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase -from forge.actors._torchstore_utils import ( - extract_param_name, - get_dcp_whole_state_dict_key, - get_param_key, - get_param_prefix, - load_tensor_from_dcp, -) - from forge.controller import ( ForgeActor, get_proc_mesh, @@ -61,6 +53,14 @@ from forge.observability.perf_tracker import Tracer from forge.types import ProcessConfig +from forge.util._torchstore import ( + extract_param_name, + get_dcp_whole_state_dict_key, + get_param_key, + get_param_prefix, + load_tensor_from_dcp, +) + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 71049bc52..9535946fc 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -37,18 +37,18 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.actors._torchstore_utils import ( - DcpHandle, - get_dcp_whole_state_dict_key, - get_param_key, -) - from forge.controller import ForgeActor from forge.data.utils import batch_to_device from forge.env import TORCHSTORE_USE_RDMA from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer +from forge.util._torchstore import ( + DcpHandle, + get_dcp_whole_state_dict_key, + get_param_key, +) + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) diff --git a/src/forge/util/__init__.py b/src/forge/util/__init__.py index 5fb03b0f9..c4fcd4ca0 100644 --- a/src/forge/util/__init__.py +++ b/src/forge/util/__init__.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from . import _torchstore from .distributed import get_world_size_and_rank from .logging import get_logger, log_once, log_rank_zero from .metric_logging import get_metric_logger @@ -13,4 +14,5 @@ "log_once", "log_rank_zero", "get_metric_logger", + "_torchstore", ] diff --git a/src/forge/actors/_torchstore_utils.py b/src/forge/util/_torchstore.py similarity index 100% rename from src/forge/actors/_torchstore_utils.py rename to src/forge/util/_torchstore.py diff --git a/tests/sandbox/toy_rl/sumdigits.py b/tests/sandbox/toy_rl/sumdigits.py index 0668f8eca..54781fab5 100644 --- a/tests/sandbox/toy_rl/sumdigits.py +++ b/tests/sandbox/toy_rl/sumdigits.py @@ -15,7 +15,6 @@ import torch import torch.nn.functional as F import torchstore as ts -from forge.actors._torchstore_utils import get_param_key from forge.actors.generator import Generator from forge.actors.replay_buffer import ReplayBuffer from forge.cli.config import parse @@ -25,6 +24,7 @@ from forge.observability.metric_actors import get_or_create_metric_logger from forge.observability.metrics import record_metric, Reduce +from forge.util._torchstore import get_param_key from forge.util.ops import selective_log_softmax from monarch.actor import endpoint from omegaconf import DictConfig diff --git a/tests/unit_tests/test_torchstore_utils.py b/tests/unit_tests/test_torchstore_utils.py index 6a2e23fbf..3ddafdf1b 100644 --- a/tests/unit_tests/test_torchstore_utils.py +++ b/tests/unit_tests/test_torchstore_utils.py @@ -14,7 +14,7 @@ import torch import torch.distributed.checkpoint as dcp -from forge.actors._torchstore_utils import DcpHandle +from forge.util._torchstore import DcpHandle ignore_torch_distributed_unitialized_warning = pytest.mark.filterwarnings( r"ignore:.*torch.distributed"