Skip to content

Commit daa1d81

Browse files
feat(server): Support Galactica (IBM#4)
1 parent d6d5b12 commit daa1d81

File tree

7 files changed

+383
-30
lines changed

7 files changed

+383
-30
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
2121
- [BLOOM](https://huggingface.co/bigscience/bloom)
2222
- [BLOOMZ](https://huggingface.co/bigscience/bloomz)
2323
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
24+
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated)
2425

2526
Other models are supported on a best effort basis using:
2627

server/Makefile

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ gen-server:
99
install-transformers:
1010
# Install specific version of transformers with custom cuda kernels
1111
rm transformers || true
12-
rm transformers-b55f16c5b71aeef47a66a4270e19c154f050a7a7 || true
13-
curl -L -O https://github.com/OlivierDehaene/transformers/archive/b55f16c5b71aeef47a66a4270e19c154f050a7a7.zip
14-
unzip b55f16c5b71aeef47a66a4270e19c154f050a7a7.zip
15-
rm b55f16c5b71aeef47a66a4270e19c154f050a7a7.zip
16-
mv transformers-b55f16c5b71aeef47a66a4270e19c154f050a7a7 transformers
12+
rm transformers-text_generation_inference || true
13+
curl -L -O https://github.com/OlivierDehaene/transformers/archive/refs/heads/text_generation_inference.zip
14+
unzip text_generation_inference.zip
15+
rm text_generation_inference.zip
16+
mv transformers-text_generation_inference transformers
1717
cd transformers && python setup.py install
1818

1919
install-torch:

server/text_generation/models/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from text_generation.models.causal_lm import CausalLM
33
from text_generation.models.bloom import BLOOMSharded
44
from text_generation.models.seq2seq_lm import Seq2SeqLM
5+
from text_generation.models.galactica import Galactica, GalacticaSharded
56

67
__all__ = ["Model", "BLOOMSharded", "CausalLM", "Seq2SeqLM"]
78

@@ -12,6 +13,11 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
1213
return BLOOMSharded(model_name, quantize=quantize)
1314
else:
1415
return CausalLM(model_name, quantize=quantize)
16+
elif model_name.startswith("facebook/galactica"):
17+
if sharded:
18+
return GalacticaSharded(model_name, quantize=quantize)
19+
else:
20+
return Galactica(model_name, quantize=quantize)
1521
else:
1622
if sharded:
1723
raise ValueError("sharded is not supported for AutoModel")

server/text_generation/models/bloom.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def __init__(self, model_name: str, quantize: bool = False):
6363

6464
torch.distributed.barrier(group=self.process_group)
6565
filenames = weight_files(model_name, extension=".safetensors")
66+
if not filenames:
67+
raise ValueError("No safetensors weights found")
6668

6769
with init_empty_weights():
6870
model = AutoModelForCausalLM.from_config(config)

server/text_generation/models/causal_lm.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -156,31 +156,29 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
156156
past_keys = past_keys.view(batch.size, -1, *past_keys.shape[-2:])
157157
past_values = past_values.view(batch.size, -1, *past_values.shape[-2:])
158158

159-
_, num_heads, head_dim, padded_sequence_length = past_keys.shape
159+
_, num_heads, padded_sequence_length, head_dim = past_values.shape
160160

161-
padded_past_keys_shape = (
161+
padded_past_values_shape = (
162162
total_batch_size,
163163
num_heads,
164-
head_dim,
165164
max_sequence_length - 1,
165+
head_dim,
166166
)
167167

168-
# head_dim is last for BLOOM
169-
if past_values.shape[-1] == head_dim:
170-
past_values_head_dim_last = True
171-
padded_past_values_shape = (
168+
# seq_length is last for BLOOM
169+
if past_keys.shape[-2] == head_dim:
170+
past_keys_head_dim_last = False
171+
padded_past_keys_shape = (
172172
total_batch_size,
173173
num_heads,
174-
max_sequence_length - 1,
175174
head_dim,
175+
max_sequence_length - 1,
176176
)
177-
elif past_values.shape[-2] == head_dim:
178-
past_values_head_dim_last = False
179-
padded_past_values_shape = padded_past_keys_shape
177+
elif past_keys.shape[-1] == head_dim:
178+
past_keys_head_dim_last = True
179+
padded_past_keys_shape = padded_past_values_shape
180180
else:
181-
raise ValueError(
182-
f"past_values shape {past_values.shape} is not valid"
183-
)
181+
raise ValueError(f"past_keys shape {past_keys.shape} is not valid")
184182

185183
# This will run only once per layer
186184
if j == len(past_key_values):
@@ -197,24 +195,24 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
197195
past_key_values.append((padded_past_keys, padded_past_values))
198196

199197
# We slice the past keys and values to remove the padding from previous batches
200-
past_key_values[j][0][
201-
start_index:end_index, :, :, -(batch.max_sequence_length - 1) :
202-
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
203-
204-
if past_values_head_dim_last:
205-
past_key_values[j][1][
198+
if past_keys_head_dim_last:
199+
past_key_values[j][0][
206200
start_index:end_index,
207201
:,
208202
-(batch.max_sequence_length - 1) :,
209203
:,
210-
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
204+
] = past_keys[:, :, -(batch.max_sequence_length - 1) :, :]
211205
else:
212-
past_key_values[j][1][
206+
past_key_values[j][0][
213207
start_index:end_index,
214208
:,
215209
:,
216210
-(batch.max_sequence_length - 1) :,
217-
] = past_values[:, :, :, -(batch.max_sequence_length - 1) :]
211+
] = past_keys[:, :, :, -(batch.max_sequence_length - 1) :]
212+
213+
past_key_values[j][1][
214+
start_index:end_index, :, -(batch.max_sequence_length - 1) :, :
215+
] = past_values[:, :, -(batch.max_sequence_length - 1) :, :]
218216

219217
start_index += batch.size
220218

@@ -243,13 +241,13 @@ def __init__(self, model_name: str, quantize=False):
243241
dtype = torch.float32
244242

245243
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
246-
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
247244
self.model = AutoModelForCausalLM.from_pretrained(
248245
model_name,
249246
torch_dtype=dtype,
250247
device_map="auto" if torch.cuda.is_available() else None,
251248
load_in_8bit=quantize,
252249
).eval()
250+
tokenizer.pad_token_id = self.model.config.pad_token_id
253251

254252
super(CausalLM, self).__init__(
255253
tokenizer=tokenizer,

0 commit comments

Comments
 (0)