Skip to content

Commit 2f79208

Browse files
maxdebaysernjhill
authored andcommitted
Fix the support for input_embeds in santacoder in sharded mode
* Remove default argument * Activate multi-shard input_embeds test cases * Change treatment of input_embeds in flash_santacoder Instead of dividing the prefix embeddings by the world size in all shards, only the rank 0 shard will return a non-zero tensor, thus preserving the semantics of the all_reduce operation but without the potential loss of precision that a floating point division entails.
1 parent 49a0b2c commit 2f79208

File tree

8 files changed

+31
-32
lines changed

8 files changed

+31
-32
lines changed

integration_tests/test_cases_bloom560m.yaml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@
4242

4343
# Prompt prefix
4444
- name: Greedy with tuned prompt prefix
45-
# Prompt prefixes with multi-shard not yet supported
46-
singleShardOnly: true
4745
request:
4846
prefixId: bloom_sentiment_1
4947
params:
@@ -59,8 +57,6 @@
5957
text: ' positive'
6058

6159
- name: Greedy with tuned prompt prefix and truncation
62-
# Prompt prefixes with multi-shard not yet supported
63-
singleShardOnly: true
6460
request:
6561
prefixId: bloom_sentiment_1
6662
params:
@@ -79,8 +75,6 @@
7975

8076
# Prompt prefix with nested path
8177
- name: Greedy with tuned prompt prefix with nested path (id)
82-
# Prompt prefixes with multi-shard not yet supported
83-
singleShardOnly: true
8478
request:
8579
prefixId: nested/path
8680
params:
@@ -98,8 +92,6 @@
9892

9993
# Prompt prefix returning input and generated tokens
10094
- name: Greedy with tuned prompt prefix and returned tokens
101-
# Prompt prefixes with multi-shard not yet supported
102-
singleShardOnly: true
10395
request:
10496
prefixId: bloom_sentiment_1
10597
params:
@@ -279,8 +271,6 @@
279271

280272
# Error case - invalid prefix id
281273
- name: Error case - invalid prefix id
282-
# Prompt prefixes with multi-shard not yet supported
283-
singleShardOnly: true
284274
request:
285275
prefixId: invalid_prefix_id
286276
params:

integration_tests/test_cases_tinyllama.yaml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,6 @@
160160

161161
# Prompt prefix
162162
- name: Greedy with tuned prompt prefix
163-
# Prompt prefixes with multi-shard not yet supported
164-
singleShardOnly: true
165163
request:
166164
prefixId: tinyllama
167165
params:
@@ -179,8 +177,6 @@
179177

180178
# Prompt prefix with truncation
181179
- name: Greedy with tuned prompt prefix with truncation
182-
# Prompt prefixes with multi-shard not yet supported
183-
singleShardOnly: true
184180
request:
185181
prefixId: tinyllama
186182
params:
@@ -201,8 +197,6 @@
201197

202198
# Prompt prefix returning input and generated tokens
203199
- name: Greedy with tuned prompt prefix and returned tokens
204-
# Prompt prefixes with multi-shard not yet supported
205-
singleShardOnly: true
206200
request:
207201
prefixId: tinyllama
208202
params:

integration_tests/test_cases_tinystarcoderpy.yaml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,6 @@
172172

173173
# Prompt prefix
174174
- name: Greedy with tuned prompt prefix
175-
# Prompt prefixes with multi-shard not yet supported
176-
singleShardOnly: true
177175
request:
178176
# Prefix is "def hello_world():\n"
179177
prefixId: tiny_starcoder
@@ -189,8 +187,6 @@
189187
text: "(\"Hello World!\")\n\nhello_world()\n"
190188

191189
- name: Greedy with tuned prompt prefix and truncation
192-
# Prompt prefixes with multi-shard not yet supported
193-
singleShardOnly: true
194190
request:
195191
# Prefix is "def hello_world():\n"
196192
prefixId: tiny_starcoder
@@ -209,8 +205,6 @@
209205

210206
# Prompt prefix returning input and generated tokens
211207
- name: Greedy with tuned prompt prefix and returned tokens
212-
# Prompt prefixes with multi-shard not yet supported
213-
singleShardOnly: true
214208
request:
215209
# Prefix is "def hello_world():\n"
216210
prefixId: tiny_starcoder

integration_tests/text_generation_tests/test_server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ async def test_mt0(server_fixture, test_cases):
396396
# test with tiny GPTBigCode model for the merged kv cache
397397
@pytest.mark.model("bigcode/tiny_starcoder_py")
398398
@pytest.mark.extensions(".safetensors,.json")
399-
@pytest.mark.shards(1)
399+
@pytest.mark.shards(2)
400400
@pytest.mark.test_case_file("test_cases_tinystarcoderpy.yaml")
401401
@pytest.mark.asyncio
402402
async def test_gptbigcode(server_fixture, test_cases):
@@ -405,7 +405,7 @@ async def test_gptbigcode(server_fixture, test_cases):
405405
# test with Llama model which has tokenizer.add_bos_token == true
406406
@pytest.mark.model("Maykeye/TinyLLama-v0")
407407
@pytest.mark.extensions(".bin,.json,.model")
408-
@pytest.mark.shards(1)
408+
@pytest.mark.shards(2)
409409
@pytest.mark.test_case_file("test_cases_tinyllama.yaml")
410410
@pytest.mark.asyncio
411411
async def test_llama(server_fixture, test_cases):

server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,10 +409,9 @@ def forward(
409409
raise ValueError(
410410
"You cannot specify both input_ids and inputs_embeds at the same time"
411411
)
412-
412+
413413
if inputs_embeds is not None:
414414
hidden_states = inputs_embeds + self.wpe(position_ids)
415-
# TODO: support TP for the position embeddings
416415
else:
417416
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
418417

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ def from_pb(
156156

157157
# convert all requests to embeddings if any request has a prefix_id
158158
if prefix_ids:
159-
# TODO: Handle TP distributed embeddings layer
160159
inputs_embeds = embeddings_lookup(input_ids)
161160
input_ids = None
162161
# fill in the prefix embeddings into the space that we already

server/text_generation_server/models/model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,21 @@ def __init__(self, engine: BaseInferenceEngine, dtype: torch.dtype):
5353
decoder_start_token_id = self.model.config.decoder_start_token_id
5454
if decoder_start_token_id is None:
5555
decoder_start_token_id = self.tokenizer.bos_token_id
56+
57+
return_zero = False
58+
# If the word_embeddings layer is configured not to reduce at the end of the forward() call
59+
# each shard will have only a partial tensor. This tensor cannot be concatenated with a
60+
# prefix tensor in each shard because the reduce that happens afterwards would result
61+
# in adding the prefix N times, where N is the world size.
62+
if isinstance(self.word_embeddings, TensorParallelEmbedding) and not self.word_embeddings.reduce:
63+
return_zero = self.word_embeddings.process_group.rank() != 0
64+
5665
self.prefix_cache = PrefixCache(
5766
device=self.device,
5867
dtype=dtype,
5968
max_length=MAX_PROMPT_PREFIX_LENGTH,
6069
encoder_decoder=self.model.config.is_encoder_decoder,
70+
return_zero=return_zero,
6171
decoder_start_tok_embedding=self.word_embeddings(
6272
torch.tensor([decoder_start_token_id], device=self.device, dtype=torch.long)
6373
) if decoder_start_token_id is not None else None,

server/text_generation_server/prompt_cache.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pathlib import Path
55
import re
66
import threading
7-
from typing import Dict, List, Union, Tuple
7+
from typing import Dict, List, Union, Tuple, Optional
88

99
import torch
1010

@@ -149,6 +149,7 @@ def __init__(
149149
dtype: torch.dtype,
150150
max_length: int,
151151
encoder_decoder: bool,
152+
return_zero: Optional[bool],
152153
decoder_start_tok_embedding: torch.Tensor,
153154
):
154155
self.max_length = max_length
@@ -158,6 +159,7 @@ def __init__(
158159
self.dtype = dtype
159160

160161
self.is_encoder_decoder = encoder_decoder
162+
self.zero = torch.zeros((1,), dtype=dtype, device=device) if return_zero else None
161163
self.decoder_start_tok_embedding = decoder_start_tok_embedding
162164

163165
self.cache_map: Dict[str, PromptCacheNode] = {}
@@ -210,23 +212,34 @@ def _load_embedding_tensors(self, prefix_id: str) -> Union[torch.Tensor, Tuple[t
210212
decoder_prefix = self._load_embedding_tensor(prefix_id, "decoder.pt", dtype=self.dtype)
211213
# For encoder-decoder we store a tuple of (encoder_prefix, decoder_prefix),
212214
# at least one must be non-None
215+
if decoder_prefix is not None:
216+
if self.zero is not None:
217+
decoder_prefix = self.zero.expand(decoder_prefix.shape)
218+
else:
219+
decoder_prefix = decoder_prefix.to(self.dtype).to(self.device, non_blocking=True)
220+
213221
if self.is_encoder_decoder:
214222
encoder_prefix = self._load_embedding_tensor(prefix_id, "encoder.pt", dtype=self.dtype)
215223
if decoder_prefix is None:
216224
if encoder_prefix is None:
217225
raise PrefixNotFound(f"Prefix id {prefix_id} not found")
218226
else:
219-
decoder_prefix = decoder_prefix.to(self.device, non_blocking=True)
220227
# TODO confirm this cat is correct
221-
decoder_prefix = torch.cat((decoder_prefix, self.decoder_start_tok_embedding))
228+
if self.zero is not None:
229+
decoder_prefix = self.zero.expand(decoder_prefix.shape[0] + 1, *decoder_prefix.shape[1:])
230+
else:
231+
decoder_prefix = torch.cat((decoder_prefix, self.decoder_start_tok_embedding))
222232
if encoder_prefix is not None:
223-
encoder_prefix = encoder_prefix.to(self.device, non_blocking=True)
233+
if self.zero is not None:
234+
encoder_prefix = self.zero.expand(encoder_prefix.shape)
235+
else:
236+
encoder_prefix = encoder_prefix.to(self.device, non_blocking=True)
224237
prefix = encoder_prefix, decoder_prefix
225238
# For decoder-only we store just the decoder prefix
226239
elif decoder_prefix is None:
227240
raise PrefixNotFound(f"Prefix id {prefix_id} not found")
228241
else:
229-
prefix = decoder_prefix.to(self.dtype).to(self.device, non_blocking=True)
242+
prefix = decoder_prefix
230243
return prefix
231244

232245
def _load_embedding_tensor(self, prefix_id: str, filename: str, dtype: torch.dtype) -> torch.Tensor:

0 commit comments

Comments
 (0)