-
Notifications
You must be signed in to change notification settings - Fork 4
Description
I am experiencing an efficiency issue apparently due to treetensor.torch.as_tensor applied on a dictionary. I provide a minimal code for reproducibility (I am using Python 3.10, DI-treetensor==0.5.0 and torch==2.6.0):
import treetensor.torch as ttorch
import torch
if __name__ == '__main__':
obs = torch.randn(1, 28, 28)
for i in range(5000):
# fast = ttorch.as_tensor(obs)
slow = ttorch.as_tensor({'key': obs})
ttorch.as_tensor(obs) takes about 1 second, while ttorch.as_tensor({'key': obs}) requires about 16 seconds. Profiling the above code I see the __repr__ method of every tensor obs is being called, taking 15 out of 16 seconds.
In my specific case, I faced such efficiency issue when collecting rollouts in DI-engine, where transition dictionaries are converted to tensors by means of ttorch.as_tensor at this line.
Can this slowdown be mitigated? I see converting a dictionary to a tree tensor requires additional computation, but I am wondering if some optimisation is possible. Moreover, adding torch.set_printoptions(precision=3, threshold=10) significantly reduces the computation time (from 16 to 2.5 seconds), but I am not sure of possible side effects and how the print options affect converting a dictionary to a tree tensor.
Thank you
