Skip to content

Commit d31197b

Browse files
alex-jw-brookstjohnson31415
authored andcommitted
Vectorized next token chooser for causal_lm
This pull request (mostly) ports the heterogeneous next token chooser, which is used for flash models in TGI, into Causal LM. Co-authored-by: Alex Brooks <[email protected]> Co-authored-by: Travis Johnson <[email protected]>
1 parent 6c670dd commit d31197b

File tree

11 files changed

+538
-50
lines changed

11 files changed

+538
-50
lines changed

integration_tests/test_cases_bloom560m.yaml

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,21 +1209,59 @@
12091209

12101210

12111211
# Repetition penalty
1212-
- name: Repetition penalty
1212+
- name: Repetition penalty - disabled and has repetition
1213+
request:
1214+
params:
1215+
stopping:
1216+
minNewTokens: 21
1217+
maxNewTokens: 21
1218+
requests:
1219+
- {"text": "I will be a good AI."}
1220+
response:
1221+
responses:
1222+
- generatedTokenCount: 21
1223+
inputTokenCount: 7
1224+
stopReason: MAX_TOKENS
1225+
text: ' I will be a good AI. I will be a good AI. I will be a good AI.'
1226+
1227+
- name: Repetition penalty - enabled to remove repetition
12131228
request:
12141229
params:
12151230
decoding:
12161231
repetition_penalty: 2.5
12171232
stopping:
1233+
minNewTokens: 20
12181234
maxNewTokens: 20
12191235
requests:
1220-
- {"text": "A very long story:\n"}
1236+
- {"text": "I will be a good AI."}
12211237
response:
12221238
responses:
12231239
- generatedTokenCount: 20
1224-
inputTokenCount: 6
1240+
inputTokenCount: 7
12251241
stopReason: MAX_TOKENS
1226-
text: The first time I saw the movie, it was in a theater. It had been on my mind
1242+
text: ' I have been working on this for the past few years and am now ready to start
1243+
my own company'
1244+
1245+
- name: Repetition penalty with truncation
1246+
request:
1247+
params:
1248+
truncateInputTokens: 7
1249+
decoding:
1250+
repetition_penalty: 2.5
1251+
stopping:
1252+
minNewTokens: 20
1253+
maxNewTokens: 20
1254+
requests:
1255+
# Truncation removes "I have been working." which means it can be generated again
1256+
- {"text": "I have been working.I will be a good AI."}
1257+
response:
1258+
responses:
1259+
- generatedTokenCount: 20
1260+
inputTokenCount: 7
1261+
stopReason: MAX_TOKENS
1262+
text: ' I have been working on this for the past few years and am now ready to start
1263+
my own company'
1264+
12271265

12281266
# Length penalty
12291267
- name: Length penalty

server/tests/test_logit_processors.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import pytest
2+
import torch
3+
4+
from text_generation_server.utils.logits_process import (
5+
HeterogeneousRepetitionPenaltyLogitsProcessor,
6+
HeterogeneousTemperatureLogitsWarper,
7+
HeterogeneousTopKLogitsWarper,
8+
HeterogeneousTopPLogitsWarper,
9+
HeterogeneousTypicalLogitsWarper,
10+
StaticWarper
11+
)
12+
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
13+
14+
##############################################################################
15+
# Tests for comparing vectorized heterogeneous logit processors to their
16+
# sequential implementations. In these tests, we only check valid cases, because
17+
# the vectorized overrides generally don't provide any input validation.
18+
19+
BATCH_SIZE = 2
20+
VOCAB_DIM = 25
21+
# Input IDs of shape (batch_size x logits_dim);
22+
# chosen intentionally to have repetition etc.
23+
INPUT_IDS = torch.tensor([
24+
[1, 2, 1, 3, 4, 6, 7, 1, 1, 1],
25+
[1, 7, 0, 3, 4, 6, 7, 1, 1, 1],
26+
], dtype=torch.long)
27+
# NOTE: We assume BATCH_SIZE x VOCAB_DIM instead of BATCH_SIZE x SEQ_LEN x VOCAB_DIM
28+
# because the vectorized operations are designed to work on the last set of logits in
29+
# the sequence. I.e., this is effectively x[:, -1, :] of the 3rd order tensor.
30+
FULL_SCORES = torch.softmax(torch.rand((BATCH_SIZE, VOCAB_DIM), dtype=torch.float32), dim=-1)
31+
32+
def compare_individual_vs_vectorized_scores(s_warped, v_warped):
33+
"""Given scores warped individually, compare to scores warped with a vectorized
34+
implementation.
35+
36+
Args:
37+
s_warped: List[torch.Tensor]
38+
List of tensors warped as single entries.
39+
v_warped: torch.Tensor
40+
Warped tensor mat.
41+
"""
42+
assert len(s_warped) == v_warped.shape[0]
43+
for idx, s_warped_scores in enumerate(s_warped):
44+
v_warped_scores = v_warped[idx]
45+
assert torch.allclose(s_warped_scores.squeeze(), v_warped_scores)
46+
47+
def test_alignment_repetition_penalty_logits_processor():
48+
"""Ensure that the repetition penalty is consistent when it is/isn't vectorized."""
49+
# NOTE: 1.0 Tests the case with no penalty
50+
penalties = [1.0, 2.5]
51+
# Apply the vectorized repetition logits processor over everything
52+
# given that we have a homogeneous set of penalties to apply
53+
vectorized_proc = HeterogeneousRepetitionPenaltyLogitsProcessor(
54+
penalty=penalties,
55+
dtype=torch.float32,
56+
device=None,
57+
)
58+
v_warped = vectorized_proc(input_ids=INPUT_IDS, scores=FULL_SCORES)
59+
# apply each penalty one at a time using the nonvectorized warper
60+
s_warped = []
61+
for penalty, logits, ids in zip(penalties, FULL_SCORES, INPUT_IDS):
62+
single_proc = RepetitionPenaltyLogitsProcessor(penalty=penalty)
63+
s_warped.append(single_proc(ids.unsqueeze(dim=0), logits.unsqueeze(dim=0)))
64+
compare_individual_vs_vectorized_scores(s_warped, v_warped)
65+
66+
67+
def test_alignment_temperature_logits_processor():
68+
"""Ensure that the temperature warping is consistent when it is/isn't vectorized."""
69+
# NOTE: 1.0 Tests the case with no temperature warping
70+
temperatures = [0.25, 1]
71+
vectorized_proc = HeterogeneousTemperatureLogitsWarper(
72+
temperature=temperatures,
73+
dtype=torch.float32,
74+
device=None,
75+
)
76+
# Vectorized temperature warping happens in place; clone the score tensor!
77+
score_clone = FULL_SCORES.clone()
78+
v_warped = vectorized_proc(input_ids=INPUT_IDS, scores=score_clone.view(2, -1))
79+
80+
s_warped = []
81+
for temp, logits in zip(temperatures, FULL_SCORES):
82+
# We are testing alignment with TemperatureLogitsWarper
83+
# through the StaticWarper wrapper class, both for the no-op case
84+
# and for the case where we actually modify our scores.
85+
single_proc = StaticWarper(temperature=temp)
86+
# NOTE: static warpers return a tuple with scores + logprobs (if enabled);
87+
# We only care about comparing the first one, i.e., scores, here.
88+
s_warped.append(single_proc(logits.unsqueeze(dim=0))[0])
89+
compare_individual_vs_vectorized_scores(s_warped, v_warped)
90+
91+
92+
@pytest.mark.parametrize("top_k", [[0, 3], [1, 3]])
93+
def test_alignment_top_k_logits_processor(top_k):
94+
"""Ensure that the top k warping is consistent when it is/isn't vectorized."""
95+
vectorized_proc = HeterogeneousTopKLogitsWarper(
96+
top_k=top_k,
97+
device=None,
98+
)
99+
# top k filling happens in place; clone the score tensor!
100+
score_clone = FULL_SCORES.clone()
101+
v_warped = vectorized_proc(input_ids=INPUT_IDS, scores=score_clone)
102+
103+
s_warped = []
104+
for k, logits in zip(top_k, FULL_SCORES):
105+
# We are testing alignment with TopKLogitsWarper
106+
# through the StaticWarper wrapper class, both when we have
107+
# things in the batch to ignore, and when we care about everything.
108+
single_proc = StaticWarper(top_k=k)
109+
# NOTE: static warpers return a tuple with scores + logprobs (if enabled);
110+
# We only care about comparing the first one, i.e., scores, here.
111+
s_warped.append(single_proc(logits.unsqueeze(dim=0))[0])
112+
compare_individual_vs_vectorized_scores(s_warped, v_warped)
113+
114+
115+
def test_alignment_top_p_logits_processor():
116+
"""Ensure that the top k warping is consistent when it is/isn't vectorized."""
117+
top_p = [.9, 0]
118+
vectorized_proc = HeterogeneousTopPLogitsWarper(
119+
top_p=top_p,
120+
dtype=torch.float32,
121+
device=None,
122+
)
123+
# top p filtering happens in place; clone the score tensor!
124+
score_clone = FULL_SCORES.clone()
125+
v_warped = vectorized_proc(input_ids=INPUT_IDS, scores=score_clone)
126+
127+
s_warped = []
128+
for p, logits in zip(top_p, FULL_SCORES):
129+
# We are testing alignment with TopPLogitsWarper through the StaticWarper
130+
# wrapper class. Be aware that TopPLogitsWarper is an implementation
131+
# in TGIS, not in Transformers!
132+
single_proc = StaticWarper(top_p=p)
133+
# NOTE: static warpers return a tuple with scores + logprobs (if enabled);
134+
# We only care about comparing the first one, i.e., scores, here.
135+
s_warped.append(single_proc(logits.unsqueeze(dim=0))[0])
136+
compare_individual_vs_vectorized_scores(s_warped, v_warped)
137+
138+
139+
def test_alignment_typical_logits_processor():
140+
"""Ensure that the typical logit warping is consistent when it is/isn't vectorized."""
141+
masses = [.7, .9]
142+
vectorized_proc = HeterogeneousTypicalLogitsWarper(
143+
mass=masses,
144+
dtype=torch.float32,
145+
device=None,
146+
)
147+
# typical logits filtering happens in place; clone the score tensor!
148+
score_clone = FULL_SCORES.clone()
149+
v_warped = vectorized_proc(input_ids=INPUT_IDS, scores=score_clone)
150+
151+
s_warped = []
152+
for mass, logits in zip(masses, FULL_SCORES):
153+
# We are testing alignment with TypicalLogitsWarper through the StaticWarper
154+
# wrapper class. Be aware that TypicalLogitsWarper is an implementation
155+
# in TGIS, not in Transformers!
156+
single_proc = StaticWarper(typical_p=mass)
157+
# NOTE: static warpers return a tuple with scores + logprobs (if enabled);
158+
# We only care about comparing the first one, i.e., scores, here.
159+
s_warped.append(single_proc(logits.unsqueeze(dim=0))[0])
160+
compare_individual_vs_vectorized_scores(s_warped, v_warped)

0 commit comments

Comments
 (0)