Skip to content

Commit b407adc

Browse files
authored
Pass details param into client (#265)
1 parent dd68924 commit b407adc

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

clients/python/lorax/client.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def generate(
8181
watermark: bool = False,
8282
response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None,
8383
decoder_input_details: bool = False,
84+
details: bool = True,
8485
) -> Response:
8586
"""
8687
Given a prompt, generate the following text
@@ -138,6 +139,8 @@ def generate(
138139
```
139140
decoder_input_details (`bool`):
140141
Return the decoder input token logprobs and ids
142+
details (`bool`):
143+
Return the token logprobs and ids for generated tokens
141144

142145
Returns:
143146
Response: generated response
@@ -149,7 +152,7 @@ def generate(
149152
merged_adapters=merged_adapters,
150153
api_token=api_token,
151154
best_of=best_of,
152-
details=True,
155+
details=details,
153156
do_sample=do_sample,
154157
max_new_tokens=max_new_tokens,
155158
repetition_penalty=repetition_penalty,
@@ -202,6 +205,7 @@ def generate_stream(
202205
typical_p: Optional[float] = None,
203206
watermark: bool = False,
204207
response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None,
208+
details: bool = True,
205209
) -> Iterator[StreamResponse]:
206210
"""
207211
Given a prompt, generate the following stream of tokens
@@ -255,6 +259,8 @@ def generate_stream(
255259
}
256260
}
257261
```
262+
details (`bool`):
263+
Return the token logprobs and ids for generated tokens
258264

259265
Returns:
260266
Iterator[StreamResponse]: stream of generated tokens
@@ -266,7 +272,7 @@ def generate_stream(
266272
merged_adapters=merged_adapters,
267273
api_token=api_token,
268274
best_of=None,
269-
details=True,
275+
details=details,
270276
decoder_input_details=False,
271277
do_sample=do_sample,
272278
max_new_tokens=max_new_tokens,
@@ -384,6 +390,7 @@ async def generate(
384390
watermark: bool = False,
385391
response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None,
386392
decoder_input_details: bool = False,
393+
details: bool = True,
387394
) -> Response:
388395
"""
389396
Given a prompt, generate the following text asynchronously
@@ -441,6 +448,8 @@ async def generate(
441448
```
442449
decoder_input_details (`bool`):
443450
Return the decoder input token logprobs and ids
451+
details (`bool`):
452+
Return the token logprobs and ids for generated tokens
444453

445454
Returns:
446455
Response: generated response
@@ -452,7 +461,7 @@ async def generate(
452461
merged_adapters=merged_adapters,
453462
api_token=api_token,
454463
best_of=best_of,
455-
details=True,
464+
details=details,
456465
decoder_input_details=decoder_input_details,
457466
do_sample=do_sample,
458467
max_new_tokens=max_new_tokens,
@@ -500,6 +509,7 @@ async def generate_stream(
500509
typical_p: Optional[float] = None,
501510
watermark: bool = False,
502511
response_format: Optional[Union[Dict[str, Any], ResponseFormat]] = None,
512+
details: bool = True,
503513
) -> AsyncIterator[StreamResponse]:
504514
"""
505515
Given a prompt, generate the following stream of tokens asynchronously
@@ -553,6 +563,8 @@ async def generate_stream(
553563
}
554564
}
555565
```
566+
details (`bool`):
567+
Return the token logprobs and ids for generated tokens
556568

557569
Returns:
558570
AsyncIterator[StreamResponse]: stream of generated tokens
@@ -564,7 +576,7 @@ async def generate_stream(
564576
merged_adapters=merged_adapters,
565577
api_token=api_token,
566578
best_of=None,
567-
details=True,
579+
details=details,
568580
decoder_input_details=False,
569581
do_sample=do_sample,
570582
max_new_tokens=max_new_tokens,

clients/python/lorax/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ class Response(BaseModel):
289289
# Generated text
290290
generated_text: str
291291
# Generation details
292-
details: Details
292+
details: Optional[Details]
293293

294294

295295
# `generate_stream` details

0 commit comments

Comments
 (0)