Skip to content

Commit 718096f

Browse files
feat: Support stop sequences (IBM#7)
1 parent 042180d commit 718096f

File tree

18 files changed

+254
-107
lines changed

18 files changed

+254
-107
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
1515
- Quantization with [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
1616
- [Safetensors](https://github.com/huggingface/safetensors) weight loading
1717
- 45ms per token generation for BLOOM with 8xA100 80GB
18+
- Logits warpers (temperature scaling, topk ...)
19+
- Stop sequences
1820

1921
## Officially supported models
2022

proto/generate.proto

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,23 @@ message ClearCacheRequest {}
2828
message ClearCacheResponse {}
2929

3030
message LogitsWarperParameters {
31+
/// exponential scaling output probability distribution
3132
float temperature = 1;
33+
/// restricting to the k highest probability elements
3234
uint32 top_k = 2;
35+
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
3336
float top_p = 3;
37+
/// apply sampling on the logits
3438
bool do_sample = 4;
3539
}
3640

41+
message StoppingCriteriaParameters {
42+
/// Maximum number of generated tokens
43+
uint32 max_new_tokens = 1;
44+
/// Optional stopping sequences
45+
repeated string stop_sequences = 2;
46+
}
47+
3748
message Request {
3849
/// Request ID
3950
uint64 id = 1;
@@ -43,8 +54,8 @@ message Request {
4354
uint32 input_length = 3;
4455
/// Logits Warper Parameters
4556
LogitsWarperParameters parameters = 4;
46-
/// Stopping criteria
47-
uint32 max_new_tokens = 5;
57+
/// Stopping Criteria Parameters
58+
StoppingCriteriaParameters stopping_parameters = 5;
4859
}
4960

5061
message Batch {
@@ -63,6 +74,8 @@ message GeneratedText {
6374
string output = 2;
6475
/// Number of generated tokens
6576
uint32 tokens = 3;
77+
/// Finish reason
78+
string finish_reason = 4;
6679
}
6780

6881
message GenerateRequest {

router/client/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ mod pb;
66
mod sharded_client;
77

88
pub use client::Client;
9-
pub use pb::generate::v1::{Batch, GeneratedText, LogitsWarperParameters, Request};
9+
pub use pb::generate::v1::{
10+
Batch, GeneratedText, LogitsWarperParameters, Request, StoppingCriteriaParameters,
11+
};
1012
pub use sharded_client::ShardedClient;
1113
use thiserror::Error;
1214
use tonic::transport;

router/src/batcher.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
190190
let response = InferResponse {
191191
output: output.output,
192192
tokens: output.tokens,
193+
finish_reason: output.finish_reason,
193194
queued: entry.time,
194195
start: entry.batch_time.unwrap(), // unwrap is always valid
195196
end: Instant::now(),
@@ -203,6 +204,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
203204
pub(crate) struct InferResponse {
204205
pub(crate) output: String,
205206
pub(crate) tokens: u32,
207+
pub(crate) finish_reason: String,
206208
pub(crate) queued: Instant,
207209
pub(crate) start: Instant,
208210
pub(crate) end: Instant,

router/src/db.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ use crate::{GenerateParameters, GenerateRequest};
44
use parking_lot::Mutex;
55
use std::collections::BTreeMap;
66
use std::sync::Arc;
7-
use text_generation_client::{Batch, ClientError, LogitsWarperParameters, Request};
7+
use text_generation_client::{
8+
Batch, ClientError, LogitsWarperParameters, Request, StoppingCriteriaParameters,
9+
};
810
use tokio::sync::oneshot::Sender;
911
use tokio::time::Instant;
1012

@@ -72,7 +74,9 @@ impl State {
7274
parameters: Some(LogitsWarperParameters::from(
7375
entry.request.parameters.clone(),
7476
)),
75-
max_new_tokens: entry.request.parameters.max_new_tokens,
77+
stopping_parameters: Some(StoppingCriteriaParameters::from(
78+
entry.request.parameters.clone(),
79+
)),
7680
});
7781

7882
ids.push(*id);
@@ -168,3 +172,12 @@ impl From<GenerateParameters> for LogitsWarperParameters {
168172
}
169173
}
170174
}
175+
176+
impl From<GenerateParameters> for StoppingCriteriaParameters {
177+
fn from(parameters: GenerateParameters) -> Self {
178+
Self {
179+
stop_sequences: parameters.stop,
180+
max_new_tokens: parameters.max_new_tokens,
181+
}
182+
}
183+
}

router/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pub(crate) struct GenerateParameters {
2121
pub do_sample: bool,
2222
#[serde(default = "default_max_new_tokens")]
2323
pub max_new_tokens: u32,
24+
pub stop: Vec<String>,
2425
}
2526

2627
fn default_temperature() -> f32 {
@@ -50,6 +51,7 @@ fn default_parameters() -> GenerateParameters {
5051
top_p: default_top_p(),
5152
do_sample: default_do_sample(),
5253
max_new_tokens: default_max_new_tokens(),
54+
stop: vec![],
5355
}
5456
}
5557

@@ -63,6 +65,7 @@ pub(crate) struct GenerateRequest {
6365
#[derive(Serialize)]
6466
pub(crate) struct GeneratedText {
6567
pub generated_text: String,
68+
pub finish_reason: String,
6669
}
6770

6871
#[derive(Serialize)]

router/src/server.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
5353
top_p: 1.0,
5454
do_sample: false,
5555
max_new_tokens: 1,
56+
stop: vec![],
5657
},
5758
},
5859
)
@@ -88,11 +89,8 @@ async fn generate(
8889
})?;
8990

9091
// Validate request
91-
let (input_length, validated_request) = state
92-
.validation
93-
.validate(req.0)
94-
.await
95-
.map_err(|err| {
92+
let (input_length, validated_request) =
93+
state.validation.validate(req.0).await.map_err(|err| {
9694
tracing::error!("{}", err.to_string());
9795
err
9896
})?;
@@ -148,6 +146,7 @@ async fn generate(
148146
// Send response
149147
let response = vec![GeneratedText {
150148
generated_text: response.output,
149+
finish_reason: response.finish_reason,
151150
}];
152151
Ok((headers, Json(response)))
153152
}

router/src/validation.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,14 @@ fn validation_worker(
121121
.unwrap_or(());
122122
continue;
123123
}
124+
if request.parameters.stop.len() > 4 {
125+
response_tx
126+
.send(Err(ValidationError::StopSequence(
127+
request.parameters.stop.len(),
128+
)))
129+
.unwrap_or(());
130+
continue;
131+
}
124132

125133
// Get the number of tokens in the input
126134
match tokenizer.encode(request.inputs.clone(), false) {
@@ -163,6 +171,8 @@ pub enum ValidationError {
163171
MaxNewTokens,
164172
#[error("inputs must have less than {1} tokens. Given: {0}")]
165173
InputLength(usize, usize),
174+
#[error("stop supports up to 4 stop sequences. Given: {0}")]
175+
StopSequence(usize),
166176
#[error("tokenizer error {0}")]
167177
Tokenizer(String),
168178
}

server/tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ def default_pb_parameters():
1515
)
1616

1717

18+
@pytest.fixture
19+
def default_pb_stop_parameters():
20+
return generate_pb2.StoppingCriteriaParameters(stop_sequences=[], max_new_tokens=10)
21+
22+
1823
@pytest.fixture(scope="session")
1924
def bloom_560m_tokenizer():
2025
return AutoTokenizer.from_pretrained("bigscience/bloom-560m", padding_side="left")

server/tests/models/test_bloom.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99

1010

1111
@pytest.fixture
12-
def default_pb_request(default_pb_parameters):
12+
def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
1313
return generate_pb2.Request(
1414
id=0,
1515
inputs="Test",
1616
input_length=1,
1717
parameters=default_pb_parameters,
18-
max_new_tokens=10,
18+
stopping_parameters=default_pb_stop_parameters,
1919
)
2020

2121

@@ -36,7 +36,7 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer)
3636
req_0 = copy(default_pb_request)
3737
req_1 = default_pb_request
3838
req_1.id = 1
39-
req_1.max_new_tokens = 5
39+
req_1.stopping_parameters.max_new_tokens = 5
4040

4141
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
4242
return BloomCausalLMBatch.from_pb(
@@ -56,7 +56,6 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
5656
assert batch.requests == default_pb_batch.requests
5757

5858
assert len(batch.input_ids) == default_pb_batch.size
59-
assert len(batch.input_ids[0]) == 8
6059
assert batch.input_ids[0][-1] == 10264
6160
assert torch.all(batch.input_ids[0][:-1] == 3)
6261

@@ -85,14 +84,19 @@ def test_causal_lm_batch_type(default_bloom):
8584

8685

8786
def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
87+
sequence_length = len(default_bloom_batch.all_input_ids[0])
8888
generated_texts, next_batch = default_bloom.generate_token(default_bloom_batch)
8989

9090
assert generated_texts == []
9191
assert isinstance(next_batch, CausalLMBatch)
9292
assert not next_batch.keys_head_dim_last
9393

9494
assert len(next_batch.all_input_ids) == next_batch.size
95-
assert len(next_batch.all_input_ids[0]) == len(next_batch.attention_mask[0]) == 9
95+
assert (
96+
len(next_batch.all_input_ids[0])
97+
== len(next_batch.attention_mask[0])
98+
== sequence_length + 1
99+
)
96100
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
97101
assert torch.all(next_batch.all_input_ids[0][:-2] == 3)
98102

@@ -106,8 +110,12 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
106110
assert next_batch.max_sequence_length == next_batch.input_lengths[0]
107111

108112
assert next_batch.past_key_values is not None
109-
assert all([p[0].shape == (16, 64, 8) for p in next_batch.past_key_values])
110-
assert all([p[1].shape == (16, 8, 64) for p in next_batch.past_key_values])
113+
assert all(
114+
[p[0].shape == (16, 64, sequence_length) for p in next_batch.past_key_values]
115+
)
116+
assert all(
117+
[p[1].shape == (16, sequence_length, 64) for p in next_batch.past_key_values]
118+
)
111119

112120

113121
def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch):

0 commit comments

Comments
 (0)