Skip to content

Commit 6ebe50d

Browse files
authored
Enable RDMA by default (meta-pytorch#461)
1 parent 56c3528 commit 6ebe50d

File tree

4 files changed

+30
-21
lines changed

4 files changed

+30
-21
lines changed

src/forge/actors/_torchstore_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
import torch.distributed.checkpoint as dcp
1212
from torch.distributed.checkpoint.metadata import Metadata as DcpMeta
13+
from torchstore.transport.buffers import rdma_available
1314

1415
logger = logging.getLogger(__name__)
1516
logger.setLevel(logging.DEBUG)
@@ -69,3 +70,8 @@ def extract_param_name(key: str) -> str:
6970

7071
def get_dcp_whole_state_dict_key(policy_version: int) -> str:
7172
return f"{get_param_prefix(policy_version)}{KEY_DELIM}{DCP_WHOLE_STATE_TAG}"
73+
74+
75+
def rdma_enabled() -> bool:
76+
"""Return if TorchStore thinks we're using RDMA"""
77+
return rdma_available()

src/forge/actors/generator.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
get_param_key,
4747
get_param_prefix,
4848
load_tensor_from_dcp,
49+
rdma_available,
4950
)
5051

5152
from forge.controller import (
@@ -56,7 +57,6 @@
5657
)
5758
from forge.data_models.completion import Completion
5859
from forge.data_models.prompt import to_prompt
59-
from forge.env import TORCHSTORE_USE_RDMA
6060
from forge.observability.metrics import record_metric, Reduce
6161
from forge.observability.perf_tracker import Tracer
6262
from forge.types import ProcessConfig
@@ -112,7 +112,7 @@ def __post_init__(self):
112112
self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
113113

114114
if self.use_dcp_for_weight_sync is None:
115-
self.use_dcp_for_weight_sync = not TORCHSTORE_USE_RDMA.get_value()
115+
self.use_dcp_for_weight_sync = not rdma_available()
116116
logger.debug(f"{self.use_dcp_for_weight_sync=}")
117117

118118
@endpoint
@@ -492,14 +492,16 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
492492
await stop_proc_mesh(actor._generator_proc)
493493

494494
@endpoint
495-
async def save_model_params(self):
496-
"""Used for debugging purpose. Save model parameters before weight update."""
497-
await self.worker.save_model_params.call()
495+
async def _test_save_model_params(self):
496+
"""Save model parameters before weight update, used for tesing purposes only."""
497+
logger.info("[Generator] save model parameters for testing.")
498+
await self.worker._test_save_model_params.call()
498499

499500
@endpoint
500-
async def validate_model_params(self, validate_fn):
501-
"""Used for debugging purpose. Validate saved params using validate_fn."""
502-
return await self.worker.validate_model_params.call(validate_fn)
501+
async def _test_validate_model_params(self, validate_fn):
502+
"""Validate updated model params using validate_fn."""
503+
logger.info("[Generator] start validating model parameters.")
504+
return await self.worker._test_validate_model_params.call(validate_fn)
503505

504506

505507
@dataclass
@@ -512,6 +514,8 @@ class GeneratorWorker(ForgeActor):
512514
"""
513515

514516
vllm_config: VllmConfig
517+
# TODO: Remove below param
518+
_test_prev_params = {}
515519

516520
@endpoint
517521
async def setup(self):
@@ -601,19 +605,20 @@ async def update_weights(self, version: int) -> None:
601605
t.stop()
602606

603607
@endpoint
604-
async def save_model_params(self):
605-
"""Used for debugging purposes. Save model parameters before weight update."""
606-
self._debug_saved_params = {}
608+
async def _test_save_model_params(self):
609+
"""Save model parameters before weight update, used for tesing purposes only."""
610+
logger.info("[GeneratorWorker] save model parameters for testing.")
607611
for name, param in self.worker.model_runner.model.named_parameters():
608-
self._debug_saved_params[name] = param.detach().cpu()
612+
self._test_prev_params[name] = param.detach().cpu()
609613
logger.info(
610614
"[GeneratorWorker] finished saving model parameters, len = %d",
611-
len(self._debug_saved_params),
615+
len(self._test_prev_params),
612616
)
613617

614618
@endpoint
615-
async def validate_model_params(self, validate_fn):
616-
"""Used for debugging purposes. Validate saved params using validate_fn."""
619+
async def _test_validate_model_params(self, validate_fn):
620+
"""Validate updated model params using validate_fn."""
621+
logger.info("[GeneratorWorker] start validating model parameters.")
617622
return validate_fn(
618-
self._debug_saved_params, self.worker.model_runner.model, logger
623+
self._test_prev_params, self.worker.model_runner.model, logger
619624
)

src/forge/actors/trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@
4141
DcpHandle,
4242
get_dcp_whole_state_dict_key,
4343
get_param_key,
44+
rdma_available,
4445
)
4546

4647
from forge.controller import ForgeActor
4748
from forge.data.utils import batch_to_device
48-
from forge.env import TORCHSTORE_USE_RDMA
4949
from forge.observability.metrics import record_metric, Reduce
5050
from forge.observability.perf_tracker import Tracer
5151

@@ -131,9 +131,7 @@ class RLTrainer(ForgeActor):
131131
# Non JobConfig-related fields
132132
loss: Callable = lambda logits, **targets: logits
133133
state_dict_key: str = "model_state_dict"
134-
use_dcp: bool = (
135-
TORCHSTORE_USE_RDMA.get_value() == 0
136-
) # torchstore currently only accepts 0 or 1
134+
use_dcp: bool = not rdma_available()
137135
dcp_path: str = "forge_dcp_tmp"
138136

139137
def __post_init__(self):

src/forge/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def get_value(self) -> Any:
101101

102102
TORCHSTORE_USE_RDMA = EnvVar(
103103
name="TORCHSTORE_RDMA_ENABLED",
104-
default=0,
104+
default=1,
105105
description="Whether or not to use RDMA in TorchStore.",
106106
)
107107

0 commit comments

Comments
 (0)