Skip to content

Commit d2fcbbc

Browse files
artek0chumakmryab
andauthored
Add Mixtral models (#553)
* Add somehow workable version * Fix generation * Fixes * Choose right attn * style * fix bloom * remove unnes * Update src/petals/models/mixtral/model.py Co-authored-by: Max Ryabinin <[email protected]> * fix order of init --------- Co-authored-by: Max Ryabinin <[email protected]>
1 parent 2ad0b2b commit d2fcbbc

File tree

7 files changed

+344
-2
lines changed

7 files changed

+344
-2
lines changed

src/petals/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from petals.models.bloom import *
22
from petals.models.falcon import *
33
from petals.models.llama import *
4+
from petals.models.mixtral import *
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from petals.models.mixtral.block import WrappedMixtralBlock
2+
from petals.models.mixtral.config import DistributedMixtralConfig
3+
from petals.models.mixtral.model import (
4+
DistributedMixtralForCausalLM,
5+
DistributedMixtralForSequenceClassification,
6+
DistributedMixtralModel,
7+
)
8+
from petals.utils.auto_config import register_model_classes
9+
10+
register_model_classes(
11+
config=DistributedMixtralConfig,
12+
model=DistributedMixtralModel,
13+
model_for_causal_lm=DistributedMixtralForCausalLM,
14+
model_for_sequence_classification=DistributedMixtralForSequenceClassification,
15+
)

src/petals/models/mixtral/block.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
from typing import Optional, Tuple
2+
3+
import torch
4+
from transformers import MixtralConfig
5+
from transformers.cache_utils import DynamicCache
6+
from transformers.modeling_attn_mask_utils import (
7+
_prepare_4d_causal_attention_mask,
8+
_prepare_4d_causal_attention_mask_for_sdpa,
9+
)
10+
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel
11+
12+
13+
class WrappedMixtralBlock(MixtralDecoderLayer):
14+
def __init__(self, config: MixtralConfig, layer_idx: int):
15+
super().__init__(config, layer_idx)
16+
17+
self._attn_implementation = config._attn_implementation
18+
self.sliding_window = config.sliding_window
19+
self.layer_idx = layer_idx
20+
21+
def forward(
22+
self,
23+
hidden_states: torch.Tensor,
24+
*args,
25+
attention_mask: Optional[torch.Tensor] = None,
26+
layer_past: Optional[Tuple[torch.Tensor]] = None,
27+
use_cache: bool = False,
28+
**kwargs
29+
):
30+
batch_size, seq_length, _ = hidden_states.shape
31+
32+
seq_length_with_past = seq_length
33+
past_key_values_length = 0
34+
35+
past_key_value = layer_past
36+
if past_key_value is not None:
37+
past_key_values_length = past_key_value[0].shape[2]
38+
seq_length_with_past = seq_length_with_past + past_key_values_length
39+
_past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
40+
past_key_value = DynamicCache()
41+
for idx in range(self.layer_idx):
42+
past_key_value.update(
43+
torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx
44+
)
45+
past_key_value.update(_past_key_value[0], _past_key_value[1], self.layer_idx)
46+
47+
if self._attn_implementation == "flash_attention_2":
48+
# 2d mask is passed through the layers
49+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
50+
elif self._attn_implementation == "sdpa":
51+
# output_attentions=True can not be supported when using SDPA, and we fall back on
52+
# the manual implementation that requires a 4D causal mask in all cases.
53+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
54+
attention_mask,
55+
(batch_size, seq_length),
56+
hidden_states,
57+
past_key_values_length,
58+
)
59+
else:
60+
# 4d mask is passed through the layers
61+
attention_mask = _prepare_4d_causal_attention_mask(
62+
attention_mask,
63+
(batch_size, seq_length),
64+
hidden_states,
65+
past_key_values_length,
66+
sliding_window=self.sliding_window,
67+
)
68+
69+
position_ids = torch.arange(
70+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=hidden_states.device
71+
)
72+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
73+
74+
outputs = super().forward(
75+
hidden_states,
76+
*args,
77+
attention_mask=attention_mask,
78+
position_ids=position_ids,
79+
past_key_value=past_key_value,
80+
use_cache=use_cache,
81+
**kwargs
82+
)
83+
84+
if use_cache:
85+
present_key_value = outputs[-1]
86+
present_key_value = present_key_value.to_legacy_cache()[self.layer_idx]
87+
present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past)
88+
outputs = outputs[:-1] + (present_key_value,)
89+
90+
return outputs
91+
92+
def _reorder_cache_from_bloom(
93+
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
94+
) -> Tuple[torch.Tensor]:
95+
# TODO: Move to mixin
96+
key_states, value_states = key_value
97+
key_states = key_states.permute(0, 2, 1)
98+
key_states = key_states.view(
99+
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
100+
)
101+
value_states = value_states.view(*key_states.shape)
102+
return (key_states, value_states)
103+
104+
def _reorder_cache_to_bloom(
105+
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
106+
) -> Tuple[torch.Tensor]:
107+
# TODO: Move to mixin
108+
key_states, value_states = key_value
109+
value_states = value_states.view(
110+
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
111+
)
112+
key_states = key_states.view(*value_states.shape)
113+
key_states = key_states.permute(0, 2, 1)
114+
return (key_states, value_states)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import os
2+
from typing import Optional, Union
3+
4+
from hivemind import get_logger
5+
from transformers.models.mixtral import MixtralConfig
6+
from transformers.models.mixtral.modeling_mixtral import MixtralAttention
7+
8+
from petals.client.config import ClientConfig
9+
from petals.client.lm_head import LMHeadConfig
10+
from petals.client.ptune import PTuneConfig
11+
from petals.models.mixtral.block import WrappedMixtralBlock
12+
13+
logger = get_logger(__name__)
14+
15+
16+
class DistributedMixtralConfig(MixtralConfig, ClientConfig, PTuneConfig, LMHeadConfig):
17+
block_class = WrappedMixtralBlock
18+
attn_class = MixtralAttention
19+
block_prefix = "model.layers"
20+
21+
num_key_value_groups = 1
22+
23+
@classmethod
24+
def from_pretrained(
25+
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
26+
):
27+
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
28+
if loading_from_repo and dht_prefix is None:
29+
dht_prefix = str(model_name_or_path)
30+
dht_prefix = dht_prefix.replace(".", "-")
31+
logger.info(f"Using DHT prefix: {dht_prefix}")
32+
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
33+
config = result[0] if isinstance(result, tuple) else result
34+
if config.pad_token_id is None:
35+
config.pad_token_id = 0
36+
return result

src/petals/models/mixtral/model.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
from typing import Optional
2+
3+
import torch
4+
import torch.nn as nn
5+
from hivemind import DHT
6+
from hivemind.utils.logging import get_logger
7+
from transformers.modeling_outputs import MoeModelOutputWithPast
8+
from transformers.models.mixtral import (
9+
MixtralForCausalLM,
10+
MixtralForSequenceClassification,
11+
MixtralModel,
12+
MixtralPreTrainedModel,
13+
)
14+
15+
from petals.client.from_pretrained import FromPretrainedMixin
16+
from petals.client.lm_head import LMHead
17+
from petals.client.ptune import PTuneMixin
18+
from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
19+
from petals.client.remote_sequential import RemoteSequential
20+
from petals.models.mixtral.config import DistributedMixtralConfig
21+
from petals.utils.auto_config import DefaultRevisionMixin
22+
23+
logger = get_logger(__name__)
24+
25+
26+
class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, MixtralModel):
27+
"""MixtralModel, but all transformer layers are hosted by the swarm"""
28+
29+
_keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
30+
_keys_to_ignore_on_load_unexpected = [r"^model\.layers\."]
31+
32+
config_class = DistributedMixtralConfig
33+
34+
def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[DHT] = None):
35+
n_layer, config.num_hidden_layers = config.num_hidden_layers, 0 # Prevent initialization
36+
super().__init__(config)
37+
assert len(self.layers) == 0
38+
config.num_hidden_layers = n_layer
39+
40+
self.layers = RemoteSequential(config, dht=dht)
41+
42+
self.requires_grad_(False) # Forbid accumulate grads for embeddings and layernorm
43+
self.init_prompts(config)
44+
45+
def forward(
46+
self,
47+
input_ids: Optional[torch.LongTensor] = None,
48+
past_key_values: Optional[RemotePastKeyValues] = None,
49+
attention_mask: Optional[torch.Tensor] = None,
50+
position_ids: Optional[torch.LongTensor] = None,
51+
head_mask: Optional[torch.LongTensor] = None,
52+
inputs_embeds: Optional[torch.LongTensor] = None,
53+
use_cache: Optional[bool] = None,
54+
output_attentions: Optional[bool] = None,
55+
output_hidden_states: Optional[bool] = None,
56+
output_router_logits: Optional[bool] = None,
57+
return_dict: Optional[bool] = None,
58+
):
59+
if input_ids is not None and inputs_embeds is not None:
60+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
61+
elif input_ids is not None:
62+
input_shape = input_ids.size()
63+
input_ids = input_ids.view(-1, input_shape[-1])
64+
elif inputs_embeds is not None:
65+
input_shape = inputs_embeds.size()[:-1]
66+
else:
67+
raise ValueError("You have to specify either input_ids or inputs_embeds")
68+
69+
# The causal mask will be added on the server-side
70+
assert (
71+
attention_mask is None or (attention_mask == 1).all()
72+
), f"Custom attention masks are not supported, {attention_mask=}"
73+
assert (
74+
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
75+
), f"Non-consecutive position_ids are not supported, {position_ids=}"
76+
assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
77+
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
78+
assert not output_attentions, f"{output_attentions=} is not supported"
79+
assert not output_hidden_states, f"{output_hidden_states=} is not supported"
80+
assert return_dict is None or return_dict, f"{return_dict=} is not supported"
81+
assert not output_router_logits, f"{output_router_logits=} is not supported"
82+
83+
if inputs_embeds is None:
84+
inputs_embeds = self.embed_tokens(input_ids)
85+
86+
use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
87+
if use_prompts:
88+
batch_size = inputs_embeds.shape[0]
89+
prompts, intermediate_prompts = self.get_prompt(batch_size)
90+
inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
91+
else:
92+
prompts = intermediate_prompts = None
93+
94+
hidden_states = inputs_embeds
95+
output_shape = input_shape + (hidden_states.size(-1),)
96+
97+
if past_key_values is None:
98+
past_key_values = RemotePastKeyValues()
99+
past_key_values.update_seen(hidden_states.size(1))
100+
101+
hidden_states = self.layers(
102+
hidden_states,
103+
prompts=intermediate_prompts,
104+
hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
105+
)
106+
107+
# Remove prefix
108+
if use_prompts:
109+
hidden_states = hidden_states[:, self.pre_seq_len :]
110+
111+
# Add last hidden state
112+
hidden_states = self.norm(hidden_states)
113+
hidden_states = hidden_states.view(output_shape)
114+
return MoeModelOutputWithPast(
115+
last_hidden_state=hidden_states,
116+
past_key_values=past_key_values,
117+
hidden_states=None,
118+
attentions=None,
119+
)
120+
121+
@property
122+
def word_embeddings(self) -> nn.Embedding: # For compatibility with RemoteGenerationMixin
123+
return self.embed_tokens
124+
125+
@property
126+
def h(self) -> RemoteSequential: # For compatibility with RemoteGenerationMixin
127+
return self.layers
128+
129+
130+
class DistributedMixtralForCausalLM(
131+
DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM
132+
):
133+
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
134+
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
135+
136+
config_class = DistributedMixtralConfig
137+
138+
def __init__(self, config: DistributedMixtralConfig):
139+
MixtralPreTrainedModel.__init__(self, config)
140+
self.model = DistributedMixtralModel(config)
141+
self.lm_head = LMHead(config)
142+
143+
# Initialize weights and apply final processing
144+
self.post_init()
145+
146+
def get_output_embeddings(self):
147+
return self.lm_head
148+
149+
@property
150+
def transformer(self) -> DistributedMixtralModel: # For compatibility with RemoteGenerationMixin
151+
return self.model
152+
153+
154+
class DistributedMixtralForSequenceClassification(
155+
DefaultRevisionMixin, FromPretrainedMixin, MixtralForSequenceClassification
156+
):
157+
def __init__(self, config: DistributedMixtralConfig):
158+
MixtralPreTrainedModel.__init__(self, config)
159+
self.num_labels = config.num_labels
160+
161+
self.model = DistributedMixtralModel(config)
162+
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
163+
164+
# Initialize weights and apply final processing
165+
self.post_init()
166+
167+
@property
168+
def transformer(self) -> DistributedMixtralModel: # For compatibility with RemoteGenerationMixin
169+
return self.model

src/petals/server/backend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> S
9191
cache_tensors = []
9292
for device, num_heads in zip(self.module.devices, self.shard_num_heads):
9393
num_heads //= self.config.num_key_value_groups
94+
if hasattr(self.config, "num_key_value_heads"):
95+
num_heads = self.config.num_key_value_heads
9496
keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)
9597
values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device)
9698
cache_tensors.extend((keys, values))

src/petals/server/from_pretrained.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
from hivemind.utils.logging import get_logger
2020
from huggingface_hub import get_hf_file_metadata, hf_hub_url
2121
from huggingface_hub.utils import EntryNotFoundError
22-
from transformers import PretrainedConfig
22+
from transformers import PretrainedConfig, PreTrainedModel
2323
from transformers.utils import get_file_from_repo
2424

2525
from petals.constants import DTYPE_MAP
26+
from petals.models.mixtral import WrappedMixtralBlock
2627
from petals.server.block_utils import resolve_block_dtype
2728
from petals.utils.auto_config import AutoDistributedConfig
2829
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
@@ -51,7 +52,11 @@ def load_pretrained_block(
5152
torch_dtype = resolve_block_dtype(config, torch_dtype)
5253

5354
with init_empty_weights():
54-
block = config.block_class(config)
55+
if config.block_class == WrappedMixtralBlock:
56+
config = PreTrainedModel._autoset_attn_implementation(config)
57+
block = config.block_class(config, block_index)
58+
else:
59+
block = config.block_class(config)
5560

5661
block_prefix = f"{config.block_prefix}.{block_index}."
5762
state_dict = _load_state_dict_from_repo(

0 commit comments

Comments
 (0)