-
Notifications
You must be signed in to change notification settings - Fork 16
Load model from torchstore into vLLM #55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is super cool @ankitageorge! Left a quick round of review as I know Philip is on-site in Bellevue and I wanted to help a bit
converting to draft while I fix some things, will re-open when ready for review |
src/forge/actors/policy.py
Outdated
Returns: | ||
torch.Tensor: The sharded tensor for this rank | ||
""" | ||
tp_rank = self.rank % self.tensor_parallel_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As per the impl, we do;
- even sharding
- placement on every rank.
Probably good to document the contract/policy.
tests/test_vllm_torchstore.py
Outdated
def _get_tensor_parallel_sharding_strategy(param_name: str) -> tuple[int, bool]: | ||
""" | ||
Determine the sharding strategy for a parameter in tensor parallel setup. | ||
This mirrors the logic from Policy._get_tensor_parallel_sharding_strategy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we have a duplicate sharding strategy code ? Alternatively we can;
1\ Test the sharding util using an isolated UT. (without full model size complications).
2\ Use the util in both prod/test code paths.
Overall LGTM!. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! (I know this is still in draft, so feel free to ignore any comments that you were already working on anyways)
src/forge/actors/policy.py
Outdated
- Output layer: shard along vocab dimension (dim 0) | ||
""" | ||
# Parameters that are not sharded (replicated across all tensor parallel ranks) | ||
if any(keyword in param_name for keyword in ["norm", "bias", "rotary_emb"]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Imo we should associate logic like this with the model somehow rather than make it a fixed property of the Policy class. Happy to brainstorm a bit more on the right way to do this (also I assume the TP strategy here is unique to vLLM and does not in general match what's defined in titan?)
return self.vllm_args | ||
|
||
@endpoint | ||
async def get_model_params(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this function purely for testing, or we plan to leave it in for the final implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in theory just for testing, but I think we need to leave it in, because I don't think there is another way to get the loaded params back from vllm to the test for comparison with the saved state dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a method that can be patched into the actor class in the test? For example you can do a TestPolicy(Policy) and then add this method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ya I can't seem to get this to work. I've tried what you suggested, and patching it in different ways, but nothing seems to work.
re-opening for review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for putting this together. I've requested a few changes but they're fairly small so I'll pre approve this.
return self.vllm_args | ||
|
||
@endpoint | ||
async def get_model_params(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a method that can be patched into the actor class in the test? For example you can do a TestPolicy(Policy) and then add this method.
QQ: how is torchstore being installed? |
https://github.com/meta-pytorch/forge/blob/main/pyproject.toml#L48 |
Add logic that loads the model from torchstore into vLLM. Handles single rank and distributed case.
Adds a test that writes the model to torchstore and the reads from it with the changes to the update method in the policy actor
Output: