-
Notifications
You must be signed in to change notification settings - Fork 5
Open
Description
Formulating some ideas from @allenwang28 .
It would be cool if we have some data structure satisfies the following
- Is serializable and can be sent through monarch rpc directly
- Small objects/non-tensor/metadata can be directly accessed upon deserialization
- Large objects/tensors are stored in torchstore and can be accessed via a
materialize()/get()/fetch()method on demand. - Easy to declare
Ideally we want to be able to do the following:
from torchstore.dataclasses import dataclass as ts_dataclass # a better name than `dataclass`?
from torchstore.dataclasses import RemoteTensor, mark_remote
@ts_dataclass
class Rollout:
data: dict = {}
some_int: int = 0
# there is probably a way so that we don't have to explictly specify 'actions', might not worth the effort to do now
actions: RemoteTensor | torch.tensor | torch.DTensor = mark_remote('actions')
rewards: RemoteTensor | torch.tensor | torch.DTensor = mark_remote('rewards')
envs: RemoteTensor | torch.tensor | torch.DTensor = mark_remote('envs')Then we on the sending end we can do
rollout = Rollout(prefix='some_key_prefix') # or generate a random prefix for the user
# do something
rollout.data = some_trivially_serializable_python_object
rollout.actions = actions_tensor
# ...
# upload to torchstore and `dematerialize` the tensors before sending it out
rollout = await rollout.to_remote()
await other_actor.do_something.call(rollout)On the receiving end we can do
# wait until everything is materialized
rollout = await rollout.materialize()or
metadata = rollout.data
# do something with data
# pass to another actor without materializing
yet_another_actor.do_something_else.call(rollout)or
actions = await rollout.data.materialize()
# do some computation with actions, envs is never fetched from remoteOther nice-to-haves:
- Support the above in a nested way, i.e., a field itself can have remote references to tensors.
Metadata
Metadata
Assignees
Labels
No labels