Skip to content

Conversation

ankitageorge
Copy link
Contributor

@ankitageorge ankitageorge commented Aug 18, 2025

Add logic that loads the model from torchstore into vLLM. Handles single rank and distributed case.
Adds a test that writes the model to torchstore and the reads from it with the changes to the update method in the policy actor
Output:

forge-8e46c2a) [[email protected] ~/forge (torchstore-testing)]$ pytest -v -s tests/integration_tests/test_policy_update.py::test_llama3_policy_update_tp
================================================================= test session starts =================================================================
platform linux -- Python 3.10.18, pytest-8.4.1, pluggy-1.6.0 -- /home/ankitageorge/.fbpkg_conda_envs/forge-8e46c2a/bin/python3.10
cachedir: .pytest_cache
hypothesis profile 'default'
rootdir: /home/ankitageorge/forge
configfile: pyproject.toml
plugins: cov-6.2.1, hypothesis-6.136.6, asyncio-1.1.0, anyio-4.10.0, timeout-2.4.0, typeguard-4.2.1
asyncio: mode=strict, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collecting ... INFO 08-20 14:26:01 [__init__.py:235] Automatically detected platform cuda.
collected 1 item                                                                                                                                      

tests/integration_tests/test_policy_update.py::test_llama3_policy_update_tp === PHASE 1: Writing Llama 3.1 8B-Instruct to TorchStore ===
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00,  1.07s/it]
Original state dict has 291 parameters
Converting transformers state dict to vLLM format...
Converted state dict has 195 parameters
Saving 195 tensors
Successfully saved 195 tensors
Successfully wrote converted state dict to torchstore with key: llama3_8b_state_dict
Starting tensor parallel test (load full state dict into sharded model)...

=== PHASE 2: Testing Policy Integration (GPUs: 2) ===
Using MASTER_PORT: 55665 for tensor parallel Policy
INFO 08-20 14:27:10 [__init__.py:235] Automatically detected platform cuda.
INFO 08-20 14:27:11 [__init__.py:235] Automatically detected platform cuda.
INFO 08-20 14:27:19 [config.py:1604] Using max model len 131072
INFO 08-20 14:27:19 [config.py:2434] Chunked prefill is enabled with max_num_batched_tokens=16384.
WARNING 08-20 14:27:19 [__init__.py:2899] We must use the `spawn` multiprocessing start method. Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See https://docs.vllm.ai/en/latest/usage/troubleshooting.html#python-multiprocessing for more information. Reason: CUDA is initialized
WARNING 08-20 14:27:19 [multiproc_worker_utils.py:307] Reducing Torch parallelism from 112 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 08-20 14:27:19 [config.py:1604] Using max model len 131072
INFO 08-20 14:27:19 [config.py:2434] Chunked prefill is enabled with max_num_batched_tokens=16384.
WARNING 08-20 14:27:19 [__init__.py:2899] We must use the `spawn` multiprocessing start method. Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See https://docs.vllm.ai/en/latest/usage/troubleshooting.html#python-multiprocessing for more information. Reason: CUDA is initialized
WARNING 08-20 14:27:19 [multiproc_worker_utils.py:307] Reducing Torch parallelism from 112 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
WARNING: Logging before InitGoogleLogging() is written to STDERR
W0820 14:27:21.482237 4030958 ProcessGroupNCCL.cpp:915] Warning: TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated. (function operator())
I0820 14:27:21.482263 4030958 ProcessGroupNCCL.cpp:1669] [PG ID 0 PG GUID 0 Rank 1] HeartbeatMonitor environments: TORCH_NCCL_ENABLE_MONITORING (Whether to kill program when no watchdog heartbeat detected): 1, TORCH_NCCL_DUMP_ON_TIMEOUT: 1, TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: 15000, TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: 480, TORCH_NCCL_COORD_CHECK_MILSEC: 1000, TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: 1
I0820 14:27:21.482270 4030958 ProcessGroupNCCL.cpp:1998] [PG ID 0 PG GUID 0 Rank 1] PGNCCL Watchdog environments: TORCH_NCCL_RETHROW_CUDA_ERRORS: 1, TORCH_NCCL_PROPAGATE_ERROR: 0, TORCH_NCCL_DESYNC_DEBUG: 0
I0820 14:27:21.482344 4030958 ProcessGroupNCCL.cpp:971] [PG ID 0 PG GUID 0 Rank 1] ProcessGroupNCCL initialization options: size: 2, global rank: 1, TIMEOUT(ms): 600000, USE_HIGH_PRIORITY_STREAM: 0, SPLIT_FROM: 0, SPLIT_COLOR: -2, PG Name: 0
I0820 14:27:21.482348 4030958 ProcessGroupNCCL.cpp:980] [PG ID 0 PG GUID 0 Rank 1] ProcessGroupNCCL environments: NCCL version: 2.25.1, TORCH_NCCL_ASYNC_ERROR_HANDLING: 3, TORCH_NCCL_ENABLE_TIMING: 0, TORCH_NCCL_BLOCKING_WAIT: 0, TORCH_DISTRIBUTED_DEBUG: OFF, TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: 0, TORCH_NCCL_TRACE_BUFFER_SIZE: 2000, TORCH_NCCL_NAN_CHECK: 0, TORCH_NCCL_CUDA_EVENT_CACHE: 1
I0820 14:27:21.482630 4030958 ProcessGroupNCCL.cpp:1669] [PG ID 1 PG GUID 1 Rank 1] HeartbeatMonitor environments: TORCH_NCCL_ENABLE_MONITORING (Whether to kill program when no watchdog heartbeat detected): 1, TORCH_NCCL_DUMP_ON_TIMEOUT: 1, TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: 15000, TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: 480, TORCH_NCCL_COORD_CHECK_MILSEC: 1000, TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: 1
I0820 14:27:21.482635 4030958 ProcessGroupNCCL.cpp:1998] [PG ID 1 PG GUID 1 Rank 1] PGNCCL Watchdog environments: TORCH_NCCL_RETHROW_CUDA_ERRORS: 1, TORCH_NCCL_PROPAGATE_ERROR: 0, TORCH_NCCL_DESYNC_DEBUG: 0
I0820 14:27:21.482715 4030958 ProcessGroupNCCL.cpp:971] [PG ID 1 PG GUID 1 Rank 1] ProcessGroupNCCL initialization options: size: 2, global rank: 1, TIMEOUT(ms): 600000, USE_HIGH_PRIORITY_STREAM: 0, SPLIT_FROM: 0, SPLIT_COLOR: -2, PG Name: 1
I0820 14:27:21.482718 4030958 ProcessGroupNCCL.cpp:980] [PG ID 1 PG GUID 1 Rank 1] ProcessGroupNCCL environments: NCCL version: 2.25.1, TORCH_NCCL_ASYNC_ERROR_HANDLING: 3, TORCH_NCCL_ENABLE_TIMING: 0, TORCH_NCCL_BLOCKING_WAIT: 0, TORCH_DISTRIBUTED_DEBUG: OFF, TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: 0, TORCH_NCCL_TRACE_BUFFER_SIZE: 2000, TORCH_NCCL_NAN_CHECK: 0, TORCH_NCCL_CUDA_EVENT_CACHE: 1
WARNING: Logging before InitGoogleLogging() is written to STDERR
W0820 14:27:21.486436 4030959 ProcessGroupNCCL.cpp:915] Warning: TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated. (function operator())
I0820 14:27:21.486460 4030959 ProcessGroupNCCL.cpp:1669] [PG ID 0 PG GUID 0 Rank 0] HeartbeatMonitor environments: TORCH_NCCL_ENABLE_MONITORING (Whether to kill program when no watchdog heartbeat detected): 1, TORCH_NCCL_DUMP_ON_TIMEOUT: 1, TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: 15000, TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: 480, TORCH_NCCL_COORD_CHECK_MILSEC: 1000, TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: 1
I0820 14:27:21.486465 4030959 ProcessGroupNCCL.cpp:1998] [PG ID 0 PG GUID 0 Rank 0] PGNCCL Watchdog environments: TORCH_NCCL_RETHROW_CUDA_ERRORS: 1, TORCH_NCCL_PROPAGATE_ERROR: 0, TORCH_NCCL_DESYNC_DEBUG: 0
I0820 14:27:21.486558 4030959 ProcessGroupNCCL.cpp:971] [PG ID 0 PG GUID 0 Rank 0] ProcessGroupNCCL initialization options: size: 2, global rank: 0, TIMEOUT(ms): 600000, USE_HIGH_PRIORITY_STREAM: 0, SPLIT_FROM: 0, SPLIT_COLOR: -2, PG Name: 0
I0820 14:27:21.486563 4030959 ProcessGroupNCCL.cpp:980] [PG ID 0 PG GUID 0 Rank 0] ProcessGroupNCCL environments: NCCL version: 2.25.1, TORCH_NCCL_ASYNC_ERROR_HANDLING: 3, TORCH_NCCL_ENABLE_TIMING: 0, TORCH_NCCL_BLOCKING_WAIT: 0, TORCH_DISTRIBUTED_DEBUG: OFF, TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: 0, TORCH_NCCL_TRACE_BUFFER_SIZE: 2000, TORCH_NCCL_NAN_CHECK: 0, TORCH_NCCL_CUDA_EVENT_CACHE: 1
I0820 14:27:21.486845 4030959 ProcessGroupNCCL.cpp:1669] [PG ID 1 PG GUID 1 Rank 0] HeartbeatMonitor environments: TORCH_NCCL_ENABLE_MONITORING (Whether to kill program when no watchdog heartbeat detected): 1, TORCH_NCCL_DUMP_ON_TIMEOUT: 1, TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: 15000, TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: 480, TORCH_NCCL_COORD_CHECK_MILSEC: 1000, TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: 1
I0820 14:27:21.486850 4030959 ProcessGroupNCCL.cpp:1998] [PG ID 1 PG GUID 1 Rank 0] PGNCCL Watchdog environments: TORCH_NCCL_RETHROW_CUDA_ERRORS: 1, TORCH_NCCL_PROPAGATE_ERROR: 0, TORCH_NCCL_DESYNC_DEBUG: 0
I0820 14:27:21.486914 4030959 ProcessGroupNCCL.cpp:971] [PG ID 1 PG GUID 1 Rank 0] ProcessGroupNCCL initialization options: size: 2, global rank: 0, TIMEOUT(ms): 600000, USE_HIGH_PRIORITY_STREAM: 0, SPLIT_FROM: 0, SPLIT_COLOR: -2, PG Name: 1
I0820 14:27:21.486918 4030959 ProcessGroupNCCL.cpp:980] [PG ID 1 PG GUID 1 Rank 0] ProcessGroupNCCL environments: NCCL version: 2.25.1, TORCH_NCCL_ASYNC_ERROR_HANDLING: 3, TORCH_NCCL_ENABLE_TIMING: 0, TORCH_NCCL_BLOCKING_WAIT: 0, TORCH_DISTRIBUTED_DEBUG: OFF, TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: 0, TORCH_NCCL_TRACE_BUFFER_SIZE: 2000, TORCH_NCCL_NAN_CHECK: 0, TORCH_NCCL_CUDA_EVENT_CACHE: 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
I0820 14:27:21.501228 4030958 ProcessGroupNCCL.cpp:1669] [PG ID 2 PG GUID 3 Rank 1] HeartbeatMonitor environments: TORCH_NCCL_ENABLE_MONITORING (Whether to kill program when no watchdog heartbeat detected): 1, TORCH_NCCL_DUMP_ON_TIMEOUT: 1, TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: 15000, TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: 480, TORCH_NCCL_COORD_CHECK_MILSEC: 1000, TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: 1
I0820 14:27:21.501257 4030959 ProcessGroupNCCL.cpp:1669] [PG ID 2 PG GUID 3 Rank 0] HeartbeatMonitor environments: TORCH_NCCL_ENABLE_MONITORING (Whether to kill program when no watchdog heartbeat detected): 1, TORCH_NCCL_DUMP_ON_TIMEOUT: 1, TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: 15000, TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: 480, TORCH_NCCL_COORD_CHECK_MILSEC: 1000, TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: 1
I0820 14:27:21.501240 4030958 ProcessGroupNCCL.cpp:1998] [PG ID 2 PG GUID 3 Rank 1] PGNCCL Watchdog environments: TORCH_NCCL_RETHROW_CUDA_ERRORS: 1, TORCH_NCCL_PROPAGATE_ERROR: 0, TORCH_NCCL_DESYNC_DEBUG: 0
I0820 14:27:21.501312 4030958 ProcessGroupNCCL.cpp:971] [PG ID 2 PG GUID 3 Rank 1] ProcessGroupNCCL initialization options: size: 2, global rank: 1, TIMEOUT(ms): 600000, USE_HIGH_PRIORITY_STREAM: 0, SPLIT_FROM: 0, SPLIT_COLOR: -2, PG Name: 3
I0820 14:27:21.501268 4030959 ProcessGroupNCCL.cpp:1998] [PG ID 2 PG GUID 3 Rank 0] PGNCCL Watchdog environments: TORCH_NCCL_RETHROW_CUDA_ERRORS: 1, TORCH_NCCL_PROPAGATE_ERROR: 0, TORCH_NCCL_DESYNC_DEBUG: 0
I0820 14:27:21.501315 4030958 ProcessGroupNCCL.cpp:980] [PG ID 2 PG GUID 3 Rank 1] ProcessGroupNCCL environments: NCCL version: 2.25.1, TORCH_NCCL_ASYNC_ERROR_HANDLING: 3, TORCH_NCCL_ENABLE_TIMING: 0, TORCH_NCCL_BLOCKING_WAIT: 0, TORCH_DISTRIBUTED_DEBUG: OFF, TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: 0, TORCH_NCCL_TRACE_BUFFER_SIZE: 2000, TORCH_NCCL_NAN_CHECK: 0, TORCH_NCCL_CUDA_EVENT_CACHE: 1
I0820 14:27:21.501345 4030959 ProcessGroupNCCL.cpp:971] [PG ID 2 PG GUID 3 Rank 0] ProcessGroupNCCL initialization options: size: 2, global rank: 0, TIMEOUT(ms): 600000, USE_HIGH_PRIORITY_STREAM: 0, SPLIT_FROM: 0, SPLIT_COLOR: -2, PG Name: 3
I0820 14:27:21.501350 4030959 ProcessGroupNCCL.cpp:980] [PG ID 2 PG GUID 3 Rank 0] ProcessGroupNCCL environments: NCCL version: 2.25.1, TORCH_NCCL_ASYNC_ERROR_HANDLING: 3, TORCH_NCCL_ENABLE_TIMING: 0, TORCH_NCCL_BLOCKING_WAIT: 0, TORCH_DISTRIBUTED_DEBUG: OFF, TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: 0, TORCH_NCCL_TRACE_BUFFER_SIZE: 2000, TORCH_NCCL_NAN_CHECK: 0, TORCH_NCCL_CUDA_EVENT_CACHE: 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
INFO 08-20 14:27:21 [__init__.py:1375] Found nccl from library libnccl.so.2
INFO 08-20 14:27:21 [__init__.py:1375] Found nccl from library libnccl.so.2
INFO 08-20 14:27:21 [pynccl.py:70] vLLM is using nccl==2.25.1
INFO 08-20 14:27:21 [pynccl.py:70] vLLM is using nccl==2.25.1
NCCL version 2.25.1+cuda12.4
INFO 08-20 14:27:22 [custom_all_reduce_utils.py:246] reading GPU P2P access cache from /home/ankitageorge/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
INFO 08-20 14:27:22 [custom_all_reduce_utils.py:246] reading GPU P2P access cache from /home/ankitageorge/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
INFO 08-20 14:27:22 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[1], buffer_handle=(1, 4194304, 6, 'psm_2c9a81ef'), local_subscribe_addr='ipc:///tmp/938bd84a-f8fc-460b-8105-12948b5c1a0f', remote_subscribe_addr=None, remote_addr_ipv6=False)
I0820 14:27:22.396420 4030959 ProcessGroupNCCL.cpp:1669] [PG ID 3 PG GUID 5 Rank 0] HeartbeatMonitor environments: TORCH_NCCL_ENABLE_MONITORING (Whether to kill program when no watchdog heartbeat detected): 1, TORCH_NCCL_DUMP_ON_TIMEOUT: 1, TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: 15000, TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: 480, TORCH_NCCL_COORD_CHECK_MILSEC: 1000, TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: 1
I0820 14:27:22.396435 4030959 ProcessGroupNCCL.cpp:1998] [PG ID 3 PG GUID 5 Rank 0] PGNCCL Watchdog environments: TORCH_NCCL_RETHROW_CUDA_ERRORS: 1, TORCH_NCCL_PROPAGATE_ERROR: 0, TORCH_NCCL_DESYNC_DEBUG: 0
I0820 14:27:22.396472 4030958 ProcessGroupNCCL.cpp:1669] [PG ID 3 PG GUID 7 Rank 0] HeartbeatMonitor environments: TORCH_NCCL_ENABLE_MONITORING (Whether to kill program when no watchdog heartbeat detected): 1, TORCH_NCCL_DUMP_ON_TIMEOUT: 1, TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: 15000, TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: 480, TORCH_NCCL_COORD_CHECK_MILSEC: 1000, TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: 1
I0820 14:27:22.396485 4030958 ProcessGroupNCCL.cpp:1998] [PG ID 3 PG GUID 7 Rank 0] PGNCCL Watchdog environments: TORCH_NCCL_RETHROW_CUDA_ERRORS: 1, TORCH_NCCL_PROPAGATE_ERROR: 0, TORCH_NCCL_DESYNC_DEBUG: 0
I0820 14:27:22.396550 4030958 ProcessGroupNCCL.cpp:971] [PG ID 3 PG GUID 7 Rank 0] ProcessGroupNCCL initialization options: size: 1, global rank: 1, TIMEOUT(ms): 600000, USE_HIGH_PRIORITY_STREAM: 0, SPLIT_FROM: 0, SPLIT_COLOR: -2, PG Name: 7
I0820 14:27:22.396513 4030959 ProcessGroupNCCL.cpp:971] [PG ID 3 PG GUID 5 Rank 0] ProcessGroupNCCL initialization options: size: 1, global rank: 0, TIMEOUT(ms): 600000, USE_HIGH_PRIORITY_STREAM: 0, SPLIT_FROM: 0, SPLIT_COLOR: -2, PG Name: 5
I0820 14:27:22.396553 4030958 ProcessGroupNCCL.cpp:980] [PG ID 3 PG GUID 7 Rank 0] ProcessGroupNCCL environments: NCCL version: 2.25.1, TORCH_NCCL_ASYNC_ERROR_HANDLING: 3, TORCH_NCCL_ENABLE_TIMING: 0, TORCH_NCCL_BLOCKING_WAIT: 0, TORCH_DISTRIBUTED_DEBUG: OFF, TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: 0, TORCH_NCCL_TRACE_BUFFER_SIZE: 2000, TORCH_NCCL_NAN_CHECK: 0, TORCH_NCCL_CUDA_EVENT_CACHE: 1
I0820 14:27:22.396517 4030959 ProcessGroupNCCL.cpp:980] [PG ID 3 PG GUID 5 Rank 0] ProcessGroupNCCL environments: NCCL version: 2.25.1, TORCH_NCCL_ASYNC_ERROR_HANDLING: 3, TORCH_NCCL_ENABLE_TIMING: 0, TORCH_NCCL_BLOCKING_WAIT: 0, TORCH_DISTRIBUTED_DEBUG: OFF, TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: 0, TORCH_NCCL_TRACE_BUFFER_SIZE: 2000, TORCH_NCCL_NAN_CHECK: 0, TORCH_NCCL_CUDA_EVENT_CACHE: 1
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
I0820 14:27:22.398516 4030958 ProcessGroupNCCL.cpp:1669] [PG ID 4 PG GUID 11 Rank 0] HeartbeatMonitor environments: TORCH_NCCL_ENABLE_MONITORING (Whether to kill program when no watchdog heartbeat detected): 1, TORCH_NCCL_DUMP_ON_TIMEOUT: 1, TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: 15000, TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: 480, TORCH_NCCL_COORD_CHECK_MILSEC: 1000, TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: 1
I0820 14:27:22.398527 4030958 ProcessGroupNCCL.cpp:1998] [PG ID 4 PG GUID 11 Rank 0] PGNCCL Watchdog environments: TORCH_NCCL_RETHROW_CUDA_ERRORS: 1, TORCH_NCCL_PROPAGATE_ERROR: 0, TORCH_NCCL_DESYNC_DEBUG: 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
I0820 14:27:22.398571 4030958 ProcessGroupNCCL.cpp:971] [PG ID 4 PG GUID 11 Rank 0] ProcessGroupNCCL initialization options: size: 1, global rank: 1, TIMEOUT(ms): 600000, USE_HIGH_PRIORITY_STREAM: 0, SPLIT_FROM: 0, SPLIT_COLOR: -2, PG Name: 11
I0820 14:27:22.398574 4030958 ProcessGroupNCCL.cpp:980] [PG ID 4 PG GUID 11 Rank 0] ProcessGroupNCCL environments: NCCL version: 2.25.1, TORCH_NCCL_ASYNC_ERROR_HANDLING: 3, TORCH_NCCL_ENABLE_TIMING: 0, TORCH_NCCL_BLOCKING_WAIT: 0, TORCH_DISTRIBUTED_DEBUG: OFF, TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: 0, TORCH_NCCL_TRACE_BUFFER_SIZE: 2000, TORCH_NCCL_NAN_CHECK: 0, TORCH_NCCL_CUDA_EVENT_CACHE: 1
I0820 14:27:22.398936 4030959 ProcessGroupNCCL.cpp:1669] [PG ID 4 PG GUID 9 Rank 0] HeartbeatMonitor environments: TORCH_NCCL_ENABLE_MONITORING (Whether to kill program when no watchdog heartbeat detected): 1, TORCH_NCCL_DUMP_ON_TIMEOUT: 1, TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: 15000, TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: 480, TORCH_NCCL_COORD_CHECK_MILSEC: 1000, TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: 1
I0820 14:27:22.398947 4030959 ProcessGroupNCCL.cpp:1998] [PG ID 4 PG GUID 9 Rank 0] PGNCCL Watchdog environments: TORCH_NCCL_RETHROW_CUDA_ERRORS: 1, TORCH_NCCL_PROPAGATE_ERROR: 0, TORCH_NCCL_DESYNC_DEBUG: 0
I0820 14:27:22.398993 4030959 ProcessGroupNCCL.cpp:971] [PG ID 4 PG GUID 9 Rank 0] ProcessGroupNCCL initialization options: size: 1, global rank: 0, TIMEOUT(ms): 600000, USE_HIGH_PRIORITY_STREAM: 0, SPLIT_FROM: 0, SPLIT_COLOR: -2, PG Name: 9
I0820 14:27:22.398995 4030959 ProcessGroupNCCL.cpp:980] [PG ID 4 PG GUID 9 Rank 0] ProcessGroupNCCL environments: NCCL version: 2.25.1, TORCH_NCCL_ASYNC_ERROR_HANDLING: 3, TORCH_NCCL_ENABLE_TIMING: 0, TORCH_NCCL_BLOCKING_WAIT: 0, TORCH_DISTRIBUTED_DEBUG: OFF, TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: 0, TORCH_NCCL_TRACE_BUFFER_SIZE: 2000, TORCH_NCCL_NAN_CHECK: 0, TORCH_NCCL_CUDA_EVENT_CACHE: 1
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
I0820 14:27:22.399775 4030958 ProcessGroupNCCL.cpp:1669] [PG ID 5 PG GUID 13 Rank 1] HeartbeatMonitor environments: TORCH_NCCL_ENABLE_MONITORING (Whether to kill program when no watchdog heartbeat detected): 1, TORCH_NCCL_DUMP_ON_TIMEOUT: 1, TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: 15000, TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: 480, TORCH_NCCL_COORD_CHECK_MILSEC: 1000, TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: 1
I0820 14:27:22.399782 4030958 ProcessGroupNCCL.cpp:1998] [PG ID 5 PG GUID 13 Rank 1] PGNCCL Watchdog environments: TORCH_NCCL_RETHROW_CUDA_ERRORS: 1, TORCH_NCCL_PROPAGATE_ERROR: 0, TORCH_NCCL_DESYNC_DEBUG: 0
I0820 14:27:22.399834 4030958 ProcessGroupNCCL.cpp:971] [PG ID 5 PG GUID 13 Rank 1] ProcessGroupNCCL initialization options: size: 2, global rank: 1, TIMEOUT(ms): 600000, USE_HIGH_PRIORITY_STREAM: 0, SPLIT_FROM: 0, SPLIT_COLOR: -2, PG Name: 13
I0820 14:27:22.399837 4030958 ProcessGroupNCCL.cpp:980] [PG ID 5 PG GUID 13 Rank 1] ProcessGroupNCCL environments: NCCL version: 2.25.1, TORCH_NCCL_ASYNC_ERROR_HANDLING: 3, TORCH_NCCL_ENABLE_TIMING: 0, TORCH_NCCL_BLOCKING_WAIT: 0, TORCH_DISTRIBUTED_DEBUG: OFF, TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: 0, TORCH_NCCL_TRACE_BUFFER_SIZE: 2000, TORCH_NCCL_NAN_CHECK: 0, TORCH_NCCL_CUDA_EVENT_CACHE: 1
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
I0820 14:27:22.400424 4030959 ProcessGroupNCCL.cpp:1669] [PG ID 5 PG GUID 13 Rank 0] HeartbeatMonitor environments: TORCH_NCCL_ENABLE_MONITORING (Whether to kill program when no watchdog heartbeat detected): 1, TORCH_NCCL_DUMP_ON_TIMEOUT: 1, TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: 15000, TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: 480, TORCH_NCCL_COORD_CHECK_MILSEC: 1000, TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN: 1
I0820 14:27:22.400432 4030959 ProcessGroupNCCL.cpp:1998] [PG ID 5 PG GUID 13 Rank 0] PGNCCL Watchdog environments: TORCH_NCCL_RETHROW_CUDA_ERRORS: 1, TORCH_NCCL_PROPAGATE_ERROR: 0, TORCH_NCCL_DESYNC_DEBUG: 0
I0820 14:27:22.400481 4030959 ProcessGroupNCCL.cpp:971] [PG ID 5 PG GUID 13 Rank 0] ProcessGroupNCCL initialization options: size: 2, global rank: 0, TIMEOUT(ms): 600000, USE_HIGH_PRIORITY_STREAM: 0, SPLIT_FROM: 0, SPLIT_COLOR: -2, PG Name: 13
I0820 14:27:22.400485 4030959 ProcessGroupNCCL.cpp:980] [PG ID 5 PG GUID 13 Rank 0] ProcessGroupNCCL environments: NCCL version: 2.25.1, TORCH_NCCL_ASYNC_ERROR_HANDLING: 3, TORCH_NCCL_ENABLE_TIMING: 0, TORCH_NCCL_BLOCKING_WAIT: 0, TORCH_DISTRIBUTED_DEBUG: OFF, TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: 0, TORCH_NCCL_TRACE_BUFFER_SIZE: 2000, TORCH_NCCL_NAN_CHECK: 0, TORCH_NCCL_CUDA_EVENT_CACHE: 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
INFO 08-20 14:27:22 [parallel_state.py:1102] rank 1 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 1, EP rank 1
INFO 08-20 14:27:22 [parallel_state.py:1102] rank 0 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
WARNING 08-20 14:27:22 [topk_topp_sampler.py:59] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
WARNING 08-20 14:27:22 [topk_topp_sampler.py:59] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
INFO 08-20 14:27:22 [gpu_model_runner.py:1843] Starting to load model meta-llama/Meta-Llama-3.1-8B-Instruct...
INFO 08-20 14:27:22 [gpu_model_runner.py:1843] Starting to load model meta-llama/Meta-Llama-3.1-8B-Instruct...
INFO 08-20 14:27:22 [gpu_model_runner.py:1875] Loading model from scratch...
INFO 08-20 14:27:22 [gpu_model_runner.py:1875] Loading model from scratch...
INFO 08-20 14:27:23 [cuda.py:290] Using Flash Attention backend on V1 engine.
INFO 08-20 14:27:23 [cuda.py:290] Using Flash Attention backend on V1 engine.
INFO 08-20 14:27:23 [weight_utils.py:296] Using model weights format ['*.safetensors']
INFO 08-20 14:27:23 [weight_utils.py:296] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:00,  7.33it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.63it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:02<00:00,  1.27it/s]
INFO 08-20 14:27:27 [default_loader.py:262] Loading weights took 3.27 seconds
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.19it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.34it/s]

INFO 08-20 14:27:27 [default_loader.py:262] Loading weights took 3.11 seconds
INFO 08-20 14:27:28 [gpu_model_runner.py:1892] Model loading took 7.5123 GiB and 4.536818 seconds
INFO 08-20 14:27:28 [gpu_model_runner.py:1892] Model loading took 7.5123 GiB and 4.864782 seconds
Setup completed successfully!
Calling Policy.update() to load weights from torchstore...
[-]E0820 14:27:41.036485 3985483 fbcode/monarch/hyperactor/src/channel/net.rs:644] session [email protected]: failed to deliver message within timeout
[-]E0820 14:27:42.072472 3985483 fbcode/monarch/hyperactor/src/channel/net.rs:644] session [email protected]: failed to deliver message within timeout
Successfully called Policy.update() to load weights from torchstore!
Successfully got model state dict after update
Loaded tensor model.layers.0.self_attn.qkv_proj.weight correctly validated
Loaded tensor model.layers.0.self_attn.o_proj.weight correctly validated
Loaded tensor model.layers.0.mlp.gate_up_proj.weight correctly validated
Loaded tensor model.layers.0.mlp.down_proj.weight correctly validated
Loaded tensor model.layers.0.input_layernorm.weight correctly validated
Loaded tensor model.layers.0.post_attention_layernorm.weight correctly validated
Successfully validated that all 6 loaded tensors equal original

Test passed! State dict successfully loaded into Policy!

Tensor parallel test passed! Full state dict successfully loaded into tensor parallel model!
PASSED

===================================================================================== warnings summary =====================================================================================
tests/integration_tests/test_policy_update.py: 196 warnings
  /home/ankitageorge/.fbpkg_conda_envs/forge-8e46c2a/lib/python3.10/site-packages/torchstore/store.py:54: DeprecationWarning: The 'warn' method is deprecated, use 'warning' instead
    logger.warn(f"Putting {key}")

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================= 1 passed, 196 warnings in 319.12s (0:05:19) ==

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 18, 2025
Copy link
Contributor

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is super cool @ankitageorge! Left a quick round of review as I know Philip is on-site in Bellevue and I wanted to help a bit

@ankitageorge ankitageorge marked this pull request as draft August 18, 2025 19:07
@ankitageorge
Copy link
Contributor Author

converting to draft while I fix some things, will re-open when ready for review

Returns:
torch.Tensor: The sharded tensor for this rank
"""
tp_rank = self.rank % self.tensor_parallel_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per the impl, we do;

  • even sharding
  • placement on every rank.
    Probably good to document the contract/policy.

def _get_tensor_parallel_sharding_strategy(param_name: str) -> tuple[int, bool]:
"""
Determine the sharding strategy for a parameter in tensor parallel setup.
This mirrors the logic from Policy._get_tensor_parallel_sharding_strategy.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we have a duplicate sharding strategy code ? Alternatively we can;
1\ Test the sharding util using an isolated UT. (without full model size complications).
2\ Use the util in both prod/test code paths.

@pradeepfn
Copy link
Contributor

Overall LGTM!.
Nit commment on the Policy code. ( even though it was not directly added by this PR).
The policy seems to be specific to LLAMA model and IIUC, there could be other policy/conversion utils for other models (?).
If so should we directly qualify them ? @pbontrager

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! (I know this is still in draft, so feel free to ignore any comments that you were already working on anyways)

- Output layer: shard along vocab dimension (dim 0)
"""
# Parameters that are not sharded (replicated across all tensor parallel ranks)
if any(keyword in param_name for keyword in ["norm", "bias", "rotary_emb"]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo we should associate logic like this with the model somehow rather than make it a fixed property of the Policy class. Happy to brainstorm a bit more on the right way to do this (also I assume the TP strategy here is unique to vLLM and does not in general match what's defined in titan?)

return self.vllm_args

@endpoint
async def get_model_params(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function purely for testing, or we plan to leave it in for the final implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in theory just for testing, but I think we need to leave it in, because I don't think there is another way to get the loaded params back from vllm to the test for comparison with the saved state dict

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a method that can be patched into the actor class in the test? For example you can do a TestPolicy(Policy) and then add this method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya I can't seem to get this to work. I've tried what you suggested, and patching it in different ways, but nothing seems to work.

@ankitageorge ankitageorge marked this pull request as ready for review August 20, 2025 21:32
@ankitageorge
Copy link
Contributor Author

re-opening for review

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for putting this together. I've requested a few changes but they're fairly small so I'll pre approve this.

return self.vllm_args

@endpoint
async def get_model_params(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a method that can be patched into the actor class in the test? For example you can do a TestPolicy(Policy) and then add this method.

@ankitageorge ankitageorge merged commit bbbc169 into main Aug 21, 2025
4 checks passed
@Jack-Khuu
Copy link
Contributor

QQ: how is torchstore being installed?

@ankitageorge
Copy link
Contributor Author

QQ: how is torchstore being installed?

https://github.com/meta-pytorch/forge/blob/main/pyproject.toml#L48

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants