Skip to content

Commit 02340e1

Browse files
add support for async inference (ray-project#54824)
This PR aims to provide basic support for asynchronous inference in the ray serve. RFC can be found at: ray-project#54652 The PR doesn't contains all the implementation pieces as having all the code changes in a single PR would be very difficult to review. Missing pieces are - implementation of failed and unprocessed task queue for the celery task processor - add more detailed and thorough tests for the same. These missing pieces will be taken care of in the subsequent PRs. --------- Signed-off-by: harshit <[email protected]>
1 parent 4dd7321 commit 02340e1

22 files changed

+2088
-2
lines changed

doc/source/serve/api/index.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,9 @@ Content-Type: application/json
384384
schema.ServeApplicationSchema
385385
schema.DeploymentSchema
386386
schema.RayActorOptionsSchema
387+
schema.CeleryAdapterConfig
388+
schema.TaskProcessorConfig
389+
schema.TaskResult
387390
```
388391

389392
(serve-rest-api-response-schema)=

python/ray/serve/schema.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections import Counter
33
from dataclasses import dataclass, field
44
from enum import Enum
5-
from typing import Any, Dict, List, Optional, Set, Union
5+
from typing import Any, Callable, Dict, List, Optional, Set, Union
66
from zlib import crc32
77

88
from ray._common.pydantic_compat import (
@@ -1202,3 +1202,64 @@ def _get_user_facing_json_serializable_dict(
12021202
)
12031203

12041204
return values
1205+
1206+
1207+
@PublicAPI(stability="alpha")
1208+
class CeleryAdapterConfig(BaseModel):
1209+
"""
1210+
Celery adapter config. You can use it to configure the Celery task processor for your Serve application.
1211+
"""
1212+
1213+
broker_url: str = Field(..., description="The URL of the broker to use for Celery.")
1214+
backend_url: str = Field(
1215+
..., description="The URL of the backend to use for Celery."
1216+
)
1217+
broker_transport_options: Optional[Dict[str, Any]] = Field(
1218+
default=None, description="The broker transport options to use for Celery."
1219+
)
1220+
worker_concurrency: Optional[int] = Field(
1221+
default=10,
1222+
description="The number of concurrent worker threads for the task processor.",
1223+
)
1224+
1225+
1226+
@PublicAPI(stability="alpha")
1227+
class TaskProcessorConfig(BaseModel):
1228+
"""
1229+
Task processor config. You can use it to configure the task processor for your Serve application.
1230+
"""
1231+
1232+
queue_name: str = Field(
1233+
..., description="The name of the queue to use for task processing."
1234+
)
1235+
adapter: Union[str, Callable] = Field(
1236+
default="ray.serve.task_processor.CeleryTaskProcessorAdapter",
1237+
description="The adapter to use for task processing. By default, Celery is used.",
1238+
)
1239+
adapter_config: Any = Field(..., description="The adapter config.")
1240+
max_retries: Optional[int] = Field(
1241+
default=3,
1242+
description="The maximum number of times to retry a task before marking it as failed.",
1243+
)
1244+
failed_task_queue_name: Optional[str] = Field(
1245+
default=None,
1246+
description="The name of the failed task queue. This is used to move failed tasks to a dead-letter queue after max retries.",
1247+
)
1248+
unprocessable_task_queue_name: Optional[str] = Field(
1249+
default=None,
1250+
description="The name of the unprocessable task queue. This is used to move unprocessable tasks(like tasks with serialization issue, or missing handler) to a dead-letter queue.",
1251+
)
1252+
1253+
1254+
@PublicAPI(stability="alpha")
1255+
class TaskResult(BaseModel):
1256+
"""
1257+
Task result Model.
1258+
"""
1259+
1260+
id: str = Field(..., description="The ID of the task.")
1261+
status: str = Field(..., description="The status of the task.")
1262+
created_at: Optional[float] = Field(
1263+
default=None, description="The timestamp of the task creation."
1264+
)
1265+
result: Any = Field(..., description="The result of the task.")

python/ray/serve/task_consumer.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import inspect
2+
import logging
3+
from functools import wraps
4+
from typing import Callable, Optional
5+
6+
from ray._common.utils import import_attr
7+
from ray.serve._private.constants import SERVE_LOGGER_NAME
8+
from ray.serve.schema import TaskProcessorConfig
9+
from ray.serve.task_processor import TaskProcessorAdapter
10+
from ray.util.annotations import PublicAPI
11+
12+
logger = logging.getLogger(SERVE_LOGGER_NAME)
13+
14+
15+
@PublicAPI(stability="alpha")
16+
def instantiate_adapter_from_config(
17+
task_processor_config: TaskProcessorConfig,
18+
) -> TaskProcessorAdapter:
19+
"""
20+
Create a TaskProcessorAdapter instance from the provided configuration and call .initialize(). This function supports two ways to specify an adapter:
21+
22+
1. String path: A fully qualified module path to an adapter class
23+
Example: "ray.serve.task_processor.CeleryTaskProcessorAdapter"
24+
25+
2. Class reference: A direct reference to an adapter class
26+
Example: CeleryTaskProcessorAdapter
27+
28+
Args:
29+
task_processor_config: Configuration object containing adapter specification.
30+
31+
Returns:
32+
An initialized TaskProcessorAdapter instance ready for use.
33+
34+
Raises:
35+
ValueError: If the adapter string path is malformed or cannot be imported.
36+
TypeError: If the adapter is not a string or callable class.
37+
38+
Example:
39+
.. code-block:: python
40+
41+
config = TaskProcessorConfig(
42+
adapter="my.module.CustomAdapter",
43+
adapter_config={"param": "value"},
44+
queue_name="my_queue"
45+
)
46+
adapter = instantiate_adapter_from_config(config)
47+
"""
48+
49+
adapter = task_processor_config.adapter
50+
51+
# Handle string-based adapter specification (module path)
52+
if isinstance(adapter, str):
53+
adapter_class = import_attr(adapter)
54+
55+
elif callable(adapter):
56+
adapter_class = adapter
57+
58+
else:
59+
raise TypeError(
60+
f"Adapter must be either a string path or a callable class, got {type(adapter).__name__}: {adapter}"
61+
)
62+
63+
try:
64+
adapter_instance = adapter_class(config=task_processor_config)
65+
except Exception as e:
66+
raise RuntimeError(f"Failed to instantiate {adapter_class.__name__}: {e}")
67+
68+
if not isinstance(adapter_instance, TaskProcessorAdapter):
69+
raise TypeError(
70+
f"{adapter_class.__name__} must inherit from TaskProcessorAdapter, got {type(adapter_instance).__name__}"
71+
)
72+
73+
try:
74+
adapter_instance.initialize(config=task_processor_config)
75+
except Exception as e:
76+
raise RuntimeError(f"Failed to initialize {adapter_class.__name__}: {e}")
77+
78+
return adapter_instance
79+
80+
81+
@PublicAPI(stability="alpha")
82+
def task_consumer(*, task_processor_config: TaskProcessorConfig):
83+
"""
84+
Decorator to mark a class as a TaskConsumer.
85+
86+
Args:
87+
task_processor_config: Configuration for the task processor (required)
88+
89+
Note:
90+
This decorator must be used with parentheses:
91+
@task_consumer(task_processor_config=config)
92+
93+
Returns:
94+
A wrapper class that inherits from the target class and implements the task consumer functionality.
95+
96+
Example:
97+
.. code-block:: python
98+
99+
from ray import serve
100+
from ray.serve.task_consumer import task_consumer, task_handler
101+
102+
@serve.deployment
103+
@task_consumer(task_processor_config=config)
104+
class MyTaskConsumer:
105+
106+
@task_handler(name="my_task")
107+
def my_task(self, *args, **kwargs):
108+
pass
109+
110+
"""
111+
112+
def decorator(target_cls):
113+
class TaskConsumerWrapper(target_cls):
114+
_adapter: TaskProcessorAdapter
115+
116+
def __init__(self, *args, **kwargs):
117+
target_cls.__init__(self, *args, **kwargs)
118+
119+
self._adapter = instantiate_adapter_from_config(task_processor_config)
120+
121+
for name, method in inspect.getmembers(
122+
target_cls, predicate=inspect.isfunction
123+
):
124+
if getattr(method, "_is_task_handler", False):
125+
task_name = getattr(method, "_task_name", name)
126+
127+
# Create a callable that properly binds the method to this instance
128+
bound_method = getattr(self, name)
129+
130+
self._adapter.register_task_handle(bound_method, task_name)
131+
132+
try:
133+
self._adapter.start_consumer()
134+
logger.info("task consumer started successfully")
135+
except Exception as e:
136+
logger.error(f"Failed to start task consumer: {e}")
137+
raise
138+
139+
def __del__(self):
140+
self._adapter.stop_consumer()
141+
self._adapter.shutdown()
142+
143+
if hasattr(target_cls, "__del__"):
144+
target_cls.__del__(self)
145+
146+
return TaskConsumerWrapper
147+
148+
return decorator
149+
150+
151+
@PublicAPI(stability="alpha")
152+
def task_handler(
153+
_func: Optional[Callable] = None, *, name: Optional[str] = None
154+
) -> Callable:
155+
"""
156+
Decorator to mark a method as a task handler.
157+
Optionally specify a task name. Default is the method name.
158+
159+
Arguments:
160+
_func: The function to decorate.
161+
name: The name of the task.
162+
163+
Returns:
164+
A wrapper function that is marked as a task handler.
165+
166+
Example:
167+
.. code-block:: python
168+
169+
from ray import serve
170+
from ray.serve.task_consumer import task_consumer, task_handler
171+
172+
@serve.deployment
173+
@task_consumer(task_processor_config=config)
174+
class MyTaskConsumer:
175+
176+
@task_handler(name="my_task")
177+
def my_task(self, *args, **kwargs):
178+
pass
179+
180+
"""
181+
182+
# Validate name parameter if provided
183+
if name is not None and (not isinstance(name, str) or not name.strip()):
184+
raise ValueError(f"Task name must be a non-empty string, got {name}")
185+
186+
def decorator(f):
187+
# async functions are not supported yet in celery `threads` worker pool
188+
if not inspect.iscoroutinefunction(f):
189+
190+
@wraps(f)
191+
def wrapper(*args, **kwargs):
192+
return f(*args, **kwargs)
193+
194+
wrapper._is_task_handler = True # type: ignore
195+
wrapper._task_name = name or f.__name__ # type: ignore
196+
return wrapper
197+
198+
else:
199+
raise NotImplementedError("Async task handlers are not supported yet")
200+
201+
if _func is not None:
202+
# Used without arguments: @task_handler
203+
return decorator(_func)
204+
else:
205+
# Used with arguments: @task_handler(name="...")
206+
return decorator

0 commit comments

Comments
 (0)