Skip to content

Commit 2522f23

Browse files
authored
Remove temperature bottleneck (#1276)
Previously, we were taking temperature on the entire `result_logits` from a decode invocation. Our `mistral` logits are of shape `[1, 1, 128256]`. When I ran a local benchmark on how long the `sfnp.divide` function was taking for this size of an array, it was ~14ms. So, `(14ms * 16 parallel_reqs * 64 decode_steps) / 1000 == 14 s`! In local benchmarks, I was able to decrease the latency of 16 concurrent requests from `~25s` to `~9s`, which also increased our throughput quite a bit. The idea is pretty simple: For greedy selection, or if already in a `softmax`/`log_softmax` form don't take temperature. Dividing by a scalar won't impact which token is the highest scoring token or which ones are the `top_k` highest scoring tokens. It doesn't make sense to apply temperature if already in `softmax` or `log_softmax` form. Otherwise only take temperature on the values that are getting converted to `softmax`. This reduces the size of the array we are dividing drastically.
1 parent e5e85f5 commit 2522f23

File tree

4 files changed

+49
-23
lines changed

4 files changed

+49
-23
lines changed

shortfin/python/shortfin_apps/llm/components/token_selection_strategy/beam_group.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,15 @@ class Beam(ABC):
4343
accumulated_normalization: float = 0.0
4444
last_token: int | None = None
4545

46-
def apply_temperature(self):
46+
def apply_temperature(self, logits: sfnp.device_array):
4747
"""Apply temperature to the logits of a decode invocation.
4848
4949
Args:
5050
temperature (float): Value to use for `temperature`.
5151
"""
5252
if self.decode_config.temperature == 1.0:
53-
return
54-
self.exec_req.result_logits = sfnp.divide(
55-
self.exec_req.result_logits, self.decode_config.temperature
56-
)
53+
return logits
54+
return sfnp.divide(logits, self.decode_config.temperature)
5755

5856
def convert_logits_normalization(
5957
self,
@@ -114,6 +112,10 @@ def _to_softmax(
114112
device,
115113
dtype,
116114
)
115+
116+
if logits_normalization == LogitsNormalization.NONE:
117+
probs_sf = self.apply_temperature(probs_sf)
118+
117119
probs = self.convert_logits_normalization(
118120
logits_normalization,
119121
LogitsNormalization.SOFTMAX,

shortfin/python/shortfin_apps/llm/components/token_selection_strategy/beam_search_token_selection_strategy.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@
3232

3333

3434
class BeamSearchBeam(Beam):
35-
def _convert_results_to_log_probs(self, probs: List):
35+
def _convert_results_to_log_probs(
36+
self,
37+
probs: List,
38+
):
3639
device = self.exec_req.result_logits.device
3740
dtype = self.exec_req.result_logits.dtype
3841
probs_sf = convert_list_to_device_array(
@@ -69,22 +72,42 @@ def sample_logits(self, k: int):
6972
Returns:
7073
Tuple[List[int], List[float]]: Tuple containing (top_tokens, top_values)
7174
"""
72-
self.apply_temperature()
75+
logits = self.exec_req.result_logits
7376
decode_config = self.decode_config
7477
num_beams = decode_config.num_beams
7578
top_k = decode_config.top_k
7679
top_p = decode_config.top_p
7780

7881
if (top_k, top_p) == (None, None):
79-
log_softmax_logits = self.convert_logits_normalization(
82+
tokens, probs = self.sampler.select_top_k(logits, -k)
83+
84+
# TODO: https://github.com/nod-ai/shark-ai/issues/1278 find cleaner way to do these conversions
85+
if logits.dtype in [sfnp.float16]:
86+
probs = [convert_float_to_int(prob, logits.dtype) for prob in probs]
87+
88+
probs_sf = convert_list_to_device_array(
89+
probs,
90+
[len(probs)],
91+
logits.device,
92+
logits.dtype,
93+
)
94+
95+
if self.decode_config.logits_normalization == LogitsNormalization.NONE:
96+
probs_sf = self.apply_temperature(probs_sf)
97+
98+
log_probs = self.convert_logits_normalization(
8099
self.decode_config.logits_normalization,
81100
LogitsNormalization.LOG_SOFTMAX,
82-
self.exec_req.result_logits,
83-
)
101+
probs_sf,
102+
).items.tolist()
84103

85-
return self.sampler.select_top_k(log_softmax_logits, -k)
104+
if logits.dtype in [sfnp.float16]:
105+
log_probs = [
106+
convert_int_to_float(log_prob, logits.dtype)
107+
for log_prob in log_probs
108+
]
86109

87-
logits = self.exec_req.result_logits
110+
return tokens, log_probs
88111

89112
if top_k is not None:
90113
# Sample from `top_k` tokens
@@ -110,7 +133,9 @@ def sample_logits(self, k: int):
110133
if logits.dtype in [sfnp.float16]:
111134
probs = [convert_float_to_int(prob, logits.dtype) for prob in probs]
112135

113-
log_probs = self._convert_results_to_log_probs(probs)
136+
log_probs = self._convert_results_to_log_probs(
137+
probs,
138+
)
114139

115140
return tokens, log_probs
116141

shortfin/python/shortfin_apps/llm/components/token_selection_strategy/greedy_token_selection_strategy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def sample_logits(self) -> int:
2626
Returns:
2727
int: The `argmax` of the logits.
2828
"""
29-
self.apply_temperature()
3029
exec_req = self.exec_req
3130
decode_config = self.decode_config
3231
top_k = decode_config.top_k

shortfin/tests/apps/llm/components/token_selection_strategy/beam_group_test.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,25 +75,25 @@ def test_beam_apply_temperature(device, exec_req, decode_config):
7575

7676
with patch.object(sfnp, "divide") as temp_mock:
7777
expected = value / temperature
78-
beam.apply_temperature()
79-
logits = beam.exec_req.result_logits.items.tolist()
80-
assert all(approximately_equal(expected, logit) for logit in logits)
78+
logits = beam.exec_req.result_logits
79+
result = beam.apply_temperature(logits).items.tolist()
80+
assert all(approximately_equal(expected, logit) for logit in result)
8181
temp_mock.assert_not_called()
8282

8383
temperature = 0.5
8484
beam.decode_config.temperature = temperature
8585
expected = value / temperature
86-
beam.apply_temperature()
87-
logits = beam.exec_req.result_logits.items.tolist()
88-
assert all(approximately_equal(expected, logit) for logit in logits)
86+
logits = beam.exec_req.result_logits
87+
result = beam.apply_temperature(logits).items.tolist()
88+
assert all(approximately_equal(expected, logit) for logit in result)
8989

9090
temperature = 1.5
9191
beam.exec_req.result_logits.items = data
9292
beam.decode_config.temperature = temperature
9393
expected = value / temperature
94-
beam.apply_temperature()
95-
logits = beam.exec_req.result_logits.items.tolist()
96-
assert all(approximately_equal(expected, logit) for logit in logits)
94+
logits = beam.exec_req.result_logits
95+
result = beam.apply_temperature(logits).items.tolist()
96+
assert all(approximately_equal(expected, logit) for logit in result)
9797

9898

9999
def test_convert_logits_normalization_none(device, exec_req, decode_config):

0 commit comments

Comments
 (0)