-
Notifications
You must be signed in to change notification settings - Fork 16
shared memory multiprocess prefetch for weight update #430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
log_stats=None, | ||
) | ||
self._start_processing() | ||
fetcher_procs = this_host().spawn_procs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we guard this with if prefetch_weights
? It may also be a bit cleaner to put this init in another function w/ some documentation for what the fetchers here do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will move to a separate function.
) | ||
self._start_processing() | ||
fetcher_procs = this_host().spawn_procs( | ||
per_host={"procs": self.n_fetcher_procs} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not that we need to address now, but I think we need to spawn these fetcher procs across all generator nodes right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. but I assume setup
call is broadcasted and every Generator node will spawn their own fetcher_procs
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that is correct, I meant if a generator workers span 2 nodes i.e. DeepSeek
In that case we would probably want to spin up the fetchers on the worker nodes right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit confused -- shouldn't vLLM worker be scoped to a single node?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also don't follow why you need more than 1. Is it to allow you to parallelize torchstore requests?
@casteryh could you split the PR with separate concerns?
Also maybe add comment somewhere saying the following is up-for-discussion
I also don't quite get the "before" and "after" table. |
I think it makes little sense to split, see below.
Currently TorchStore RDMA only works with CPU-CPU.
This actually comes automatically once you have separate processes fetching the weights to shared memory.
Will do
Ah maybe it's confusing because I am trying to do two things at once. Basically the speed up comes from 1. multiprocess shared memory (this saves 30 seconds) & 2. prefetch while completing on-the-fly generation (this saves about 10 seconds on average) |
Ahha, yes, I got it now. Okay, so the boolean guard is not "use prefetch or not" -- is should be "use shared memory or not". I think you should try profiling on proc = 8 / replica for policy. |
@JenniferWang
|
@allenwang28 @JenniferWang ptal |
Yes, I was thinking that tp = 8 on policy would be worse without shared memory ? |
I think we should make using shared memory by default for CPU based weight sync, with a flag to turn it off. |
This commit fixes multiple memory leak issues in the SharedTensor implementation by introducing explicit lifecycle management and proper cleanup patterns. Key Changes: 1. Fixed __del__ bug: changed hasattr(self, "shm") to check "_shm" 2. Added explicit close() method for releasing shared memory handles 3. Changed tensor from @cached_property to @Property with manual caching 4. Added closed state tracking with is_closed property 5. Made tensor access after close() raise RuntimeError (fail-fast) 6. Made get_handle() after close() raise RuntimeError 7. Updated drop() to call close() first, then unlink 8. Added context manager support (__enter__/__exit__) 9. Fixed _WeightFetcher to explicitly close after getting handle 10. Fixed GeneratorWorker to close shared memory after loading weights 11. Optimized SharedTensorHandle.drop() to not create unnecessary instances Memory Leak Prevention: - Creators must call close() after getting handle - Receivers must call close() after using tensor - One process should call drop() to unlink after all are done - close() and drop() are idempotent and safe to call multiple times Documentation: - Added comprehensive class docstring with lifecycle model - Documented that cached tensor references become invalid after close() - Added warnings about not relying on __del__ for cleanup - Added 12 new tests for close/cleanup behavior Test Results: 65/65 tests pass with no warnings
Refactor generator code to use context manager pattern (with statement) for SharedTensor cleanup instead of explicit close() calls. This provides: - Clearer intent: context manager makes lifecycle explicit - Automatic cleanup: ensures close() is called even on exceptions - More idiomatic Python: standard pattern for resource management Changes: - GeneratorWorker.update_weights(): Use 'with' for SharedTensor from handles - _WeightFetcher.fetch(): Use 'with' when creating SharedTensor and getting handle The context manager automatically calls close() on exit, making the code more concise and safer.
yes I believe it was ~100 seconds without shared memory for tp=8, but I have some problem with my slurm node and can't test now. |
Change prefetch_weights_to_shm from False to True to enable the new shared memory-based weight prefetching feature by default.
done |
Remove qwen3_32b_experimental.yaml as the shared memory weight prefetching feature is now enabled by default and no longer experimental.
fixed a memory leak |
shm.close() | ||
shm.unlink() | ||
except Exception: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the consideration behind swallowing the exceptions in cleaning up the resource?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make this idempotent and safe to be called from multiple-processes. Open to other ideas.
Co-authored-by: Jiyue Wang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome! I think we'll want to go over how we're doing prefetch again after some of this is upstreamed to torchstore, but otherwise it looks great. I wonder if this is too risky of a change to make before PTC thought?
) | ||
|
||
|
||
class _WeightFetcher(ForgeActor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this could be a method on the generator that gets called from main, so prefetch is controlled and visible from the main loop. I am curious if this has to actually be a separate process since this is an async method and I would think most of the time it's waiting on ts.get.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has to be a separate actor because it has to be launched in a separate process
param_key = get_param_key(version, name) | ||
param = await ts.get(param_key) | ||
# Use context manager to ensure cleanup after getting handle | ||
with SharedTensor(tensor=param) as shared_tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the plan to move this to TS and hide the rdma/shared memory logic from the user?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hopefully yes.
) | ||
self._start_processing() | ||
fetcher_procs = this_host().spawn_procs( | ||
per_host={"procs": self.n_fetcher_procs} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also don't follow why you need more than 1. Is it to allow you to parallelize torchstore requests?
engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs) | ||
sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams) | ||
use_dcp_for_weight_sync: bool | None = None | ||
prefetch_weights_to_shm: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general we should try to avoid changing the "public" api when we expect to quickly change the backend again. After launch we should try to keep this in mind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed.
Yes it's 2x faster than 1 process. I haven't tuned this parameter too much though. @pbontrager |
I am testing its stability right now. But fwiw, the current main is not stable / well tested either. |
we can also switch the flag to be False by default |
What this PR does
Perf
TL;DR e2e weight sync time is now ~50s for QWen3 32b; one training step takes <70s
Tested with