Skip to content

Commit 611e21c

Browse files
fix(server): Fix stop sequences (IBM#11)
1 parent 3e2e624 commit 611e21c

File tree

5 files changed

+77
-76
lines changed

5 files changed

+77
-76
lines changed

launcher/tests/integration_tests.rs

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
use std::fs::File;
1+
use float_eq::assert_float_eq;
2+
use serde::Deserialize;
23
use serde_json::Value;
4+
use std::fs::File;
35
use std::io::{BufRead, BufReader};
46
use std::path::PathBuf;
57
use std::thread;
68
use std::thread::sleep;
79
use std::time::Duration;
8-
use float_eq::assert_float_eq;
910
use subprocess::{Popen, PopenConfig, Redirection};
10-
use serde::Deserialize;
1111

1212
#[derive(Deserialize)]
1313
struct Details {
@@ -22,7 +22,6 @@ struct GeneratedText {
2222
details: Details,
2323
}
2424

25-
2625
fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
2726
let argv = vec![
2827
"text-generation-launcher".to_string(),
@@ -46,7 +45,7 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
4645
..Default::default()
4746
},
4847
)
49-
.expect("Could not start launcher");
48+
.expect("Could not start launcher");
5049

5150
// Redirect STDOUT and STDERR to the console
5251
let launcher_stdout = launcher.stdout.take().unwrap();
@@ -63,7 +62,7 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
6362
}
6463
});
6564

66-
for _ in 0..30 {
65+
for _ in 0..60 {
6766
let health = reqwest::blocking::get(format!("http://localhost:{}/health", port));
6867
if health.is_ok() {
6968
return launcher;
@@ -76,7 +75,12 @@ fn start_launcher(model_name: String, num_shard: usize, port: usize, master_port
7675
panic!("failed to launch {}", model_name)
7776
}
7877

79-
fn test_model(model_name: String, num_shard: usize, port: usize, master_port: usize) -> GeneratedText {
78+
fn test_model(
79+
model_name: String,
80+
num_shard: usize,
81+
port: usize,
82+
master_port: usize,
83+
) -> GeneratedText {
8084
let mut launcher = start_launcher(model_name, num_shard, port, master_port);
8185

8286
let data = r#"
@@ -101,7 +105,6 @@ fn test_model(model_name: String, num_shard: usize, port: usize, master_port: us
101105
results.pop().unwrap()
102106
}
103107

104-
105108
fn read_json(name: &str) -> GeneratedText {
106109
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
107110
d.push("tests/");
@@ -117,9 +120,17 @@ fn read_json(name: &str) -> GeneratedText {
117120
fn compare_results(result: GeneratedText, expected: GeneratedText) {
118121
assert_eq!(result.generated_text, expected.generated_text);
119122
assert_eq!(result.details.finish_reason, expected.details.finish_reason);
120-
assert_eq!(result.details.generated_tokens, expected.details.generated_tokens);
121-
122-
for (token, expected_token) in result.details.tokens.into_iter().zip(expected.details.tokens.into_iter()) {
123+
assert_eq!(
124+
result.details.generated_tokens,
125+
expected.details.generated_tokens
126+
);
127+
128+
for (token, expected_token) in result
129+
.details
130+
.tokens
131+
.into_iter()
132+
.zip(expected.details.tokens.into_iter())
133+
{
123134
assert_eq!(token.0, expected_token.0);
124135
assert_eq!(token.1, expected_token.1);
125136
if let Some(logprob) = token.2 {

server/tests/test_utils.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,46 +11,33 @@
1111

1212

1313
def test_stop_sequence_criteria():
14-
criteria = StopSequenceCriteria([1, 2, 3])
14+
criteria = StopSequenceCriteria("/test;")
1515

16-
assert not criteria(1)
17-
assert criteria.current_token_idx == 1
18-
assert not criteria(2)
19-
assert criteria.current_token_idx == 2
20-
assert criteria(3)
21-
assert criteria.current_token_idx == 3
16+
assert not criteria("/")
17+
assert not criteria("/test")
18+
assert criteria("/test;")
19+
assert not criteria("/test; ")
2220

2321

24-
def test_stop_sequence_criteria_reset():
25-
criteria = StopSequenceCriteria([1, 2, 3])
26-
27-
assert not criteria(1)
28-
assert criteria.current_token_idx == 1
29-
assert not criteria(2)
30-
assert criteria.current_token_idx == 2
31-
assert not criteria(4)
32-
assert criteria.current_token_idx == 0
33-
34-
35-
def test_stop_sequence_criteria_empty():
36-
with pytest.raises(ValueError):
37-
StopSequenceCriteria([])
22+
def test_stopping_criteria():
23+
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
24+
assert criteria(65827, "/test") == (False, None)
25+
assert criteria(30, ";") == (True, "stop_sequence")
3826

3927

40-
def test_stopping_criteria():
41-
criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5)
42-
assert criteria([1]) == (False, None)
43-
assert criteria([1, 2]) == (False, None)
44-
assert criteria([1, 2, 3]) == (True, "stop_sequence")
28+
def test_stopping_criteria_eos():
29+
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
30+
assert criteria(1, "") == (False, None)
31+
assert criteria(0, "") == (True, "eos_token")
4532

4633

4734
def test_stopping_criteria_max():
48-
criteria = StoppingCriteria([StopSequenceCriteria([1, 2, 3])], max_new_tokens=5)
49-
assert criteria([1]) == (False, None)
50-
assert criteria([1, 1]) == (False, None)
51-
assert criteria([1, 1, 1]) == (False, None)
52-
assert criteria([1, 1, 1, 1]) == (False, None)
53-
assert criteria([1, 1, 1, 1, 1]) == (True, "length")
35+
criteria = StoppingCriteria(0, [StopSequenceCriteria("/test;")], max_new_tokens=5)
36+
assert criteria(1, "") == (False, None)
37+
assert criteria(1, "") == (False, None)
38+
assert criteria(1, "") == (False, None)
39+
assert criteria(1, "") == (False, None)
40+
assert criteria(1, "") == (True, "length")
5441

5542

5643
def test_weight_hub_files():

server/text_generation/models/causal_lm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,12 @@ def generate_token(
345345
all_logprobs = torch.cat([all_logprobs, next_token_logprob])
346346

347347
# Evaluate stopping criteria
348-
stop, reason = stopping_criteria(all_input_ids)
348+
stop, reason = stopping_criteria(
349+
next_token.squeeze(),
350+
self.tokenizer.decode(
351+
next_token.squeeze(), clean_up_tokenization_spaces=False
352+
),
353+
)
349354
if stop:
350355
# Decode all tokens
351356
output_text = self.tokenizer.decode(

server/text_generation/models/seq2seq_lm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,12 @@ def generate_token(
441441
decoder_logprobs = torch.cat([decoder_logprobs, next_token_logprob])
442442

443443
# Evaluate stopping criteria
444-
stop, reason = stopping_criteria(decoder_input_ids)
444+
stop, reason = stopping_criteria(
445+
next_token.squeeze(),
446+
self.tokenizer.decode(
447+
next_token.squeeze(), clean_up_tokenization_spaces=False
448+
),
449+
)
445450
if stop:
446451
# Slice with decoder_input_length to remove padding
447452
# Decode all tokens

server/text_generation/utils.py

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import concurrent
22
import os
3+
import re
34
import torch
45
import torch.distributed
56

@@ -74,43 +75,39 @@ def from_pb(cls, pb: generate_pb2.NextTokenChooserParameters) -> "NextTokenChoos
7475

7576

7677
class StopSequenceCriteria:
77-
def __init__(self, tokens: List[int]):
78-
if not tokens:
79-
raise ValueError("tokens cannot be empty")
80-
81-
self.tokens = tokens
82-
self.current_token_idx = 0
83-
84-
def __call__(self, last_token: int) -> bool:
85-
if last_token == self.tokens[self.current_token_idx]:
86-
# Increase idx to go to next token
87-
self.current_token_idx += 1
88-
else:
89-
# Reset to first token of the stopping sequence
90-
self.current_token_idx = 0
91-
92-
if self.current_token_idx == len(self.tokens):
93-
# We matched the entire sequence without resetting
78+
def __init__(self, stop_sequence: str):
79+
self.regex = re.compile(f".*{stop_sequence}$")
80+
81+
def __call__(self, output: str) -> bool:
82+
if self.regex.findall(output):
9483
return True
9584
return False
9685

9786

9887
class StoppingCriteria:
9988
def __init__(
100-
self, stop_sequence_criterias: List[StopSequenceCriteria], max_new_tokens=20
89+
self,
90+
eos_token_id: int,
91+
stop_sequence_criterias: List[StopSequenceCriteria],
92+
max_new_tokens=20,
10193
):
94+
self.eos_token_id = eos_token_id
10295
self.stop_sequence_criterias = stop_sequence_criterias
10396
self.max_new_tokens = max_new_tokens
10497
self.current_tokens = 0
98+
self.current_output = ""
10599

106-
def __call__(self, all_ids) -> Tuple[bool, Optional[str]]:
100+
def __call__(self, last_token: int, last_output: str) -> Tuple[bool, Optional[str]]:
107101
self.current_tokens += 1
108102
if self.current_tokens >= self.max_new_tokens:
109103
return True, "length"
110104

111-
last_token = all_ids[-1]
105+
if last_token == self.eos_token_id:
106+
return True, "eos_token"
107+
108+
self.current_output += last_output
112109
for stop_sequence_criteria in self.stop_sequence_criterias:
113-
if stop_sequence_criteria(last_token):
110+
if stop_sequence_criteria(self.current_output):
114111
return True, "stop_sequence"
115112

116113
return False, None
@@ -119,16 +116,12 @@ def __call__(self, all_ids) -> Tuple[bool, Optional[str]]:
119116
def from_pb(
120117
cls, pb: generate_pb2.StoppingCriteriaParameters, tokenizer: AutoTokenizer
121118
) -> "StoppingCriteria":
122-
stop_sequence_criterias = []
123-
for stop_sequence in pb.stop_sequences:
124-
tokens = tokenizer(
125-
stop_sequence, padding=False, return_attention_mask=False
126-
).input_ids
127-
if tokens:
128-
stop_sequence_criterias.append(StopSequenceCriteria(tokens))
129-
stop_sequence_criterias.append(StopSequenceCriteria([tokenizer.eos_token_id]))
130-
131-
return StoppingCriteria(stop_sequence_criterias, pb.max_new_tokens)
119+
stop_sequence_criterias = [
120+
StopSequenceCriteria(sequence) for sequence in pb.stop_sequences
121+
]
122+
return StoppingCriteria(
123+
tokenizer.eos_token_id, stop_sequence_criterias, pb.max_new_tokens
124+
)
132125

133126

134127
def initialize_torch_distributed():

0 commit comments

Comments
 (0)