44# See the LICENSE file for more information.
55#
66from abc import ABC , abstractmethod
7+ import asyncio
8+ import json
79import traceback
8- from typing import AsyncGenerator
10+ from typing import AsyncGenerator , Dict , Optional
911
10- from .struct import LLMRequest , LLMResponse
12+ from .struct import LLMRequest , LLMRequestAbort , LLMResponse
1113from ten_runtime import (
1214 AsyncExtension ,
1315)
@@ -31,60 +33,172 @@ class AsyncLLM2BaseExtension(AsyncExtension, ABC):
3133 def __init__ (self , name : str ):
3234 super ().__init__ (name )
3335 self .ten_env : AsyncTenEnv = None
36+ self ._inflight : Dict [str , "AsyncLLM2BaseExtension._TaskCtx" ] = {}
37+ self ._lock = asyncio .Lock ()
38+
3439
3540 async def on_init (self , async_ten_env : AsyncTenEnv ) -> None :
3641 await super ().on_init (async_ten_env )
42+ self .ten_env = async_ten_env
3743
3844 async def on_start (self , async_ten_env : AsyncTenEnv ) -> None :
3945 await super ().on_start (async_ten_env )
40- self .ten_env = async_ten_env
4146
4247 async def on_stop (self , async_ten_env : AsyncTenEnv ) -> None :
48+ await self ._cancel_all ()
4349 await super ().on_stop (async_ten_env )
4450
4551 async def on_deinit (self , async_ten_env : AsyncTenEnv ) -> None :
52+ await self ._cancel_all ()
4653 await super ().on_deinit (async_ten_env )
4754
4855 async def on_cmd (self , async_ten_env : AsyncTenEnv , cmd : Cmd ) -> None :
49- """
50- handle default commands
51- return True if the command is handled, False otherwise
52- """
5356 cmd_name = cmd .get_name ()
54- async_ten_env .log_debug (f"on_cmd name22 { cmd_name } " )
57+ async_ten_env .log_debug (f"[LLM2Base] on_cmd: { cmd_name } " )
5558 try :
5659 if cmd_name == "chat_completion" :
5760 payload , err = cmd .get_property_to_json (None )
5861 if err :
59- raise RuntimeError (f"Failed to get payload: { err } " )
60- args = LLMRequest .model_validate_json (
61- payload
62- )
63- response = self .on_call_chat_completion (
64- async_ten_env , args
65- )
62+ raise RuntimeError (f"Failed to get payload: { err } " )
63+
64+ req = LLMRequest .model_validate_json (payload )
65+ rid = req .request_id
66+ if not rid :
67+ raise RuntimeError ("LLMRequest.request_id is required" )
68+
69+ # Reject duplicates instead of replacing
70+ async with self ._lock :
71+ existing = self ._inflight .get (rid )
72+ if existing and not existing .task .done ():
73+ async_ten_env .log_info (
74+ f"[LLM2Base] Duplicate request_id rejected: { rid } "
75+ )
76+ cr = CmdResult .create (StatusCode .ERROR , cmd )
77+ cr .set_property_from_json (
78+ None ,
79+ json .dumps ({
80+ "error" : "request_id_already_running" ,
81+ "message" : "A chat_completion with this request_id is already in progress." ,
82+ "request_id" : rid ,
83+ }),
84+ )
85+ await async_ten_env .return_result (cr )
86+ return
6687
67- async for llm_choice in response :
68- # If the response is a final output, we can return it directly
69- cmd_result = CmdResult .create (StatusCode .OK , cmd )
70- cmd_result .set_property_from_json (
71- None , llm_choice .model_dump_json ()
88+ # Start streaming task
89+ await self ._start_locked (async_ten_env , cmd , req )
90+
91+ # Ack creation (streaming results will arrive from the task)
92+ # await async_ten_env.return_result(CmdResult.create(StatusCode.OK, cmd))
93+
94+ elif cmd_name == "abort" :
95+ payload , err = cmd .get_property_to_json (None )
96+ if err :
97+ raise RuntimeError (f"Failed to get payload: { err } " )
98+
99+ abort = LLMRequestAbort .model_validate_json (payload )
100+ req_id : Optional [str ] = getattr (abort , "request_id" , None )
101+
102+ if req_id :
103+ cancelled = await self ._cancel_one (req_id )
104+ async_ten_env .log_info (
105+ f"[LLM2Base] abort: request_id={ req_id } , cancelled={ cancelled } "
72106 )
73- cmd_result .set_final (False )
74- await async_ten_env .return_result (cmd_result )
107+ else :
108+ await self ._cancel_all ()
109+ async_ten_env .log_info ("[LLM2Base] abort: all requests cancelled" )
110+
111+ await async_ten_env .return_result (CmdResult .create (StatusCode .OK , cmd ))
75112
76- cmd_result = CmdResult .create (StatusCode .OK , cmd )
77- cmd_result .set_final (True )
78- await async_ten_env .return_result (cmd_result )
79113 else :
80- await async_ten_env .return_result (
81- CmdResult .create (StatusCode .OK , cmd )
114+ await async_ten_env .return_result (CmdResult .create (StatusCode .OK , cmd ))
115+
116+ except Exception :
117+ async_ten_env .log_error (f"[LLM2Base] on_cmd error:\n { traceback .format_exc ()} " )
118+ await async_ten_env .return_result (CmdResult .create (StatusCode .ERROR , cmd ))
119+
120+ # ---------------------------
121+ # Concurrency & task plumbing
122+ # ---------------------------
123+
124+ class _TaskCtx :
125+ __slots__ = ("task" , "cmd" , "request_id" )
126+ def __init__ (self , task : asyncio .Task , cmd : Cmd , request_id : str ):
127+ self .task = task
128+ self .cmd = cmd
129+ self .request_id = request_id
130+
131+ async def _start_locked (self , ten_env : AsyncTenEnv , cmd : Cmd , req : LLMRequest ) -> None :
132+ """Call with self._lock held. Starts a task and registers it in _inflight."""
133+ rid = req .request_id
134+ task = asyncio .create_task (self ._run_stream (ten_env , cmd , req ), name = f"llm2:{ rid } " )
135+ self ._inflight [rid ] = self ._TaskCtx (task = task , cmd = cmd , request_id = rid )
136+ task .add_done_callback (lambda t , rid = rid : asyncio .create_task (self ._cleanup_after (rid )))
137+
138+ async def _run_stream (self , ten_env : AsyncTenEnv , cmd : Cmd , req : LLMRequest ) -> None :
139+ rid = req .request_id
140+ try :
141+ gen = self .on_call_chat_completion (ten_env , req )
142+ async for chunk in gen :
143+ try :
144+ cr = CmdResult .create (StatusCode .OK , cmd )
145+ cr .set_property_from_json (None , chunk .model_dump_json ())
146+ cr .set_final (False )
147+ await ten_env .return_result (cr )
148+ except Exception :
149+ ten_env .log_error (
150+ f"[LLM2Base] return_result streaming error (rid={ rid } ):\n { traceback .format_exc ()} "
151+ )
152+
153+ final = CmdResult .create (StatusCode .OK , cmd )
154+ final .set_final (True )
155+ await ten_env .return_result (final )
156+
157+ except asyncio .CancelledError :
158+ ten_env .log_info (f"[LLM2Base] stream cancelled (rid={ rid } )" )
159+ try :
160+ final = CmdResult .create (StatusCode .OK , cmd )
161+ # Optionally attach abort metadata:
162+ # final.set_property_from_json(None, json.dumps({"aborted": True, "request_id": rid}))
163+ final .set_final (True )
164+ await ten_env .return_result (final )
165+ except Exception :
166+ ten_env .log_error (
167+ f"[LLM2Base] error returning final for cancelled stream (rid={ rid } ):\n { traceback .format_exc ()} "
168+ )
169+ raise
170+ except Exception :
171+ ten_env .log_error (f"[LLM2Base] stream error (rid={ rid } ):\n { traceback .format_exc ()} " )
172+ try :
173+ err_final = CmdResult .create (StatusCode .ERROR , cmd )
174+ err_final .set_final (True )
175+ await ten_env .return_result (err_final )
176+ except Exception :
177+ ten_env .log_error (
178+ f"[LLM2Base] error returning ERROR final (rid={ rid } ):\n { traceback .format_exc ()} "
82179 )
83- except Exception as e :
84- async_ten_env .log_error (f"on_cmd error: { traceback .format_exc ()} " )
85- await async_ten_env .return_result (
86- CmdResult .create (StatusCode .ERROR , cmd )
87- )
180+
181+ async def _cleanup_after (self , request_id : str ) -> None :
182+ async with self ._lock :
183+ ctx = self ._inflight .get (request_id )
184+ if ctx and ctx .task .done ():
185+ self ._inflight .pop (request_id , None )
186+
187+ async def _cancel_one (self , request_id : str ) -> bool :
188+ async with self ._lock :
189+ ctx = self ._inflight .get (request_id )
190+ if not ctx :
191+ return False
192+ if not ctx .task .done ():
193+ ctx .task .cancel ()
194+ return True
195+ return False
196+
197+ async def _cancel_all (self ) -> None :
198+ async with self ._lock :
199+ for ctx in list (self ._inflight .values ()):
200+ if not ctx .task .done ():
201+ ctx .task .cancel ()
88202
89203 @abstractmethod
90204 def on_call_chat_completion (
0 commit comments