Skip to content

Commit 117daf0

Browse files
committed
fix: add abort for llm2
1 parent 306eb71 commit 117daf0

File tree

3 files changed

+166
-32
lines changed

3 files changed

+166
-32
lines changed

api/llm-interface.json

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
"name": "chat_completion",
66
"property": {
77
"properties": {
8+
"request_id": {
9+
"type": "string"
10+
},
811
"model": {
912
"type": "string"
1013
},
@@ -60,6 +63,16 @@
6063
}
6164
}
6265
}
66+
},
67+
{
68+
"name": "abort",
69+
"property": {
70+
"properties": {
71+
"request_id": {
72+
"type": "string"
73+
}
74+
}
75+
}
6376
}
6477
],
6578
"data_out": [

interface/ten_ai_base/llm2.py

Lines changed: 146 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
# See the LICENSE file for more information.
55
#
66
from abc import ABC, abstractmethod
7+
import asyncio
8+
import json
79
import traceback
8-
from typing import AsyncGenerator
10+
from typing import AsyncGenerator, Dict, Optional
911

10-
from .struct import LLMRequest, LLMResponse
12+
from .struct import LLMRequest, LLMRequestAbort, LLMResponse
1113
from ten_runtime import (
1214
AsyncExtension,
1315
)
@@ -31,60 +33,172 @@ class AsyncLLM2BaseExtension(AsyncExtension, ABC):
3133
def __init__(self, name: str):
3234
super().__init__(name)
3335
self.ten_env: AsyncTenEnv = None
36+
self._inflight: Dict[str, "AsyncLLM2BaseExtension._TaskCtx"] = {}
37+
self._lock = asyncio.Lock()
38+
3439

3540
async def on_init(self, async_ten_env: AsyncTenEnv) -> None:
3641
await super().on_init(async_ten_env)
42+
self.ten_env = async_ten_env
3743

3844
async def on_start(self, async_ten_env: AsyncTenEnv) -> None:
3945
await super().on_start(async_ten_env)
40-
self.ten_env = async_ten_env
4146

4247
async def on_stop(self, async_ten_env: AsyncTenEnv) -> None:
48+
await self._cancel_all()
4349
await super().on_stop(async_ten_env)
4450

4551
async def on_deinit(self, async_ten_env: AsyncTenEnv) -> None:
52+
await self._cancel_all()
4653
await super().on_deinit(async_ten_env)
4754

4855
async def on_cmd(self, async_ten_env: AsyncTenEnv, cmd: Cmd) -> None:
49-
"""
50-
handle default commands
51-
return True if the command is handled, False otherwise
52-
"""
5356
cmd_name = cmd.get_name()
54-
async_ten_env.log_debug(f"on_cmd name22 {cmd_name}")
57+
async_ten_env.log_debug(f"[LLM2Base] on_cmd: {cmd_name}")
5558
try:
5659
if cmd_name == "chat_completion":
5760
payload, err = cmd.get_property_to_json(None)
5861
if err:
59-
raise RuntimeError(f"Failed to get payload: {err}")
60-
args = LLMRequest.model_validate_json(
61-
payload
62-
)
63-
response = self.on_call_chat_completion(
64-
async_ten_env, args
65-
)
62+
raise RuntimeError(f"Failed to get payload: {err}")
63+
64+
req = LLMRequest.model_validate_json(payload)
65+
rid = req.request_id
66+
if not rid:
67+
raise RuntimeError("LLMRequest.request_id is required")
68+
69+
# Reject duplicates instead of replacing
70+
async with self._lock:
71+
existing = self._inflight.get(rid)
72+
if existing and not existing.task.done():
73+
async_ten_env.log_info(
74+
f"[LLM2Base] Duplicate request_id rejected: {rid}"
75+
)
76+
cr = CmdResult.create(StatusCode.ERROR, cmd)
77+
cr.set_property_from_json(
78+
None,
79+
json.dumps({
80+
"error": "request_id_already_running",
81+
"message": "A chat_completion with this request_id is already in progress.",
82+
"request_id": rid,
83+
}),
84+
)
85+
await async_ten_env.return_result(cr)
86+
return
6687

67-
async for llm_choice in response:
68-
# If the response is a final output, we can return it directly
69-
cmd_result = CmdResult.create(StatusCode.OK, cmd)
70-
cmd_result.set_property_from_json(
71-
None, llm_choice.model_dump_json()
88+
# Start streaming task
89+
await self._start_locked(async_ten_env, cmd, req)
90+
91+
# Ack creation (streaming results will arrive from the task)
92+
# await async_ten_env.return_result(CmdResult.create(StatusCode.OK, cmd))
93+
94+
elif cmd_name == "abort":
95+
payload, err = cmd.get_property_to_json(None)
96+
if err:
97+
raise RuntimeError(f"Failed to get payload: {err}")
98+
99+
abort = LLMRequestAbort.model_validate_json(payload)
100+
req_id: Optional[str] = getattr(abort, "request_id", None)
101+
102+
if req_id:
103+
cancelled = await self._cancel_one(req_id)
104+
async_ten_env.log_info(
105+
f"[LLM2Base] abort: request_id={req_id}, cancelled={cancelled}"
72106
)
73-
cmd_result.set_final(False)
74-
await async_ten_env.return_result(cmd_result)
107+
else:
108+
await self._cancel_all()
109+
async_ten_env.log_info("[LLM2Base] abort: all requests cancelled")
110+
111+
await async_ten_env.return_result(CmdResult.create(StatusCode.OK, cmd))
75112

76-
cmd_result = CmdResult.create(StatusCode.OK, cmd)
77-
cmd_result.set_final(True)
78-
await async_ten_env.return_result(cmd_result)
79113
else:
80-
await async_ten_env.return_result(
81-
CmdResult.create(StatusCode.OK, cmd)
114+
await async_ten_env.return_result(CmdResult.create(StatusCode.OK, cmd))
115+
116+
except Exception:
117+
async_ten_env.log_error(f"[LLM2Base] on_cmd error:\n{traceback.format_exc()}")
118+
await async_ten_env.return_result(CmdResult.create(StatusCode.ERROR, cmd))
119+
120+
# ---------------------------
121+
# Concurrency & task plumbing
122+
# ---------------------------
123+
124+
class _TaskCtx:
125+
__slots__ = ("task", "cmd", "request_id")
126+
def __init__(self, task: asyncio.Task, cmd: Cmd, request_id: str):
127+
self.task = task
128+
self.cmd = cmd
129+
self.request_id = request_id
130+
131+
async def _start_locked(self, ten_env: AsyncTenEnv, cmd: Cmd, req: LLMRequest) -> None:
132+
"""Call with self._lock held. Starts a task and registers it in _inflight."""
133+
rid = req.request_id
134+
task = asyncio.create_task(self._run_stream(ten_env, cmd, req), name=f"llm2:{rid}")
135+
self._inflight[rid] = self._TaskCtx(task=task, cmd=cmd, request_id=rid)
136+
task.add_done_callback(lambda t, rid=rid: asyncio.create_task(self._cleanup_after(rid)))
137+
138+
async def _run_stream(self, ten_env: AsyncTenEnv, cmd: Cmd, req: LLMRequest) -> None:
139+
rid = req.request_id
140+
try:
141+
gen = self.on_call_chat_completion(ten_env, req)
142+
async for chunk in gen:
143+
try:
144+
cr = CmdResult.create(StatusCode.OK, cmd)
145+
cr.set_property_from_json(None, chunk.model_dump_json())
146+
cr.set_final(False)
147+
await ten_env.return_result(cr)
148+
except Exception:
149+
ten_env.log_error(
150+
f"[LLM2Base] return_result streaming error (rid={rid}):\n{traceback.format_exc()}"
151+
)
152+
153+
final = CmdResult.create(StatusCode.OK, cmd)
154+
final.set_final(True)
155+
await ten_env.return_result(final)
156+
157+
except asyncio.CancelledError:
158+
ten_env.log_info(f"[LLM2Base] stream cancelled (rid={rid})")
159+
try:
160+
final = CmdResult.create(StatusCode.OK, cmd)
161+
# Optionally attach abort metadata:
162+
# final.set_property_from_json(None, json.dumps({"aborted": True, "request_id": rid}))
163+
final.set_final(True)
164+
await ten_env.return_result(final)
165+
except Exception:
166+
ten_env.log_error(
167+
f"[LLM2Base] error returning final for cancelled stream (rid={rid}):\n{traceback.format_exc()}"
168+
)
169+
raise
170+
except Exception:
171+
ten_env.log_error(f"[LLM2Base] stream error (rid={rid}):\n{traceback.format_exc()}")
172+
try:
173+
err_final = CmdResult.create(StatusCode.ERROR, cmd)
174+
err_final.set_final(True)
175+
await ten_env.return_result(err_final)
176+
except Exception:
177+
ten_env.log_error(
178+
f"[LLM2Base] error returning ERROR final (rid={rid}):\n{traceback.format_exc()}"
82179
)
83-
except Exception as e:
84-
async_ten_env.log_error(f"on_cmd error: {traceback.format_exc()}")
85-
await async_ten_env.return_result(
86-
CmdResult.create(StatusCode.ERROR, cmd)
87-
)
180+
181+
async def _cleanup_after(self, request_id: str) -> None:
182+
async with self._lock:
183+
ctx = self._inflight.get(request_id)
184+
if ctx and ctx.task.done():
185+
self._inflight.pop(request_id, None)
186+
187+
async def _cancel_one(self, request_id: str) -> bool:
188+
async with self._lock:
189+
ctx = self._inflight.get(request_id)
190+
if not ctx:
191+
return False
192+
if not ctx.task.done():
193+
ctx.task.cancel()
194+
return True
195+
return False
196+
197+
async def _cancel_all(self) -> None:
198+
async with self._lock:
199+
for ctx in list(self._inflight.values()):
200+
if not ctx.task.done():
201+
ctx.task.cancel()
88202

89203
@abstractmethod
90204
def on_call_chat_completion(

interface/ten_ai_base/struct.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,19 @@ class LLMRequest(BaseModel):
124124
Model for LLM input data.
125125
This model is used to define the structure of the input data for LLM operations.
126126
"""
127+
request_id: str
127128
model: str
128129
messages: list[LLMMessage]
129130
streaming: Optional[bool] = True
130131
tools: Optional[list[LLMToolMetadata]] = None
131132
parameters: Optional[dict[str, Any]] = None
132133

134+
class LLMRequestAbort(BaseModel):
135+
"""
136+
Model for LLM abort request.
137+
This model is used to define the structure of the request to abort an ongoing LLM operation.
138+
"""
139+
request_id: str
133140

134141
class EventType(str, Enum):
135142
MESSAGE_CONTENT_DELTA = "message_content_delta"

0 commit comments

Comments
 (0)