Skip to content

Commit 32a2530

Browse files
feat: Return logprobs (IBM#8)
1 parent 718096f commit 32a2530

File tree

18 files changed

+247
-94
lines changed

18 files changed

+247
-94
lines changed

.github/workflows/server-tests.yaml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
name: Server Tests
2+
3+
on:
4+
pull_request:
5+
paths:
6+
- "server/**"
7+
- "proto/**"
8+
9+
jobs:
10+
run_tests:
11+
runs-on: ubuntu-20.04
12+
steps:
13+
- uses: actions/checkout@v2
14+
- name: Set up Python
15+
uses: actions/setup-python@v1
16+
with:
17+
python-version: 3.9
18+
- name: Loading cache.
19+
uses: actions/cache@v2
20+
id: model_cache
21+
with:
22+
path: ~/.cache/huggingface/
23+
key: models
24+
- name: Install server dependencies
25+
run: |
26+
make install-server
27+
- name: Run tests
28+
run: |
29+
pip install pytest
30+
pytest -sv server/tests

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
1717
- 45ms per token generation for BLOOM with 8xA100 80GB
1818
- Logits warpers (temperature scaling, topk ...)
1919
- Stop sequences
20+
- Log probabilities
2021

2122
## Officially supported models
2223

proto/generate.proto

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ message ClearCacheRequest {}
2727
/// Empty response
2828
message ClearCacheResponse {}
2929

30-
message LogitsWarperParameters {
30+
message NextTokenChooserParameters {
3131
/// exponential scaling output probability distribution
3232
float temperature = 1;
3333
/// restricting to the k highest probability elements
@@ -52,8 +52,8 @@ message Request {
5252
string inputs = 2;
5353
/// The number of tokens inside inputs
5454
uint32 input_length = 3;
55-
/// Logits Warper Parameters
56-
LogitsWarperParameters parameters = 4;
55+
/// Next Token Chooser Parameters
56+
NextTokenChooserParameters parameters = 4;
5757
/// Stopping Criteria Parameters
5858
StoppingCriteriaParameters stopping_parameters = 5;
5959
}
@@ -71,11 +71,17 @@ message GeneratedText {
7171
/// Request
7272
Request request = 1;
7373
/// Output
74-
string output = 2;
74+
string output_text = 2;
7575
/// Number of generated tokens
76-
uint32 tokens = 3;
76+
uint32 generated_tokens = 3;
77+
/// Tokens
78+
repeated string tokens = 4;
79+
/// Token IDs
80+
repeated uint32 token_ids = 5;
81+
/// Logprobs
82+
repeated float logprobs = 6;
7783
/// Finish reason
78-
string finish_reason = 4;
84+
string finish_reason = 7;
7985
}
8086

8187
message GenerateRequest {

router/client/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ mod sharded_client;
77

88
pub use client::Client;
99
pub use pb::generate::v1::{
10-
Batch, GeneratedText, LogitsWarperParameters, Request, StoppingCriteriaParameters,
10+
Batch, GeneratedText, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
1111
};
1212
pub use sharded_client::ShardedClient;
1313
use thiserror::Error;

router/src/batcher.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,13 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
187187
let entry = db
188188
.remove(&output.request.unwrap().id)
189189
.expect("ID not found in db. This is a bug.");
190+
190191
let response = InferResponse {
191-
output: output.output,
192+
output_text: output.output_text,
193+
generated_tokens: output.generated_tokens,
194+
token_ids: output.token_ids,
192195
tokens: output.tokens,
196+
logprobs: output.logprobs,
193197
finish_reason: output.finish_reason,
194198
queued: entry.time,
195199
start: entry.batch_time.unwrap(), // unwrap is always valid
@@ -202,8 +206,11 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
202206

203207
#[derive(Debug)]
204208
pub(crate) struct InferResponse {
205-
pub(crate) output: String,
206-
pub(crate) tokens: u32,
209+
pub(crate) output_text: String,
210+
pub(crate) generated_tokens: u32,
211+
pub(crate) token_ids: Vec<u32>,
212+
pub(crate) tokens: Vec<String>,
213+
pub(crate) logprobs: Vec<f32>,
207214
pub(crate) finish_reason: String,
208215
pub(crate) queued: Instant,
209216
pub(crate) start: Instant,

router/src/db.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use parking_lot::Mutex;
55
use std::collections::BTreeMap;
66
use std::sync::Arc;
77
use text_generation_client::{
8-
Batch, ClientError, LogitsWarperParameters, Request, StoppingCriteriaParameters,
8+
Batch, ClientError, NextTokenChooserParameters, Request, StoppingCriteriaParameters,
99
};
1010
use tokio::sync::oneshot::Sender;
1111
use tokio::time::Instant;
@@ -71,7 +71,7 @@ impl State {
7171
id: *id,
7272
inputs: entry.request.inputs.clone(),
7373
input_length: entry.input_length as u32,
74-
parameters: Some(LogitsWarperParameters::from(
74+
parameters: Some(NextTokenChooserParameters::from(
7575
entry.request.parameters.clone(),
7676
)),
7777
stopping_parameters: Some(StoppingCriteriaParameters::from(
@@ -162,7 +162,7 @@ impl Db {
162162
}
163163
}
164164

165-
impl From<GenerateParameters> for LogitsWarperParameters {
165+
impl From<GenerateParameters> for NextTokenChooserParameters {
166166
fn from(parameters: GenerateParameters) -> Self {
167167
Self {
168168
temperature: parameters.temperature,

router/src/lib.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ pub(crate) struct GenerateParameters {
2121
pub do_sample: bool,
2222
#[serde(default = "default_max_new_tokens")]
2323
pub max_new_tokens: u32,
24+
#[serde(default)]
2425
pub stop: Vec<String>,
26+
#[serde(default)]
27+
pub details: bool,
2528
}
2629

2730
fn default_temperature() -> f32 {
@@ -52,6 +55,7 @@ fn default_parameters() -> GenerateParameters {
5255
do_sample: default_do_sample(),
5356
max_new_tokens: default_max_new_tokens(),
5457
stop: vec![],
58+
details: false,
5559
}
5660
}
5761

@@ -62,10 +66,18 @@ pub(crate) struct GenerateRequest {
6266
pub parameters: GenerateParameters,
6367
}
6468

69+
#[derive(Serialize)]
70+
pub(crate) struct Details {
71+
pub finish_reason: String,
72+
pub generated_tokens: u32,
73+
pub tokens: Vec<(u32, String, f32)>,
74+
}
75+
6576
#[derive(Serialize)]
6677
pub(crate) struct GeneratedText {
6778
pub generated_text: String,
68-
pub finish_reason: String,
79+
#[serde(skip_serializing_if = "Option::is_none")]
80+
pub details: Option<Details>,
6981
}
7082

7183
#[derive(Serialize)]

router/src/server.rs

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::{
2-
Batcher, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation,
2+
Batcher, Details, ErrorResponse, GenerateParameters, GenerateRequest, GeneratedText, Validation,
33
};
44
use axum::extract::Extension;
55
use axum::http::{HeaderMap, StatusCode};
@@ -54,6 +54,7 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
5454
do_sample: false,
5555
max_new_tokens: 1,
5656
stop: vec![],
57+
details: false,
5758
},
5859
},
5960
)
@@ -89,6 +90,7 @@ async fn generate(
8990
})?;
9091

9192
// Validate request
93+
let details = req.0.parameters.details;
9294
let (input_length, validated_request) =
9395
state.validation.validate(req.0).await.map_err(|err| {
9496
tracing::error!("{}", err.to_string());
@@ -105,12 +107,31 @@ async fn generate(
105107
err
106108
})?;
107109

110+
// Token details
111+
let details = match details {
112+
true => {
113+
let tokens = response
114+
.token_ids
115+
.into_iter()
116+
.zip(response.tokens.into_iter())
117+
.zip(response.logprobs.into_iter())
118+
.map(|((id, text), logprob)| (id, text, logprob))
119+
.collect();
120+
Some(Details {
121+
finish_reason: response.finish_reason,
122+
generated_tokens: response.generated_tokens,
123+
tokens,
124+
})
125+
}
126+
false => None,
127+
};
128+
108129
// Timings
109130
let total_time = start_time.elapsed();
110131
let validation_time = response.queued - start_time;
111132
let queue_time = response.start - response.queued;
112133
let inference_time = response.end - response.start;
113-
let time_per_token = inference_time / response.tokens;
134+
let time_per_token = inference_time / response.generated_tokens;
114135

115136
// Headers
116137
let mut headers = HeaderMap::new();
@@ -141,12 +162,12 @@ async fn generate(
141162
tracing::Span::current().record("queue_time", format!("{:?}", queue_time));
142163
tracing::Span::current().record("inference_time", format!("{:?}", inference_time));
143164
tracing::Span::current().record("time_per_token", format!("{:?}", time_per_token));
144-
tracing::info!("Output: {}", response.output);
165+
tracing::info!("Output: {}", response.output_text);
145166

146167
// Send response
147168
let response = vec![GeneratedText {
148-
generated_text: response.output,
149-
finish_reason: response.finish_reason,
169+
generated_text: response.output_text,
170+
details,
150171
}];
151172
Ok((headers, Json(response)))
152173
}

server/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
@pytest.fixture
99
def default_pb_parameters():
10-
return generate_pb2.LogitsWarperParameters(
10+
return generate_pb2.NextTokenChooserParameters(
1111
temperature=1.0,
1212
top_k=0,
1313
top_p=1.0,

server/tests/models/test_bloom.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,12 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
128128
assert next_batch is None
129129

130130
assert len(generated_texts) == 1
131-
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
131+
assert (
132+
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
133+
)
132134
assert generated_texts[0].request == default_bloom_batch.requests[0]
133135
assert (
134-
generated_texts[0].tokens
136+
generated_texts[0].generated_tokens
135137
== default_bloom_batch.stopping_criterias[0].max_new_tokens
136138
)
137139

@@ -151,10 +153,10 @@ def test_causal_lm_generate_token_completion_multi(
151153
assert next_batch is not None
152154

153155
assert len(generated_texts) == 1
154-
assert generated_texts[0].output == "TestTestTestTestTestTest"
156+
assert generated_texts[0].output_text == "TestTestTestTestTestTest"
155157
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
156158
assert (
157-
generated_texts[0].tokens
159+
generated_texts[0].generated_tokens
158160
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
159161
)
160162

@@ -170,10 +172,12 @@ def test_causal_lm_generate_token_completion_multi(
170172
assert next_batch is None
171173

172174
assert len(generated_texts) == 1
173-
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
175+
assert (
176+
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
177+
)
174178
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
175179
assert (
176-
generated_texts[0].tokens
180+
generated_texts[0].generated_tokens
177181
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
178182
)
179183

@@ -240,10 +244,10 @@ def test_batch_concatenate(
240244
assert next_batch is not None
241245

242246
assert len(generated_texts) == 1
243-
assert generated_texts[0].output == "TestTestTestTestTestTest"
247+
assert generated_texts[0].output_text == "TestTestTestTestTestTest"
244248
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[1]
245249
assert (
246-
generated_texts[0].tokens
250+
generated_texts[0].generated_tokens
247251
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
248252
)
249253

@@ -259,10 +263,12 @@ def test_batch_concatenate(
259263
assert next_batch is not None
260264

261265
assert len(generated_texts) == 1
262-
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
266+
assert (
267+
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
268+
)
263269
assert generated_texts[0].request == default_bloom_batch.requests[0]
264270
assert (
265-
generated_texts[0].tokens
271+
generated_texts[0].generated_tokens
266272
== default_bloom_batch.stopping_criterias[0].max_new_tokens
267273
)
268274

@@ -279,9 +285,11 @@ def test_batch_concatenate(
279285
assert next_batch is None
280286

281287
assert len(generated_texts) == 1
282-
assert generated_texts[0].output == "TestTestTestTestTestTestTestTestTestTestTest"
288+
assert (
289+
generated_texts[0].output_text == "TestTestTestTestTestTestTestTestTestTestTest"
290+
)
283291
assert generated_texts[0].request == default_multi_requests_bloom_batch.requests[0]
284292
assert (
285-
generated_texts[0].tokens
293+
generated_texts[0].generated_tokens
286294
== default_multi_requests_bloom_batch.stopping_criterias[0].max_new_tokens
287295
)

0 commit comments

Comments
 (0)