Skip to content

Conversation

LucasLLC
Copy link
Contributor

[WIP] PTD / Gloo transport implementation

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 29, 2025

latency_trcker.track_step("allocate")

# TODO: booooo
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a huge fan of this, need to think of something better


latency_trcker.track_step("allocate")

# TODO: booooo
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a huge fan of this, need to think of something better

state["transport_context"] = None
return state

async def setup_comms(self, storage_volume):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry it's not clear to me what does the topology look like for multiple replicas.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JenniferWang each client/volume pair creates a 1:1 process group for send/recv which is cached.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarity, filestore is getting replaced with tcpstore before merging

Copy link
Contributor

@casteryh casteryh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incredible! Huge congratulations for making this work!

Left some comments and questions.

self.objects = other_buffer.objects
self.requires_meta = other_buffer.requires_meta

async def setup_comms(self, storage_volume) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe to be called concurrently and idempotent? Based on how you use it in the create transport buffer code, I assume it is more like ensure_comms ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah actually it's not safe or idempotent. It's also not safe to call concurrently from the same client/volume combo.

We may need a lock based on the client, wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah actually it's not safe or idempotent. It's also not safe to call concurrently from the same client/volume combo.

We may need a lock based on the client, wdyt?

Not in the scope of this PR but we can add a TODO and add an issue just to keep track of this.


transport_buffer = self.create_transport_buffer()
async def get_from_storage_volume(self, key, request: Request):
latency_trcker = LatencyTracker(f"get_from_storage_volume:{key}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo


# TODO: re-evaluate thiss logic for better polymorphism
t = None
if transport_buffer.read_ahead:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain what read_ahead means?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a bummer, but for ptd style comms I need to call 'recv' before calling the storage volume equivalent (send), otherwise the code hangs.

I'd like to circle back here and figure out one path that works for all transport buffers, but I'm punting this down the road in the interest of getting something working.

# TODO: consider placing the buffer inside the request or vice versa
transport_buffer.update(
await self.storage_volume.get.call_one(
key, transport_buffer, request.meta_only()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the buffer is not read_ahead, then the storage_volume does the read/write, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

storage volume always does the read ahead, the only thing that changes is the order. In the non-ptd case we call storage_volume.get before we call transport_buffer.read_into in the client.

A bit messy -- open to suggestions here

Copy link
Contributor

@casteryh casteryh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall LGTM, except two remaining concerns:

  • I don't think a separate finish() method is necessary and it makes more sense to keep the original semantics of read_into and write_from, by creating a background task to poll the pytorch future.
    Am I thinking straight here?
  • Can we add a TODO and a tracking issue to document the unsafe behavior of setup_comms?


assert self.fut is None
pg = self.transport_context[self.file_store_name]
self.fut = pg.send([tensor], dstRank=self.remote_rank, tag=0)
Copy link
Contributor

@casteryh casteryh Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is useful so we can do things concurrently while the future is pending (actually necessary so we can schedule the recv in the storage volume from the same thread

From all the code that uses either read_into or write_from I always see buffer.finish() follows immediately?

What I am saying is a pytorch.futures.Future can be converted to a asyncio style future simply by create a polling task. So the additional finish() is not necessary.

Suggested change
self.fut = pg.send([tensor], dstRank=self.remote_rank, tag=0)
self.fut = pg.send([tensor], dstRank=self.remote_rank, tag=0)
async def poll_pt_future(fut):
while not fut.done():
await asyncio.sleep(0.01) # or other poll frequency
await asyncio.create_task(poll_pt_future(self.fut))

for shard in self.kv[key].values():
if shard["slice"] == request.tensor_slice:
await transport_buffer.write_from(shard["tensor"])
transport_buffer.finish()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
transport_buffer.finish()
await transport_buffer.finish()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants