33#
44# This source code is licensed under the terms described in the LICENSE file in
55# the root directory of this source tree.
6+ import asyncio
7+ from typing import Any
8+
69from llama_stack .apis .inference import (
710 ListOpenAIChatCompletionResponse ,
811 OpenAIChatCompletion ,
912 OpenAICompletionWithInputMessages ,
1013 OpenAIMessageParam ,
1114 Order ,
1215)
13- from llama_stack .core .datatypes import AccessRule
14- from llama_stack .core . utils . config_dirs import RUNTIME_BASE_DIR
16+ from llama_stack .core .datatypes import AccessRule , InferenceStoreConfig
17+ from llama_stack .log import get_logger
1518
1619from ..sqlstore .api import ColumnDefinition , ColumnType
1720from ..sqlstore .authorized_sqlstore import AuthorizedSqlStore
18- from ..sqlstore .sqlstore import SqliteSqlStoreConfig , SqlStoreConfig , sqlstore_impl
21+ from ..sqlstore .sqlstore import SqlStoreConfig , SqlStoreType , sqlstore_impl
22+
23+ logger = get_logger (name = __name__ , category = "inference_store" )
1924
2025
2126class InferenceStore :
22- def __init__ (self , sql_store_config : SqlStoreConfig , policy : list [AccessRule ]):
23- if not sql_store_config :
24- sql_store_config = SqliteSqlStoreConfig (
25- db_path = (RUNTIME_BASE_DIR / "sqlstore.db" ).as_posix (),
27+ def __init__ (
28+ self ,
29+ config : InferenceStoreConfig | SqlStoreConfig ,
30+ policy : list [AccessRule ],
31+ ):
32+ # Handle backward compatibility
33+ if not isinstance (config , InferenceStoreConfig ):
34+ # Legacy: SqlStoreConfig passed directly as config
35+ config = InferenceStoreConfig (
36+ sql_store_config = config ,
2637 )
27- self .sql_store_config = sql_store_config
38+
39+ self .config = config
40+ self .sql_store_config = config .sql_store_config
2841 self .sql_store = None
2942 self .policy = policy
3043
44+ # Disable write queue for SQLite to avoid concurrency issues
45+ self .enable_write_queue = self .sql_store_config .type != SqlStoreType .sqlite
46+
47+ # Async write queue and worker control
48+ self ._queue : asyncio .Queue [tuple [OpenAIChatCompletion , list [OpenAIMessageParam ]]] | None = None
49+ self ._worker_tasks : list [asyncio .Task [Any ]] = []
50+ self ._max_write_queue_size : int = config .max_write_queue_size
51+ self ._num_writers : int = max (1 , config .num_writers )
52+
3153 async def initialize (self ):
3254 """Create the necessary tables if they don't exist."""
3355 self .sql_store = AuthorizedSqlStore (sqlstore_impl (self .sql_store_config ))
@@ -42,10 +64,68 @@ async def initialize(self):
4264 },
4365 )
4466
67+ if self .enable_write_queue :
68+ self ._queue = asyncio .Queue (maxsize = self ._max_write_queue_size )
69+ for _ in range (self ._num_writers ):
70+ self ._worker_tasks .append (asyncio .create_task (self ._worker_loop ()))
71+ else :
72+ logger .info ("Write queue disabled for SQLite to avoid concurrency issues" )
73+
74+ async def shutdown (self ) -> None :
75+ if not self ._worker_tasks :
76+ return
77+ if self ._queue is not None :
78+ await self ._queue .join ()
79+ for t in self ._worker_tasks :
80+ if not t .done ():
81+ t .cancel ()
82+ for t in self ._worker_tasks :
83+ try :
84+ await t
85+ except asyncio .CancelledError :
86+ pass
87+ self ._worker_tasks .clear ()
88+
89+ async def flush (self ) -> None :
90+ """Wait for all queued writes to complete. Useful for testing."""
91+ if self .enable_write_queue and self ._queue is not None :
92+ await self ._queue .join ()
93+
4594 async def store_chat_completion (
4695 self , chat_completion : OpenAIChatCompletion , input_messages : list [OpenAIMessageParam ]
4796 ) -> None :
48- if not self .sql_store :
97+ if self .enable_write_queue :
98+ if self ._queue is None :
99+ raise ValueError ("Inference store is not initialized" )
100+ try :
101+ self ._queue .put_nowait ((chat_completion , input_messages ))
102+ except asyncio .QueueFull :
103+ logger .warning (
104+ f"Write queue full; adding chat completion id={ getattr (chat_completion , 'id' , '<unknown>' )} "
105+ )
106+ await self ._queue .put ((chat_completion , input_messages ))
107+ else :
108+ await self ._write_chat_completion (chat_completion , input_messages )
109+
110+ async def _worker_loop (self ) -> None :
111+ assert self ._queue is not None
112+ while True :
113+ try :
114+ item = await self ._queue .get ()
115+ except asyncio .CancelledError :
116+ break
117+ chat_completion , input_messages = item
118+ try :
119+ await self ._write_chat_completion (chat_completion , input_messages )
120+ except Exception as e : # noqa: BLE001
121+ logger .error (f"Error writing chat completion: { e } " )
122+ finally :
123+ self ._queue .task_done ()
124+
125+ async def _write_chat_completion (
126+ self , chat_completion : OpenAIChatCompletion , input_messages : list [OpenAIMessageParam ]
127+ ) -> None :
128+ if self .sql_store is None :
49129 raise ValueError ("Inference store is not initialized" )
50130
51131 data = chat_completion .model_dump ()
0 commit comments