File tree Expand file tree Collapse file tree 7 files changed +19
-20
lines changed Expand file tree Collapse file tree 7 files changed +19
-20
lines changed Original file line number Diff line number Diff line change 1616import torch .nn .functional as F
1717import torchstore as ts
1818from datasets import load_dataset
19- from forge .actors ._torchstore_utils import (
20- get_dcp_whole_state_dict_key ,
21- get_param_prefix ,
22- )
2319from forge .actors .generator import Generator
2420from forge .actors .reference_model import ReferenceModel
2521from forge .actors .replay_buffer import ReplayBuffer
3430from forge .observability .perf_tracker import Tracer
3531
3632from forge .types import LauncherConfig , ProvisionerConfig
33+ from forge .util ._torchstore import get_dcp_whole_state_dict_key , get_param_prefix
3734from forge .util .ops import compute_logprobs
3835from monarch .actor import endpoint
3936from omegaconf import DictConfig
Original file line number Diff line number Diff line change 4040from vllm .v1 .structured_output import StructuredOutputManager
4141from vllm .worker .worker_base import WorkerWrapperBase
4242
43- from forge .actors ._torchstore_utils import (
44- extract_param_name ,
45- get_dcp_whole_state_dict_key ,
46- get_param_key ,
47- get_param_prefix ,
48- load_tensor_from_dcp ,
49- )
50-
5143from forge .controller import ForgeActor , get_proc_mesh , stop_proc_mesh
5244from forge .data_models .completion import Completion
5345from forge .data_models .prompt import to_prompt
5648from forge .observability .perf_tracker import Tracer
5749from forge .types import ProcessConfig
5850
51+ from forge .util ._torchstore import (
52+ extract_param_name ,
53+ get_dcp_whole_state_dict_key ,
54+ get_param_key ,
55+ get_param_prefix ,
56+ load_tensor_from_dcp ,
57+ )
58+
5959logger = logging .getLogger (__name__ )
6060logger .setLevel (logging .INFO )
6161
Original file line number Diff line number Diff line change 3838from torchtitan .experiments .forge .engine import ForgeEngine
3939from torchtitan .experiments .forge .job_config import ForgeJobConfig
4040
41- from forge .actors ._torchstore_utils import (
42- DcpHandle ,
43- get_dcp_whole_state_dict_key ,
44- get_param_key ,
45- )
46-
4741from forge .controller import ForgeActor
4842from forge .data .utils import batch_to_device
4943from forge .env import TORCHSTORE_USE_RDMA
5044from forge .observability .metrics import record_metric , Reduce
5145from forge .observability .perf_tracker import Tracer
5246
47+ from forge .util ._torchstore import (
48+ DcpHandle ,
49+ get_dcp_whole_state_dict_key ,
50+ get_param_key ,
51+ )
52+
5353logger = logging .getLogger (__name__ )
5454logger .setLevel (logging .DEBUG )
5555
Original file line number Diff line number Diff line change 33#
44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
6+ from . import _torchstore
67from .distributed import get_world_size_and_rank
78from .logging import get_logger , log_once , log_rank_zero
89from .metric_logging import get_metric_logger
1314 "log_once" ,
1415 "log_rank_zero" ,
1516 "get_metric_logger" ,
17+ "_torchstore" ,
1618]
File renamed without changes.
Original file line number Diff line number Diff line change 1515import torch
1616import torch .nn .functional as F
1717import torchstore as ts
18- from forge .actors ._torchstore_utils import get_param_key
1918from forge .actors .generator import Generator
2019from forge .actors .replay_buffer import ReplayBuffer
2120from forge .cli .config import parse
2524from forge .observability .metric_actors import get_or_create_metric_logger
2625
2726from forge .observability .metrics import record_metric , Reduce
27+ from forge .util ._torchstore import get_param_key
2828from forge .util .ops import selective_log_softmax
2929from monarch .actor import endpoint
3030from omegaconf import DictConfig
Original file line number Diff line number Diff line change 1414
1515import torch
1616import torch .distributed .checkpoint as dcp
17- from forge .actors . _torchstore_utils import DcpHandle
17+ from forge .util . _torchstore import DcpHandle
1818
1919ignore_torch_distributed_unitialized_warning = pytest .mark .filterwarnings (
2020 r"ignore:.*torch.distributed"
You can’t perform that action at this time.
0 commit comments