Skip to content

Commit 59f60b0

Browse files
committed
fix for cr
1 parent 616f6de commit 59f60b0

File tree

2 files changed

+81
-44
lines changed

2 files changed

+81
-44
lines changed

areal/scheduler/rpc/rpc_server.py

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import argparse
22
import importlib
33
import traceback
4+
from collections.abc import Callable
45
from concurrent.futures import Future
6+
from typing import Any
57

68
from flask import Flask, jsonify, request
79

@@ -20,6 +22,53 @@
2022
# Global engine instance - must be TrainEngine or InferenceEngine
2123
_engine: TrainEngine | InferenceEngine | None = None
2224

25+
26+
def _handle_submit(
27+
engine: InferenceEngine,
28+
args: list[Any],
29+
kwargs: dict[str, Any],
30+
) -> tuple[list[Any], dict[str, Any]]:
31+
workflow_path = kwargs["workflow_path"]
32+
workflow_kwargs = kwargs["workflow_kwargs"]
33+
episode_data = kwargs["data"]
34+
should_accept_path = kwargs["should_accept_path"]
35+
36+
episode_data = deserialize_value(episode_data)
37+
38+
module_path, class_name = workflow_path.rsplit(".", 1)
39+
module = importlib.import_module(module_path)
40+
workflow_class = getattr(module, class_name)
41+
logger.info(f"Imported workflow class: {workflow_path}")
42+
43+
workflow_kwargs = deserialize_value(workflow_kwargs)
44+
workflow = workflow_class(**workflow_kwargs)
45+
logger.info(f"Workflow '{workflow_path}' instantiated successfully")
46+
47+
should_accept = None
48+
if should_accept_path is not None:
49+
module_path, fn_name = should_accept_path.rsplit(".", 1)
50+
module = importlib.import_module(module_path)
51+
should_accept = getattr(module, fn_name)
52+
logger.info(f"Imported filtering function: {should_accept_path}")
53+
54+
new_args: list[Any] = []
55+
new_kwargs: dict[str, Any] = dict(
56+
data=episode_data,
57+
workflow=workflow,
58+
should_accept=should_accept,
59+
)
60+
return new_args, new_kwargs
61+
62+
63+
_METHOD_HANDLERS: dict[
64+
str,
65+
Callable[
66+
[InferenceEngine, list[Any], dict[str, Any]], tuple[list[Any], dict[str, Any]]
67+
],
68+
] = {
69+
"submit": _handle_submit,
70+
}
71+
2372
# Create Flask app
2473
app = Flask(__name__)
2574

@@ -207,42 +256,12 @@ def call_engine_method():
207256
500,
208257
)
209258

210-
# Special case for `submit` on inference engines
259+
# Special handling for some methods on inference engines
211260
try:
212-
if method_name == "submit" and isinstance(_engine, InferenceEngine):
213-
workflow_path = kwargs["workflow_path"]
214-
workflow_kwargs = kwargs["workflow_kwargs"]
215-
episode_data = kwargs["data"]
216-
should_accept_path = kwargs["should_accept_path"]
217-
218-
# Deserialize episode_data (may contain tensors)
219-
episode_data = deserialize_value(episode_data)
220-
221-
# Dynamic import workflow
222-
module_path, class_name = workflow_path.rsplit(".", 1)
223-
module = importlib.import_module(module_path)
224-
workflow_class = getattr(module, class_name)
225-
logger.info(f"Imported workflow class: {workflow_path}")
226-
227-
# Instantiate workflow
228-
workflow_kwargs = deserialize_value(workflow_kwargs)
229-
workflow = workflow_class(**workflow_kwargs)
230-
logger.info(f"Workflow '{workflow_path}' instantiated successfully")
231-
232-
should_accept = None
233-
if should_accept_path is not None:
234-
# Dynamic import filtering function
235-
module_path, fn_name = should_accept_path.rsplit(".", 1)
236-
module = importlib.import_module(module_path)
237-
should_accept = getattr(module, fn_name)
238-
logger.info(f"Imported filtering function: {should_accept_path}")
239-
240-
args = []
241-
kwargs = dict(
242-
data=episode_data,
243-
workflow=workflow,
244-
should_accept=should_accept,
245-
)
261+
if isinstance(_engine, InferenceEngine):
262+
handler = _METHOD_HANDLERS.get(method_name)
263+
if handler is not None:
264+
args, kwargs = handler(_engine, args, kwargs)
246265
except Exception as e:
247266
logger.error(
248267
f"Workflow data conversion failed: {e}\n{traceback.format_exc()}"
@@ -261,7 +280,9 @@ def call_engine_method():
261280

262281
# HACK: handle update weights future
263282
if isinstance(result, Future):
283+
logger.info("Waiting for update weights future")
264284
result = result.result()
285+
logger.info("Update weights future done")
265286

266287
# Serialize result (convert tensors to SerializedTensor dicts)
267288
serialized_result = serialize_value(result)
@@ -296,7 +317,11 @@ def export_stats():
296317
return jsonify({"error": "Engine not initialized"}), 503
297318

298319
# TrainEngine: reduce stats across data_parallel_group
299-
assert isinstance(_engine, TrainEngine)
320+
if not isinstance(_engine, TrainEngine):
321+
return (
322+
jsonify({"error": "/export_stats is only available for TrainEngine"}),
323+
400,
324+
)
300325
result = stats_tracker.export(reduce_group=_engine.data_parallel_group)
301326
return jsonify({"status": "success", "result": result})
302327

areal/scheduler/rpc/serialization.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,14 @@
2525
import torch
2626
from pydantic import BaseModel, Field
2727

28+
from areal.utils import logging
29+
2830
TOKENIZER_ARCHIVE_INLINE_THRESHOLD = 512 * 1024
2931
TOKENIZER_ZSTD_THRESHOLD = 20 * 1024 * 1024
3032
TokenizerCompression = Literal["zip", "zstd"]
3133

34+
logger = logging.getLogger("SyncRPCServer")
35+
3236

3337
class SerializedTensor(BaseModel):
3438
"""Pydantic model for serialized tensor with metadata.
@@ -96,9 +100,6 @@ def to_tensor(self) -> torch.Tensor:
96100
dtype_str = self.dtype.replace("torch.", "")
97101
dtype = getattr(torch, dtype_str)
98102

99-
# Reconstruct tensor from bytes
100-
import numpy as np
101-
102103
np_array = np.frombuffer(buffer, dtype=self._torch_dtype_to_numpy(dtype))
103104
# Copy the array to make it writable before converting to tensor
104105
np_array = np_array.copy()
@@ -465,33 +466,44 @@ def deserialize_value(value: Any) -> Any:
465466
}
466467
# Reconstruct the dataclass instance
467468
return dataclass_type(**deserialized_data)
468-
except Exception:
469+
except Exception as e:
469470
# If parsing fails, treat as regular dict
471+
logger.warning(
472+
f"Failed to deserialize dataclass, treating as regular dict: {e}"
473+
)
470474
pass
471475

472476
# Check for SerializedTokenizer marker
473477
if value.get("type") == "tokenizer":
474478
try:
475479
serialized_tokenizer = SerializedTokenizer.model_validate(value)
476480
return serialized_tokenizer.to_tokenizer()
477-
except Exception:
481+
except Exception as e:
482+
logger.warning(
483+
f"Failed to deserialize tokenizer, treating as regular dict: {e}"
484+
)
478485
pass
479486

480487
# Check for SerializedNDArray marker
481488
if value.get("type") == "ndarray":
482489
try:
483490
serialized_array = SerializedNDArray.model_validate(value)
484491
return serialized_array.to_array()
485-
except Exception:
492+
except Exception as e:
493+
logger.warning(
494+
f"Failed to deserialize ndarray, treating as regular dict: {e}"
495+
)
486496
pass
487497

488498
# Check for SerializedTensor marker
489499
if value.get("type") == "tensor":
490500
try:
491501
serialized_tensor = SerializedTensor.model_validate(value)
492502
return serialized_tensor.to_tensor()
493-
except Exception:
494-
# If parsing fails, treat as regular dict
503+
except Exception as e:
504+
logger.warning(
505+
f"Failed to deserialize tensor, treating as regular dict: {e}"
506+
)
495507
pass
496508

497509
# Regular dict - recursively deserialize values

0 commit comments

Comments
 (0)