@@ -159,7 +159,6 @@ async def reconfigure(
159
159
@property
160
160
def max_batch_size (self ):
161
161
return (self .args .model_config .generation .max_batch_size if self .args .model_config .generation else 1 )
162
- # return 1
163
162
164
163
@property
165
164
def batch_wait_timeout_s (self ):
@@ -194,30 +193,12 @@ async def generate_text(self, prompt: Prompt):
194
193
with async_timeout .timeout (GATEWAY_TIMEOUT_S ):
195
194
text = await self .generate_text_batch (
196
195
prompt ,
197
- # [prompt],
198
- # priority=QueuePriority.GENERATE_TEXT,
199
196
# start_timestamp=start_timestamp,
200
197
)
201
198
logger .info (f"generated text: { text } " )
202
199
# return text[0]
203
200
return text
204
201
205
- # no need anymore, will be delete soon
206
- async def generate (self , prompt : Prompt ):
207
- time .time ()
208
- logger .info (prompt )
209
- logger .info (self .get_max_batch_size ())
210
- logger .info (self .get_batch_wait_timeout_s ())
211
- with async_timeout .timeout (GATEWAY_TIMEOUT_S ):
212
- text = await self .generate_text_batch (
213
- prompt ,
214
- # [prompt],
215
- # priority=QueuePriority.GENERATE_TEXT,
216
- # start_timestamp=start_timestamp,
217
- )
218
- return text
219
- # return text[0]
220
-
221
202
@app .post ("/batch" , include_in_schema = False )
222
203
async def batch_generate_text (self , prompts : List [Prompt ]):
223
204
logger .info (f"batch_generate_text prompts: { prompts } " )
@@ -229,7 +210,6 @@ async def batch_generate_text(self, prompts: List[Prompt]):
229
210
* [
230
211
self .generate_text_batch (
231
212
prompt ,
232
- # priority=QueuePriority.BATCH_GENERATE_TEXT,
233
213
# start_timestamp=start_timestamp,
234
214
)
235
215
for prompt in prompts
@@ -333,20 +313,22 @@ def __init__(self, models: Dict[str, DeploymentHandle], model_configurations: Di
333
313
async def predict (self , model : str , prompt : Union [Prompt , List [Prompt ]]) -> Union [Dict [str , Any ], List [Dict [str , Any ]], List [Any ]]:
334
314
logger .info (f"url: { model } , keys: { self ._models .keys ()} " )
335
315
modelKeys = list (self ._models .keys ())
336
- # model = _replace_prefix(model)
316
+
337
317
modelID = model
338
318
for item in modelKeys :
339
319
logger .info (f"_reverse_prefix(item): { _reverse_prefix (item )} " )
340
320
if _reverse_prefix (item ) == model :
341
321
modelID = item
342
322
logger .info (f"set modelID: { item } " )
343
323
logger .info (f"search model key { modelID } " )
324
+
344
325
if isinstance (prompt , Prompt ):
345
326
results = await asyncio .gather (* [self ._models [modelID ].generate_text .remote (prompt )])
346
327
elif isinstance (prompt , list ):
347
328
results = await asyncio .gather (* [self ._models [modelID ].batch_generate_text .remote (prompt )])
348
329
else :
349
330
raise Exception ("Invaid prompt format." )
331
+
350
332
logger .info (f"{ results } " )
351
333
return results [0 ]
352
334
0 commit comments