Skip to content

Commit 19e5100

Browse files
committed
Move torchstore_utils from actors to utils
1 parent 839664c commit 19e5100

File tree

7 files changed

+19
-20
lines changed

7 files changed

+19
-20
lines changed

apps/grpo/main.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@
1616
import torch.nn.functional as F
1717
import torchstore as ts
1818
from datasets import load_dataset
19-
from forge.actors._torchstore_utils import (
20-
get_dcp_whole_state_dict_key,
21-
get_param_prefix,
22-
)
2319
from forge.actors.generator import Generator
2420
from forge.actors.reference_model import ReferenceModel
2521
from forge.actors.replay_buffer import ReplayBuffer
@@ -34,6 +30,7 @@
3430
from forge.observability.perf_tracker import Tracer
3531

3632
from forge.types import LauncherConfig, ProvisionerConfig
33+
from forge.util._torchstore import get_dcp_whole_state_dict_key, get_param_prefix
3734
from forge.util.ops import compute_logprobs
3835
from monarch.actor import endpoint
3936
from omegaconf import DictConfig

src/forge/actors/generator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,6 @@
4040
from vllm.v1.structured_output import StructuredOutputManager
4141
from vllm.worker.worker_base import WorkerWrapperBase
4242

43-
from forge.actors._torchstore_utils import (
44-
extract_param_name,
45-
get_dcp_whole_state_dict_key,
46-
get_param_key,
47-
get_param_prefix,
48-
load_tensor_from_dcp,
49-
)
50-
5143
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
5244
from forge.data_models.completion import Completion
5345
from forge.data_models.prompt import to_prompt
@@ -56,6 +48,14 @@
5648
from forge.observability.perf_tracker import Tracer
5749
from forge.types import ProcessConfig
5850

51+
from forge.util._torchstore import (
52+
extract_param_name,
53+
get_dcp_whole_state_dict_key,
54+
get_param_key,
55+
get_param_prefix,
56+
load_tensor_from_dcp,
57+
)
58+
5959
logger = logging.getLogger(__name__)
6060
logger.setLevel(logging.INFO)
6161

src/forge/actors/trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,18 @@
3838
from torchtitan.experiments.forge.engine import ForgeEngine
3939
from torchtitan.experiments.forge.job_config import ForgeJobConfig
4040

41-
from forge.actors._torchstore_utils import (
42-
DcpHandle,
43-
get_dcp_whole_state_dict_key,
44-
get_param_key,
45-
)
46-
4741
from forge.controller import ForgeActor
4842
from forge.data.utils import batch_to_device
4943
from forge.env import TORCHSTORE_USE_RDMA
5044
from forge.observability.metrics import record_metric, Reduce
5145
from forge.observability.perf_tracker import Tracer
5246

47+
from forge.util._torchstore import (
48+
DcpHandle,
49+
get_dcp_whole_state_dict_key,
50+
get_param_key,
51+
)
52+
5353
logger = logging.getLogger(__name__)
5454
logger.setLevel(logging.DEBUG)
5555

src/forge/util/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
from . import _torchstore
67
from .distributed import get_world_size_and_rank
78
from .logging import get_logger, log_once, log_rank_zero
89
from .metric_logging import get_metric_logger
@@ -13,4 +14,5 @@
1314
"log_once",
1415
"log_rank_zero",
1516
"get_metric_logger",
17+
"_torchstore",
1618
]
File renamed without changes.

tests/sandbox/toy_rl/sumdigits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import torch
1616
import torch.nn.functional as F
1717
import torchstore as ts
18-
from forge.actors._torchstore_utils import get_param_key
1918
from forge.actors.generator import Generator
2019
from forge.actors.replay_buffer import ReplayBuffer
2120
from forge.cli.config import parse
@@ -25,6 +24,7 @@
2524
from forge.observability.metric_actors import get_or_create_metric_logger
2625

2726
from forge.observability.metrics import record_metric, Reduce
27+
from forge.util._torchstore import get_param_key
2828
from forge.util.ops import selective_log_softmax
2929
from monarch.actor import endpoint
3030
from omegaconf import DictConfig

tests/unit_tests/test_torchstore_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import torch
1616
import torch.distributed.checkpoint as dcp
17-
from forge.actors._torchstore_utils import DcpHandle
17+
from forge.util._torchstore import DcpHandle
1818

1919
ignore_torch_distributed_unitialized_warning = pytest.mark.filterwarnings(
2020
r"ignore:.*torch.distributed"

0 commit comments

Comments
 (0)