44import traceback
55from http import HTTPStatus
66from http .server import BaseHTTPRequestHandler , HTTPServer
7- from typing import AnyStr
7+ from typing import Any , AnyStr , Dict , List
88
99import cloudpickle
10+ import torch
1011from tensordict import TensorDict
1112
1213from areal .api .controller_api import DistributedBatch
1314from areal .api .engine_api import InferenceEngine
1415from areal .controller .batch import DistributedBatchMemory
1516from areal .utils import logging
16- from areal .utils .data import (
17- tensor_container_to ,
18- )
1917
2018logger = logging .getLogger ("RPCServer" )
2119
2220
21+ def tensor_container_to_safe (
22+ d : Dict [str , Any ] | torch .Tensor | List [torch .Tensor ], * args , ** kwargs
23+ ):
24+ """Apply `t.to(*args, **kwargs)` to all tensors in the dictionary.
25+ Support nested dictionaries.
26+ """
27+ new_dict = {}
28+ if torch .is_tensor (d ):
29+ return d .to (* args , ** kwargs )
30+ elif isinstance (d , list ):
31+ return [tensor_container_to_safe (v , * args , ** kwargs ) for v in d ]
32+ elif isinstance (d , dict ):
33+ for key , value in d .items ():
34+ if isinstance (value , dict ) or isinstance (value , list ):
35+ new_dict [key ] = tensor_container_to_safe (value , * args , ** kwargs )
36+ elif torch .is_tensor (value ):
37+ new_dict [key ] = value .to (* args , ** kwargs )
38+ else :
39+ new_dict [key ] = value
40+ return new_dict
41+ else :
42+ return d
43+
44+
2345def process_input_to_distributed_batch (to_device , * args , ** kwargs ):
2446 for i in range (len (args )):
2547 if isinstance (args [i ], DistributedBatch ):
@@ -31,14 +53,14 @@ def process_input_to_distributed_batch(to_device, *args, **kwargs):
3153 if isinstance (kwargs [k ], DistributedBatch ):
3254 kwargs [k ] = kwargs [k ].get_data ()
3355
34- args = tuple (tensor_container_to (list (args ), to_device ))
35- kwargs = tensor_container_to (kwargs , to_device )
56+ args = tuple (tensor_container_to_safe (list (args ), to_device ))
57+ kwargs = tensor_container_to_safe (kwargs , to_device )
3658
3759 return args , kwargs
3860
3961
4062def process_output_to_distributed_batch (result ):
41- result = tensor_container_to (result , "cpu" )
63+ result = tensor_container_to_safe (result , "cpu" )
4264 if isinstance (result , dict ):
4365 return DistributedBatchMemory .from_dict (result )
4466 elif isinstance (result , TensorDict ):
0 commit comments