1+ import asyncio
2+
3+ import aiohttp
14import fastapi
25import pydantic
6+ import yaml
7+ from aiohttp .client_exceptions import ClientConnectorError , ServerTimeoutError
8+ from fastapi import HTTPException
9+ from loguru import logger
310from oasst_inference_server .settings import settings
411from oasst_shared import model_configs
512from oasst_shared .schemas import inference
613
14+ # NOTE: Populate this with plugins that we will provide out of the box
15+ OA_PLUGINS = []
16+
717router = fastapi .APIRouter (
818 prefix = "/configs" ,
919 tags = ["configs" ],
@@ -63,6 +73,16 @@ class ModelConfigInfo(pydantic.BaseModel):
6373 repetition_penalty = 1.2 ,
6474 ),
6575 ),
76+ ParameterConfig (
77+ name = "k50-Plugins" ,
78+ description = "Top-k sampling with k=50 and temperature=0.35" ,
79+ sampling_parameters = inference .SamplingParameters (
80+ max_new_tokens = 1024 ,
81+ temperature = 0.35 ,
82+ top_k = 50 ,
83+ repetition_penalty = (1 / 0.90 ),
84+ ),
85+ ),
6686 ParameterConfig (
6787 name = "nucleus9" ,
6888 description = "Nucleus sampling with p=0.9" ,
@@ -93,6 +113,44 @@ class ModelConfigInfo(pydantic.BaseModel):
93113]
94114
95115
116+ async def fetch_plugin (url : str , retries : int = 3 , timeout : float = 5.0 ) -> inference .PluginConfig :
117+ async with aiohttp .ClientSession () as session :
118+ for attempt in range (retries ):
119+ try :
120+ async with session .get (url , timeout = timeout ) as response :
121+ content_type = response .headers .get ("Content-Type" )
122+
123+ if response .status == 200 :
124+ if "application/json" in content_type or url .endswith (".json" ):
125+ config = await response .json ()
126+ elif (
127+ "application/yaml" in content_type
128+ or "application/x-yaml" in content_type
129+ or url .endswith (".yaml" )
130+ or url .endswith (".yml" )
131+ ):
132+ config = yaml .safe_load (await response .text ())
133+ else :
134+ raise HTTPException (
135+ status_code = 400 ,
136+ detail = f"Unsupported content type: { content_type } . Only JSON and YAML are supported." ,
137+ )
138+
139+ return inference .PluginConfig (** config )
140+ elif response .status == 404 :
141+ raise HTTPException (status_code = 404 , detail = "Plugin not found" )
142+ else :
143+ raise HTTPException (status_code = response .status , detail = "Unexpected status code" )
144+ except (ClientConnectorError , ServerTimeoutError ) as e :
145+ if attempt == retries - 1 : # last attempt
146+ raise HTTPException (status_code = 500 , detail = f"Request failed after { retries } retries: { e } " )
147+ await asyncio .sleep (2 ** attempt ) # exponential backoff
148+
149+ except aiohttp .ClientError as e :
150+ raise HTTPException (status_code = 500 , detail = f"Request failed: { e } " )
151+ raise HTTPException (status_code = 500 , detail = "Failed to fetch plugin" )
152+
153+
96154@router .get ("/model_configs" )
97155async def get_model_configs () -> list [ModelConfigInfo ]:
98156 return [
@@ -103,3 +161,36 @@ async def get_model_configs() -> list[ModelConfigInfo]:
103161 for model_config_name in model_configs .MODEL_CONFIGS
104162 if (settings .allowed_model_config_names == "*" or model_config_name in settings .allowed_model_config_names_list )
105163 ]
164+
165+
166+ @router .post ("/plugin_config" )
167+ async def get_plugin_config (plugin : inference .PluginEntry ) -> inference .PluginEntry :
168+ try :
169+ plugin_config = await fetch_plugin (plugin .url )
170+ except HTTPException as e :
171+ logger .warning (f"Failed to fetch plugin config from { plugin .url } : { e .detail } " )
172+ raise fastapi .HTTPException (status_code = e .status_code , detail = e .detail )
173+
174+ return inference .PluginEntry (url = plugin .url , enabled = plugin .enabled , plugin_config = plugin_config )
175+
176+
177+ @router .get ("/builtin_plugins" )
178+ async def get_builtin_plugins () -> list [inference .PluginEntry ]:
179+ plugins = []
180+
181+ for plugin in OA_PLUGINS :
182+ try :
183+ plugin_config = await fetch_plugin (plugin .url )
184+ except HTTPException as e :
185+ logger .warning (f"Failed to fetch plugin config from { plugin .url } : { e .detail } " )
186+ continue
187+
188+ final_plugin : inference .PluginEntry = inference .PluginEntry (
189+ url = plugin .url ,
190+ enabled = plugin .enabled ,
191+ trusted = plugin .trusted ,
192+ plugin_config = plugin_config ,
193+ )
194+ plugins .append (final_plugin )
195+
196+ return plugins
0 commit comments