Skip to content

Commit 0f57e51

Browse files
authored
feat(transformers): add MiniMax (#1186)
* add MiniMax * add unit tests * add transformer version for model imports * small fix * add copyright
1 parent 48a81b1 commit 0f57e51

File tree

13 files changed

+1808
-16
lines changed

13 files changed

+1808
-16
lines changed

mindone/transformers/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,4 +539,12 @@
539539
Glm4vTextModel,
540540
Glm4vVisionModel,
541541
)
542+
from .models.minimax import (
543+
MiniMaxForCausalLM,
544+
MiniMaxForQuestionAnswering,
545+
MiniMaxForSequenceClassification,
546+
MiniMaxForTokenClassification,
547+
MiniMaxModel,
548+
MiniMaxPreTrainedModel,
549+
)
542550
from .models.vjepa2 import VJEPA2ForVideoClassification, VJEPA2Model, VJEPA2PreTrainedModel

mindone/transformers/generation/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1830,6 +1830,7 @@ def _supports_default_dynamic_cache(self) -> bool:
18301830
and "jamba" not in self.__class__.__name__.lower()
18311831
and "zamba" not in self.__class__.__name__.lower()
18321832
and "bamba" not in self.__class__.__name__.lower()
1833+
and "minimax" not in self.__class__.__name__.lower()
18331834
)
18341835

18351836
def _supports_default_dynamic_input(self) -> bool:

mindone/transformers/masking_utils.py

Lines changed: 160 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,25 @@ def causal_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int)
6464
return kv_idx <= q_idx
6565

6666

67+
def sliding_window_overlay(sliding_window: int) -> Callable:
68+
"""
69+
This is an overlay depicting a sliding window pattern. Add it on top of a causal mask for a proper sliding
70+
window mask.
71+
"""
72+
73+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
74+
return kv_idx > q_idx - sliding_window
75+
76+
return inner_mask
77+
78+
79+
def sliding_window_causal_mask_function(sliding_window: int) -> Callable:
80+
"""
81+
This return the mask_function function to create a sliding window mask.
82+
"""
83+
return and_masks(sliding_window_overlay(sliding_window), causal_mask_function)
84+
85+
6786
def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
6887
"""
6988
Used to vmap our mask_functions over the q_idx and kv_idx dimensions of the inputs. Optionally, vmap over
@@ -280,12 +299,65 @@ def eager_mask(
280299
return mask
281300

282301

302+
def flash_attention_mask(
303+
batch_size: int,
304+
cache_position: ms.Tensor,
305+
kv_length: int,
306+
kv_offset: int = 0,
307+
mask_function: Callable = causal_mask_function,
308+
attention_mask: Optional[ms.Tensor] = None,
309+
**kwargs,
310+
):
311+
"""
312+
Create the attention mask necesary to use FA2. Since FA2 is un-padded by definition, here we simply return
313+
`None` if the mask is fully causal, or we return the 2D mask which will then be used to extract the seq_lens.
314+
We just slice it in case of sliding window.
315+
316+
Args:
317+
batch_size (`int`):
318+
The batch size of the input sequence.
319+
cache_position (`ms.Tensor`):
320+
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
321+
kv_length (`int`):
322+
The size that the key and value states will have during the attention computation.
323+
kv_offset (`int`, optional):
324+
An optional offset to indicate at which first position the key and values states will refer to.
325+
mask_function (`Callable`):
326+
The mask factory function describing the mask pattern.
327+
attention_mask (`ms.Tensor`, optional):
328+
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
329+
"""
330+
if attention_mask is not None:
331+
# Here we need to slice from the right if using sliding or chunked (for full attention, this is equivalent to doing nothing)
332+
attention_mask = attention_mask[:, -kv_length:]
333+
# We only return an actual mask if there is at least 1 padding token, otherwise we return `None` and use `is_causal` in FA2
334+
# (note that the attention_mask is a boolean dtype here)
335+
if attention_mask.all():
336+
attention_mask = None
337+
338+
return attention_mask
339+
340+
341+
def flex_attention_mask(
342+
batch_size: int,
343+
cache_position: ms.Tensor,
344+
kv_length: int,
345+
kv_offset: int = 0,
346+
mask_function: Callable = causal_mask_function,
347+
attention_mask: Optional[ms.Tensor] = None,
348+
**kwargs,
349+
):
350+
raise NotImplementedError("`flex_attention` is not supported yet.")
351+
352+
283353
class AttentionMaskInterface(GeneralInterface):
284354
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
285355
# a new instance is created (in order to locally override a given function)
286356
_global_mapping = {
357+
"sdpa": sdpa_mask,
287358
"eager": eager_mask,
288-
"flash_attention_2": eager_mask,
359+
"flash_attention_2": flash_attention_mask,
360+
"flex_attention": flex_attention_mask,
289361
}
290362

291363

@@ -308,13 +380,13 @@ def _preprocess_mask_arguments(
308380
Args:
309381
config (`PretrainedConfig`):
310382
The model config.
311-
input_embeds (`torch.Tensor`):
383+
input_embeds (`ms.Tensor`):
312384
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
313385
batch size, query length and dtype.
314-
attention_mask (`torch.Tensor`, optional):
386+
attention_mask (`ms.Tensor`, optional):
315387
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
316388
It can also be an already prepared 4D mask, in which case it is returned as-is.
317-
cache_position (`torch.Tensor`):
389+
cache_position (`ms.Tensor`):
318390
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
319391
past_key_values (`Cache`, optional):
320392
The past key values, if we use a cache.
@@ -325,7 +397,7 @@ def _preprocess_mask_arguments(
325397
Returns:
326398
early_exit (`bool`):
327399
Whether we should early exit mask creation, and return the mask as-is.
328-
attention_mask (`torch.Tensor` or `BlockMask` or `None`):
400+
attention_mask (`ms.Tensor` or `BlockMask` or `None`):
329401
The attention mask to either return immediately, or to use in downstream mask creation.
330402
kv_length (`int`):
331403
The size that the key and value states will have during the attention computation.
@@ -375,13 +447,13 @@ def create_causal_mask(
375447
Args:
376448
config (`PretrainedConfig`):
377449
The model config.
378-
input_embeds (`torch.Tensor`):
450+
input_embeds (`ms.Tensor`):
379451
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
380452
batch size, query length and dtype.
381-
attention_mask (`torch.Tensor`, optional):
453+
attention_mask (`ms.Tensor`, optional):
382454
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
383455
It can also be an already prepared 4D mask, in which case it is returned as-is.
384-
cache_position (`torch.Tensor`):
456+
cache_position (`ms.Tensor`):
385457
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
386458
past_key_values (`Cache`, optional):
387459
The past key values, if we use a cache.
@@ -435,6 +507,86 @@ def create_causal_mask(
435507
return causal_mask
436508

437509

510+
def create_sliding_window_causal_mask(
511+
config: PretrainedConfig,
512+
input_embeds: ms.Tensor,
513+
attention_mask: Optional[ms.Tensor],
514+
cache_position: ms.Tensor,
515+
past_key_values: Optional[Cache],
516+
or_mask_function: Optional[Callable] = None,
517+
and_mask_function: Optional[Callable] = None,
518+
) -> Optional[Union[ms.Tensor, BlockMask]]:
519+
"""
520+
Create a sliding window causal mask based on the attention implementation used (stored in the config). This type
521+
of attention pattern was mostly democratized by Mistral. If `past_key_values` has an HybridCache structure, this
522+
function will return the mask corresponding to one of the "sliding_attention" layers (to align to what is needed in the
523+
`modeling_xxx.py` files).
524+
525+
Args:
526+
config (`PretrainedConfig`):
527+
The model config.
528+
input_embeds (`ms.Tensor`):
529+
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
530+
batch size, query length and dtype.
531+
attention_mask (`ms.Tensor`, optional):
532+
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
533+
It can also be an already prepared 4D mask, in which case it is returned as-is.
534+
cache_position (`ms.Tensor`):
535+
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
536+
past_key_values (`Cache`, optional):
537+
The past key values, if we use a cache.
538+
or_mask_function (`Callable`, optional):
539+
An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is
540+
useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
541+
and_mask_function (`Callable`, optional):
542+
An optional mask function to combine with the sliding causal mask function (by doing the intersection of both). This is
543+
useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling.
544+
"""
545+
# If we have an HybridCache structure, here we want to create the mask for the sliding layers
546+
if hasattr(past_key_values, "is_sliding") and True in past_key_values.is_sliding:
547+
layer_idx = past_key_values.is_sliding.index(True)
548+
else:
549+
layer_idx = 0
550+
551+
early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments(
552+
config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx
553+
)
554+
if early_exit:
555+
return attention_mask
556+
557+
sliding_window = getattr(config, "sliding_window", None)
558+
if sliding_window is None:
559+
raise ValueError("Could not find a `sliding_window` argument in the config, or it is not set")
560+
561+
batch_size, dtype = input_embeds.shape[0], input_embeds.dtype
562+
mask_factory_function = sliding_window_causal_mask_function(sliding_window)
563+
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
564+
565+
# Do not allow skip if we are compiling (this is to match BC)
566+
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
567+
allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True
568+
569+
# Allow slight deviations from sliding causal mask
570+
if or_mask_function is not None or and_mask_function is not None:
571+
raise NotImplementedError("`or_mask_function` or `and_mask_function` arguments are not supported yet.")
572+
573+
# We now create the mask
574+
causal_mask = mask_interface(
575+
batch_size=batch_size,
576+
cache_position=cache_position,
577+
kv_length=kv_length,
578+
kv_offset=kv_offset,
579+
mask_function=mask_factory_function,
580+
attention_mask=attention_mask,
581+
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
582+
local_size=sliding_window, # Additional kwarg for sdpa
583+
dtype=dtype, # Additional kwarg for eager
584+
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
585+
)
586+
return causal_mask
587+
588+
438589
LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING = {
439590
"full_attention": create_causal_mask,
591+
"sliding_attention": create_sliding_window_causal_mask,
440592
}

mindone/transformers/modeling_rope_utils.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
# limitations under the License.
1717

1818
import math
19-
from typing import Optional, Tuple
19+
from functools import wraps
20+
from typing import Optional
2021

2122
from transformers import PretrainedConfig
2223
from transformers.utils import logging
@@ -27,9 +28,63 @@
2728
logger = logging.get_logger(__name__)
2829

2930

31+
def dynamic_rope_update(rope_forward):
32+
"""
33+
Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
34+
(i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
35+
36+
Args:
37+
rope_forward (Callable):
38+
The forward pass of the RoPE implementation.
39+
40+
Returns:
41+
The decorated forward pass.
42+
"""
43+
44+
def longrope_frequency_update(self, position_ids):
45+
"""Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
46+
seq_len = mint.max(position_ids) + 1
47+
if hasattr(self.config, "original_max_position_embeddings"):
48+
original_max_position_embeddings = self.config.original_max_position_embeddings
49+
else:
50+
original_max_position_embeddings = self.config.max_position_embeddings
51+
if seq_len > original_max_position_embeddings:
52+
if not hasattr(self, "long_inv_freq"):
53+
self.long_inv_freq, _ = self.rope_init_fn(self.config, seq_len=original_max_position_embeddings + 1)
54+
self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
55+
else:
56+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
57+
58+
def dynamic_frequency_update(self, position_ids):
59+
"""
60+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
61+
1 - growing beyond the cached sequence length (allow scaling)
62+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
63+
"""
64+
seq_len = mint.max(position_ids) + 1
65+
if seq_len > self.max_seq_len_cached: # growth
66+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, seq_len=seq_len)
67+
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
68+
self.max_seq_len_cached = seq_len
69+
70+
if seq_len < self.original_max_seq_len < self.max_seq_len_cached: # reset
71+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
72+
self.max_seq_len_cached = self.original_max_seq_len
73+
74+
@wraps(rope_forward)
75+
def wrapper(self, x, position_ids):
76+
if "dynamic" in self.rope_type:
77+
dynamic_frequency_update(self, position_ids)
78+
elif self.rope_type == "longrope":
79+
longrope_frequency_update(self, position_ids)
80+
return rope_forward(self, x, position_ids)
81+
82+
return wrapper
83+
84+
3085
def _compute_default_rope_parameters(
3186
config: Optional[PretrainedConfig] = None, seq_len: Optional[int] = None, **rope_kwargs
32-
) -> Tuple[Tensor, float]:
87+
) -> tuple[Tensor, float]:
3388
"""
3489
Computes the inverse frequencies according to the original RoPE implementation
3590
Args:
@@ -54,7 +109,7 @@ def _compute_default_rope_parameters(
54109
elif config is not None:
55110
base = config.rope_theta
56111
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
57-
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
112+
head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
58113
dim = int(head_dim * partial_rotary_factor)
59114

60115
attention_factor = 1.0 # Unused in this type of RoPE

mindone/transformers/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,4 @@
9494
from . import glm4
9595

9696
if version.parse(transformers.__version__) >= version.parse("4.53.0"):
97-
from . import glm4v, vjepa2
97+
from . import glm4v, minimax, vjepa2

mindone/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,8 @@
271271
MODEL_NAMES_MAPPING.update({"glm4": "glm4"})
272272

273273
if version.parse(transformers.__version__) >= version.parse("4.53.0"):
274-
CONFIG_MAPPING_NAMES.update({"vjepa2": "VJEPA2Model"})
275-
MODEL_NAMES_MAPPING.update({"vjepa2": "VJEPA2Model"})
274+
CONFIG_MAPPING_NAMES.update({"minimax": "MiniMaxConfig", "vjepa2": "VJEPA2Model"})
275+
MODEL_NAMES_MAPPING.update({"minimax": "MiniMax", "vjepa2": "VJEPA2Model"})
276276

277277

278278
def model_type_to_module_name(key):

mindone/transformers/models/auto/modeling_auto.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,8 +603,12 @@
603603
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES.update({"glm4": "Glm4ForTokenClassification"})
604604

605605
if version.parse(transformers.__version__) >= version.parse("4.53.0"):
606-
MODEL_MAPPING_NAMES.update({"vjepa2": "VJEPA2Model"})
606+
MODEL_MAPPING_NAMES.update({"minimax": "MiniMaxModel", "vjepa2": "VJEPA2Model"})
607+
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update({"minimax": "MiniMaxForCausalLM"})
607608
MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES.update({"vjepa2": "VJEPA2ForVideoClassification"})
609+
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.update({"minimax": "MiniMaxForSequenceClassification"})
610+
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES.update({"minimax": "MiniMaxForQuestionAnswering"})
611+
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES.update({"minimax": "MiniMaxForTokenClassification"})
608612

609613
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
610614
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# coding=utf-8
2+
# Copyright 2025 MiniMaxAI and HuggingFace Inc. teams. All rights reserved.
3+
#
4+
# This code is adapted from https://github.com/huggingface/transformers
5+
# with modifications to run transformers on mindspore.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
19+
from .modeling_minimax import *

0 commit comments

Comments
 (0)