Skip to content

Commit d6f4f80

Browse files
artek0chumakmryab
andauthored
Fix Mixtral-related issues (#570)
This PR fixes problems related to #569: - block initialization - throughput calculation and cache usage - mixtral in tests Beam search is removed for Mixtral and Llama for now. Those models use DynamicCache, which requires special function to change: (see https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py#L161) --------- Co-authored-by: Max Ryabinin <[email protected]>
1 parent d2fcbbc commit d6f4f80

File tree

13 files changed

+79
-34
lines changed

13 files changed

+79
-34
lines changed

.github/workflows/run-tests.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ jobs:
1616
- { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
1717
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
1818
- { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
19+
- { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.8' }
20+
- { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.11' }
1921
fail-fast: false
2022
runs-on: ${{ matrix.os }}-latest
2123
timeout-minutes: 20

src/petals/client/remote_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def update_seen(self, new_seen: int) -> None:
3838
self.seen_tokens += new_seen
3939

4040
def reorder_cache(self, beam_idx):
41-
pass
41+
raise NotImplementedError("Beam search reordering is not implemented yet")
4242

4343

4444
_skipped_tokens = ContextVar("skipped_tokens", default=0)

src/petals/models/bloom/block.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
1010
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
1111

12+
from petals.utils.misc import is_dummy
13+
1214

1315
class WrappedBloomBlock(BloomBlock):
1416
def forward(
@@ -22,6 +24,10 @@ def forward(
2224
):
2325
assert attention_mask is None, "Non-causal attention masks are not supported yet"
2426
batch_size, seq_length = hidden_states.shape[:2]
27+
if layer_past is not None and is_dummy(layer_past[0]):
28+
# Bloom cannot use cache if it was misconsctructed(e.g. Dummy tensors)
29+
# In this case, fallback to the old code:
30+
layer_past = None
2531
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
2632
seq_length_with_past = seq_length + past_length
2733
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)

src/petals/models/mixtral/block.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from typing import Optional, Tuple
23

34
import torch
@@ -33,16 +34,15 @@ def forward(
3334
past_key_values_length = 0
3435

3536
past_key_value = layer_past
37+
3638
if past_key_value is not None:
3739
past_key_values_length = past_key_value[0].shape[2]
3840
seq_length_with_past = seq_length_with_past + past_key_values_length
3941
_past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
4042
past_key_value = DynamicCache()
41-
for idx in range(self.layer_idx):
42-
past_key_value.update(
43-
torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx
44-
)
45-
past_key_value.update(_past_key_value[0], _past_key_value[1], self.layer_idx)
43+
past_key_value.key_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[0]]
44+
past_key_value.value_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[1]]
45+
past_key_value._seen_tokens = past_key_values_length
4646

4747
if self._attn_implementation == "flash_attention_2":
4848
# 2d mask is passed through the layers
@@ -83,7 +83,7 @@ def forward(
8383

8484
if use_cache:
8585
present_key_value = outputs[-1]
86-
present_key_value = present_key_value.to_legacy_cache()[self.layer_idx]
86+
present_key_value = present_key_value[self.layer_idx]
8787
present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past)
8888
outputs = outputs[:-1] + (present_key_value,)
8989

src/petals/models/mixtral/model.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,20 @@ def forward(
122122
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
123123
return self.embed_tokens
124124

125+
@property
126+
def word_embeddings_layernorm(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests
127+
return nn.Identity()
128+
125129
@property
126130
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
127131
return self.layers
128132

133+
@property
134+
def ln_f(self) -> nn.Module: # For compatibility with RemoteGenerationMixin in tests
135+
return self.norm
129136

130-
class DistributedMixtralForCausalLM(
131-
DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM
132-
):
137+
138+
class DistributedMixtralForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM):
133139
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
134140
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
135141

@@ -151,9 +157,12 @@ def transformer(self) -> DistributedMixtralModel: # For compatibility with Remo
151157
return self.model
152158

153159

154-
class DistributedMixtralForSequenceClassification(
155-
DefaultRevisionMixin, FromPretrainedMixin, MixtralForSequenceClassification
156-
):
160+
class DistributedMixtralForSequenceClassification(FromPretrainedMixin, MixtralForSequenceClassification):
161+
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
162+
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
163+
164+
config_class = DistributedMixtralConfig
165+
157166
def __init__(self, config: DistributedMixtralConfig):
158167
MixtralPreTrainedModel.__init__(self, config)
159168
self.num_labels = config.num_labels

src/petals/server/block_utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
import torch
44
from accelerate import init_empty_weights
5-
from transformers import PretrainedConfig
5+
from transformers import PretrainedConfig, PreTrainedModel
66

7+
from petals.models.mixtral.block import WrappedMixtralBlock
78
from petals.utils.convert_block import QuantType
89
from petals.utils.misc import get_size_in_bytes
910

@@ -32,7 +33,7 @@ def get_block_size(
3233
), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations'
3334

3435
with init_empty_weights(include_buffers=True):
35-
block = config.block_class(config)
36+
block = get_model_block(config)
3637
n_params = sum(param.numel() for param in block.parameters())
3738

3839
if location == "memory":
@@ -50,3 +51,15 @@ def get_block_size(
5051
bytes_per_value = get_size_in_bytes(dtype)
5152

5253
return round(n_params * bytes_per_value * (1 + eps))
54+
55+
56+
def get_model_block(config, layer_idx: int = 0):
57+
"""
58+
The function to create a model block based on the block class
59+
kwargs argument **only** is necessary for specific classes, like Mixtral.
60+
They will not be passed to other block constructors.
61+
"""
62+
if config.block_class == WrappedMixtralBlock:
63+
config = PreTrainedModel._autoset_attn_implementation(config)
64+
return config.block_class(config, layer_idx)
65+
return config.block_class(config)

src/petals/server/from_pretrained.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from petals.constants import DTYPE_MAP
2626
from petals.models.mixtral import WrappedMixtralBlock
27-
from petals.server.block_utils import resolve_block_dtype
27+
from petals.server.block_utils import get_model_block, resolve_block_dtype
2828
from petals.utils.auto_config import AutoDistributedConfig
2929
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
3030
from petals.utils.hf_auth import always_needs_auth
@@ -52,11 +52,7 @@ def load_pretrained_block(
5252
torch_dtype = resolve_block_dtype(config, torch_dtype)
5353

5454
with init_empty_weights():
55-
if config.block_class == WrappedMixtralBlock:
56-
config = PreTrainedModel._autoset_attn_implementation(config)
57-
block = config.block_class(config, block_index)
58-
else:
59-
block = config.block_class(config)
55+
block = get_model_block(config, layer_idx=block_index)
6056

6157
block_prefix = f"{config.block_prefix}.{block_index}."
6258
state_dict = _load_state_dict_from_repo(

src/petals/server/throughput.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
from hivemind.utils.logging import get_logger
1414
from transformers import PretrainedConfig
1515

16-
from petals.server.block_utils import resolve_block_dtype
16+
from petals.server.block_utils import get_model_block, resolve_block_dtype
1717
from petals.utils.convert_block import QuantType, convert_block
1818
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
19+
from petals.utils.misc import DUMMY_KEY_PAST
1920

2021
logger = get_logger(__name__)
2122

@@ -201,18 +202,25 @@ def measure_compute_rps(
201202
if not tensor_parallel_devices:
202203
tensor_parallel_devices = (device,)
203204
with torch.inference_mode():
204-
block = config.block_class(config).to(dtype)
205+
block = get_model_block(config)
206+
block = block.to(dtype)
205207
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
206208

207-
cache = None
209+
cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype))
208210
elapsed = 0
209211
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
210-
_, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time
212+
213+
# Skip the 1st step to exclude the initialization time
214+
def step(cache_):
215+
outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None)
216+
return outputs[1] if inference else None
217+
218+
cache = step(cache)
211219
synchronize(device)
212220

213221
start_time = time.perf_counter()
214222
for _ in range(n_steps):
215-
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
223+
cache = step(cache)
216224
synchronize(device)
217225
elapsed = time.perf_counter() - start_time
218226
device_rps = n_steps * n_tokens / elapsed

src/petals/utils/misc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
DUMMY_INT64 = torch.empty(0, dtype=torch.int64)
66

7+
DUMMY_KEY_PAST = torch.empty((0, 0, 0))
8+
79

810
def is_dummy(tensor: torch.Tensor) -> bool:
911
return tensor.numel() == 0

src/petals/utils/peft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from safetensors.torch import load_file
1818
from transformers.utils import get_file_from_repo
1919

20-
from petals.server.block_utils import resolve_block_dtype
20+
from petals.server.block_utils import get_model_block, resolve_block_dtype
2121
from petals.utils.convert_block import QuantType
2222
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
2323
from petals.utils.misc import get_size_in_bytes
@@ -273,7 +273,7 @@ def estimate_adapter_memory_per_block(
273273
) -> int:
274274
"""Get the number of extra bytes used to store a set of adapters per given block"""
275275
with init_empty_weights(include_buffers=True):
276-
block = block_config.block_class(block_config)
276+
block = get_model_block(block_config)
277277
base_block_parameters = sum(p.numel() for p in block.parameters())
278278
create_lora_adapter(block, quant_type=QuantType.NONE)
279279

0 commit comments

Comments
 (0)