Skip to content

Conversation

pradeepfn
Copy link
Contributor

@pradeepfn pradeepfn commented Sep 16, 2025

In this PR;

  1. We expand the single worker test case that share weights between RLTrainer and Policy engine with multiple workers.
  2. We also introduce tensor-parallelism in the trainer as a start to get e2e signals for different parallelism strategies.
  3. Currently trainer_num_workers =2, policy_num_worker = 2 and the TP = 2.
  4. Given torchstore maintains a local storage only for each host, the test, in current form does not represent multi-node weight sync.
  5. But supporting/testing multiple storage volumes of torchstore is straightforward. We just have to switch default-strategy to local-strategy.
  6. This will be done in an immediate followup diff.

Summary:

  1. Intra node weight sync between N -> N workers/ranks, with single storage volume - done (this PR).
  2. Intra node weight sync between N->N workers/ranks, with multiple storage volumes - next PR/wip.
  3. Inter node weight sync between N->N workers/ranks, with multiple storage volumes - Not started yet. Should work OOTB.

Test:

pytest tests/integration_tests/test_policy_update.py::TestWeightSync::test_policy_update_tp

@pradeepfn pradeepfn marked this pull request as draft September 16, 2025 16:10
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 16, 2025
@pradeepfn
Copy link
Contributor Author

I'm encountering following error, even when running the already established single worker UT.

2025-09-16T15:51:09.716656Z ERROR error invoking get_reconstruction api, error: ReqwestMiddlewareError(Middleware(Request failed after 5 retries

Caused by:
0: error sending request for url (https://cas-server.xethub.hf.co/reconstruction/17a41bb52c7f7ac39f0060f9d12213e0448696767d13dfddc10fc00f79a84a9d)
1: client error (Connect)
2: tunnel error: failed to create underlying connection
3: dns error
4: failed to lookup address information: Name or service not known)), caller: "/home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:322"
at /home/runner/work/xet-core/xet-core/error_printer/src/lib.rs:28

2025-09-16T15:51:18.623345Z WARN Reqwest(reqwest::Error { kind: Request, url: "https://cas-server.xethub.hf.co/reconstruction/e2a9405ede2b482e4d566aadf8bf470bd4e42542ef327709edadf724e1950d47", source: hyper_util::client::legacy::Error(Connect, ConnectFailed(ConnectError("dns error", Custom { kind: Uncategorized, error: "failed to lookup address information: Name or service not known" }))) }). Retrying...
at /home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:226

2025-09-16T15:51:18.623379Z ERROR error invoking get_reconstruction api, error: ReqwestMiddlewareError(Middleware(Request failed after 5 retries

Caused by:
0: error sending request for url (https://cas-server.xethub.hf.co/reconstruction/e2a9405ede2b482e4d566aadf8bf470bd4e42542ef327709edadf724e1950d47)
1: client error (Connect)
2: tunnel error: failed to create underlying connection
3: dns error
4: failed to lookup address information: Name or service not known)), caller: "/home/runner/work/xet-core/xet-core/cas_client/src/http_client.rs:322"
at /home/runner/work/xet-core/xet-core/error_printer/src/lib.rs:28

------------------------------------------------------------------------------------- Captured stderr setup -------------------------------------------------------------------------------------
Fetching 2 files: 0%| | 0/2 [00:23<?, ?it/s]
==================================================================================== short test summary info ====================================================================================
ERROR tests/integration_tests/test_policy_update.py::TestWeightSync::test_policy_update_single - RuntimeError: Data processing error: CAS service error : ReqwestMiddleware Error: Request failed after 5 retries
======================================================================================= 1 error in 39.88s ===

@Jack-Khuu
Copy link
Contributor

Can you try HF_HUB_DISABLE_XET=1 to get around the CAS error

@pradeepfn
Copy link
Contributor Author

Getting following error.

Aggregated Logs (2025-09-16 19:55:17) >>>
[2 similar log lines] Traceback (most recent call last):
[16 similar log lines] File "/home/pradeepfdo/.conda/envs/forge/lib/python3.10/site-packages/monarch/_src/actor/actor_mesh.py", line 827, in instrumented
[4 similar log lines] result = await the_method(*args, **kwargs)
[4 similar log lines] File "/home/pradeepfdo/forge_fork/src/forge/actors/policy.py", line 374, in update_weights
[2 similar log lines] await self.policy_worker.update.call(version=self.weights_version)
[4 similar log lines] File "/home/pradeepfdo/.conda/envs/forge/lib/python3.10/site-packages/monarch/_src/actor/future.py", line 136, in mark_complete
[2 similar log lines] func, value = fut.set_result, await coro
[2 similar log lines] rank, value = await r._recv()
[2 similar log lines] return self._process(result)
[2 similar log lines] return rank, super()._process(msg)
[2 similar log lines] raise cast(Exception, payload)
[2 similar log lines] monarch._src.actor.actor_mesh.ActorError: A remote actor call has failed.
[2 similar log lines] Traceback of where the remote call failed (most recent call last):
[3 similar log lines] result = await instrumented()
[3 similar log lines] raise e
[2 similar log lines] await self._load_tensor_parallel_state_dict(current_state_dict, version)
[2 similar log lines] File "/home/pradeepfdo/forge_fork/src/forge/actors/policy.py", line 429, in _load_tensor_parallel_state_dict
[2 similar log lines] sharding.load_from_source_to_target(
[2 similar log lines] File "/home/pradeepfdo/forge_fork/src/forge/data/sharding.py", line 43, in load_from_source_to_target
[2 similar log lines] sharded_tensor = self._calculate_tensor_shard(
[2 similar log lines] File "/home/pradeepfdo/forge_fork/src/forge/data/sharding.py", line 122, in _calculate_tensor_shard
[2 similar log lines] tensor_size = full_tensor.shape[shard_dim]
[2 similar log lines] AttributeError: 'dict' object has no attribute 'shape'
[2 similar log lines]

@pradeepfn
Copy link
Contributor Author

@joecummings confirmed that there is a known issue in the DP sharding. Thanks Joe!. Going to try out TP.

@pradeepfn
Copy link
Contributor Author

Apparently, meta-pytorch/torchstore#32 solves the issue. Let me test.

@pradeepfn
Copy link
Contributor Author

pull/32 fixed the issue.

@pradeepfn pradeepfn marked this pull request as ready for review September 18, 2025 03:50
@pradeepfn
Copy link
Contributor Author

Screenshot 2025-09-18 at 8 56 05 AM

Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FIX THE LINT :)

@pradeepfn pradeepfn merged commit 8d763cf into meta-pytorch:main Sep 19, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants