Skip to content

Commit 9ae69d2

Browse files
committed
update to fix comment
1 parent 21d5385 commit 9ae69d2

File tree

4 files changed

+32
-14
lines changed

4 files changed

+32
-14
lines changed

areal/engine/fsdp_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ def create_process_group(self, parallel_strategy: ParallelStrategy | None = None
121121
self.dp_head = int(self.world_mesh["sp_tp"].mesh[0].item())
122122
self.dp_rank = dist.get_rank(self.dp_group)
123123

124+
self.world_size = int(os.environ["WORLD_SIZE"])
125+
124126
self.logger.info(f"Data parallel head {self.dp_head} and rank {self.dp_rank}")
125127

126128
def initialize(
@@ -137,8 +139,6 @@ def initialize(
137139
"torch", "2.4.0"
138140
), f"areal only supports FSDP2, which requires torch>=2.4.0"
139141

140-
self.world_size = int(os.environ["WORLD_SIZE"])
141-
142142
# Create device model
143143
self.create_device_model()
144144

areal/scheduler/rpc/rpc_client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def create_engine(
2828
self,
2929
worker_id: str,
3030
engine_obj: Union[InferenceEngine, TrainEngine],
31-
# init_config: Union[InferenceEngineConfig, TrainEngineConfig],
3231
*args,
3332
**kwargs,
3433
) -> None:

areal/scheduler/rpc/rpc_server.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,44 @@
44
import traceback
55
from http import HTTPStatus
66
from http.server import BaseHTTPRequestHandler, HTTPServer
7-
from typing import AnyStr
7+
from typing import Any, AnyStr, Dict, List
88

99
import cloudpickle
10+
import torch
1011
from tensordict import TensorDict
1112

1213
from areal.api.controller_api import DistributedBatch
1314
from areal.api.engine_api import InferenceEngine
1415
from areal.controller.batch import DistributedBatchMemory
1516
from areal.utils import logging
16-
from areal.utils.data import (
17-
tensor_container_to,
18-
)
1917

2018
logger = logging.getLogger("RPCServer")
2119

2220

21+
def tensor_container_to_safe(
22+
d: Dict[str, Any] | torch.Tensor | List[torch.Tensor], *args, **kwargs
23+
):
24+
"""Apply `t.to(*args, **kwargs)` to all tensors in the dictionary.
25+
Support nested dictionaries.
26+
"""
27+
new_dict = {}
28+
if torch.is_tensor(d):
29+
return d.to(*args, **kwargs)
30+
elif isinstance(d, list):
31+
return [tensor_container_to_safe(v, *args, **kwargs) for v in d]
32+
elif isinstance(d, dict):
33+
for key, value in d.items():
34+
if isinstance(value, dict) or isinstance(value, list):
35+
new_dict[key] = tensor_container_to_safe(value, *args, **kwargs)
36+
elif torch.is_tensor(value):
37+
new_dict[key] = value.to(*args, **kwargs)
38+
else:
39+
new_dict[key] = value
40+
return new_dict
41+
else:
42+
return d
43+
44+
2345
def process_input_to_distributed_batch(to_device, *args, **kwargs):
2446
for i in range(len(args)):
2547
if isinstance(args[i], DistributedBatch):
@@ -31,14 +53,14 @@ def process_input_to_distributed_batch(to_device, *args, **kwargs):
3153
if isinstance(kwargs[k], DistributedBatch):
3254
kwargs[k] = kwargs[k].get_data()
3355

34-
args = tuple(tensor_container_to(list(args), to_device))
35-
kwargs = tensor_container_to(kwargs, to_device)
56+
args = tuple(tensor_container_to_safe(list(args), to_device))
57+
kwargs = tensor_container_to_safe(kwargs, to_device)
3658

3759
return args, kwargs
3860

3961

4062
def process_output_to_distributed_batch(result):
41-
result = tensor_container_to(result, "cpu")
63+
result = tensor_container_to_safe(result, "cpu")
4264
if isinstance(result, dict):
4365
return DistributedBatchMemory.from_dict(result)
4466
elif isinstance(result, TensorDict):

areal/utils/data.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,10 +351,7 @@ def tensor_container_to(
351351
new_dict[key] = value
352352
return new_dict
353353
else:
354-
logger.warning(
355-
f"Unsupported type in tensor_container_to: {type(d)}, returning original."
356-
)
357-
return d
354+
raise ValueError(f"Unsupported type: {type(d)}")
358355

359356

360357
@dataclass

0 commit comments

Comments
 (0)