-
Notifications
You must be signed in to change notification settings - Fork 5
[WIP] PTD / Gloo transport implementation #44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
torchstore/transport/pipe.py
Outdated
|
||
latency_trcker.track_step("allocate") | ||
|
||
# TODO: booooo |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a huge fan of this, need to think of something better
torchstore/transport/pipe.py
Outdated
|
||
latency_trcker.track_step("allocate") | ||
|
||
# TODO: booooo |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a huge fan of this, need to think of something better
state["transport_context"] = None | ||
return state | ||
|
||
async def setup_comms(self, storage_volume): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry it's not clear to me what does the topology look like for multiple replicas.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JenniferWang each client/volume pair creates a 1:1 process group for send/recv which is cached.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For clarity, filestore is getting replaced with tcpstore before merging
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is incredible! Huge congratulations for making this work!
Left some comments and questions.
self.objects = other_buffer.objects | ||
self.requires_meta = other_buffer.requires_meta | ||
|
||
async def setup_comms(self, storage_volume) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this safe to be called concurrently and idempotent? Based on how you use it in the create transport buffer code, I assume it is more like ensure_comms
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah actually it's not safe or idempotent. It's also not safe to call concurrently from the same client/volume combo.
We may need a lock based on the client, wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah actually it's not safe or idempotent. It's also not safe to call concurrently from the same client/volume combo.
We may need a lock based on the client, wdyt?
Not in the scope of this PR but we can add a TODO and add an issue just to keep track of this.
|
||
transport_buffer = self.create_transport_buffer() | ||
async def get_from_storage_volume(self, key, request: Request): | ||
latency_trcker = LatencyTracker(f"get_from_storage_volume:{key}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: typo
|
||
# TODO: re-evaluate thiss logic for better polymorphism | ||
t = None | ||
if transport_buffer.read_ahead: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you explain what read_ahead means?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit of a bummer, but for ptd style comms I need to call 'recv' before calling the storage volume equivalent (send), otherwise the code hangs.
I'd like to circle back here and figure out one path that works for all transport buffers, but I'm punting this down the road in the interest of getting something working.
# TODO: consider placing the buffer inside the request or vice versa | ||
transport_buffer.update( | ||
await self.storage_volume.get.call_one( | ||
key, transport_buffer, request.meta_only() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the buffer is not read_ahead, then the storage_volume does the read/write, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
storage volume always does the read ahead, the only thing that changes is the order. In the non-ptd case we call storage_volume.get before we call transport_buffer.read_into in the client.
A bit messy -- open to suggestions here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall LGTM, except two remaining concerns:
- I don't think a separate
finish()
method is necessary and it makes more sense to keep the original semantics of read_into and write_from, by creating a background task to poll the pytorch future.
Am I thinking straight here? - Can we add a TODO and a tracking issue to document the unsafe behavior of
setup_comms
?
|
||
assert self.fut is None | ||
pg = self.transport_context[self.file_store_name] | ||
self.fut = pg.send([tensor], dstRank=self.remote_rank, tag=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is useful so we can do things concurrently while the future is pending (actually necessary so we can schedule the recv in the storage volume from the same thread
From all the code that uses either read_into
or write_from
I always see buffer.finish()
follows immediately?
What I am saying is a pytorch.futures.Future can be converted to a asyncio style future simply by create a polling task. So the additional finish() is not necessary.
self.fut = pg.send([tensor], dstRank=self.remote_rank, tag=0) | |
self.fut = pg.send([tensor], dstRank=self.remote_rank, tag=0) | |
async def poll_pt_future(fut): | |
while not fut.done(): | |
await asyncio.sleep(0.01) # or other poll frequency | |
await asyncio.create_task(poll_pt_future(self.fut)) |
torchstore/storage_volume.py
Outdated
for shard in self.kv[key].values(): | ||
if shard["slice"] == request.tensor_slice: | ||
await transport_buffer.write_from(shard["tensor"]) | ||
transport_buffer.finish() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
transport_buffer.finish() | |
await transport_buffer.finish() |
[WIP] PTD / Gloo transport implementation