Skip to content

Efficiency of treetensor.torch.as_tensor #19

@LamannaLeonardo

Description

@LamannaLeonardo

I am experiencing an efficiency issue apparently due to treetensor.torch.as_tensor applied on a dictionary. I provide a minimal code for reproducibility (I am using Python 3.10, DI-treetensor==0.5.0 and torch==2.6.0):

import treetensor.torch as ttorch
import torch

if __name__ == '__main__':
    obs = torch.randn(1, 28, 28)
    for i in range(5000):
        # fast = ttorch.as_tensor(obs)
        slow = ttorch.as_tensor({'key': obs})

ttorch.as_tensor(obs) takes about 1 second, while ttorch.as_tensor({'key': obs}) requires about 16 seconds. Profiling the above code I see the __repr__ method of every tensor obs is being called, taking 15 out of 16 seconds.

Image

In my specific case, I faced such efficiency issue when collecting rollouts in DI-engine, where transition dictionaries are converted to tensors by means of ttorch.as_tensor at this line.

Can this slowdown be mitigated? I see converting a dictionary to a tree tensor requires additional computation, but I am wondering if some optimisation is possible. Moreover, adding torch.set_printoptions(precision=3, threshold=10) significantly reduces the computation time (from 16 to 2.5 seconds), but I am not sure of possible side effects and how the print options affect converting a dictionary to a tree tensor.

Thank you

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions