77from __future__ import annotations # for forward references
88
99import hashlib
10+ import inspect
1011import json
1112import os
1213from collections .abc import Generator
@@ -198,16 +199,11 @@ def _extract_model_identifiers():
198199
199200 Supported endpoints:
200201 - '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ]
201- - '/v1/models' (OpenAI): response body has 'data' : [ { id: ... }, ... ]
202+ - '/v1/models' (OpenAI): response body is : [ { id: ... }, ... ]
202203 Returns a list of unique identifiers or None if structure doesn't match.
203204 """
204- body = response ["body" ]
205- if endpoint == "/api/tags" :
206- items = body .get ("models" )
207- idents = [m .model for m in items ]
208- else :
209- items = body .get ("data" )
210- idents = [m .id for m in items ]
205+ items = response ["body" ]
206+ idents = [m .model if endpoint == "/api/tags" else m .id for m in items ]
211207 return sorted (set (idents ))
212208
213209 identifiers = _extract_model_identifiers ()
@@ -219,28 +215,22 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
219215 seen : dict [str , dict [str , Any ]] = {}
220216 for rec in records :
221217 body = rec ["response" ]["body" ]
222- if endpoint == "/api/tags" :
223- items = body .models
224- elif endpoint == "/v1/models" :
225- items = body .data
226- else :
227- items = []
228-
229- for m in items :
230- if endpoint == "/v1/models" :
218+ if endpoint == "/v1/models" :
219+ for m in body :
231220 key = m .id
232- else :
221+ seen [key ] = m
222+ elif endpoint == "/api/tags" :
223+ for m in body .models :
233224 key = m .model
234- seen [key ] = m
225+ seen [key ] = m
235226
236227 ordered = [seen [k ] for k in sorted (seen .keys ())]
237228 canonical = records [0 ]
238229 canonical_req = canonical .get ("request" , {})
239230 if isinstance (canonical_req , dict ):
240231 canonical_req ["endpoint" ] = endpoint
241- if endpoint == "/v1/models" :
242- body = {"data" : ordered , "object" : "list" }
243- else :
232+ body = ordered
233+ if endpoint == "/api/tags" :
244234 from ollama import ListResponse
245235
246236 body = ListResponse (models = ordered )
@@ -252,7 +242,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
252242
253243 if _current_mode == InferenceMode .LIVE or _current_storage is None :
254244 # Normal operation
255- return await original_method (self , * args , ** kwargs )
245+ if inspect .iscoroutinefunction (original_method ):
246+ return await original_method (self , * args , ** kwargs )
247+ else :
248+ return original_method (self , * args , ** kwargs )
256249
257250 # Get base URL based on client type
258251 if client_type == "openai" :
@@ -300,7 +293,14 @@ async def replay_stream():
300293 )
301294
302295 elif _current_mode == InferenceMode .RECORD :
303- response = await original_method (self , * args , ** kwargs )
296+ if inspect .iscoroutinefunction (original_method ):
297+ response = await original_method (self , * args , ** kwargs )
298+ else :
299+ response = original_method (self , * args , ** kwargs )
300+
301+ # we want to store the result of the iterator, not the iterator itself
302+ if endpoint == "/v1/models" :
303+ response = [m async for m in response ]
304304
305305 request_data = {
306306 "method" : method ,
@@ -380,10 +380,14 @@ async def patched_embeddings_create(self, *args, **kwargs):
380380 _original_methods ["embeddings_create" ], self , "openai" , "/v1/embeddings" , * args , ** kwargs
381381 )
382382
383- async def patched_models_list (self , * args , ** kwargs ):
384- return await _patched_inference_method (
385- _original_methods ["models_list" ], self , "openai" , "/v1/models" , * args , ** kwargs
386- )
383+ def patched_models_list (self , * args , ** kwargs ):
384+ async def _iter ():
385+ for item in await _patched_inference_method (
386+ _original_methods ["models_list" ], self , "openai" , "/v1/models" , * args , ** kwargs
387+ ):
388+ yield item
389+
390+ return _iter ()
387391
388392 # Apply OpenAI patches
389393 AsyncChatCompletions .create = patched_chat_completions_create
0 commit comments