diff --git a/.gitignore b/.gitignore index 4f834433..f468ffd0 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,7 @@ venv.bak/ # Project specifics /.idea/ /.vscode/ +/.devcontainer/ # Devenv .devenv* diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index f3e93ede..b8a611c1 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -37,6 +37,7 @@ class BlockKwargs: sequence_lengths = "sequence_lengths" # TODO: Belongs elsewhere? grad_output = "grad_output" + iteration = "iteration" @config_class(registry=True) diff --git a/fast_llm/layers/decoder/config.py b/fast_llm/layers/decoder/config.py index 403b204c..deb1b14d 100644 --- a/fast_llm/layers/decoder/config.py +++ b/fast_llm/layers/decoder/config.py @@ -1,16 +1,25 @@ +import enum import typing from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.config_utils.parameter import combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.config import BlockConfig +from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.utils import Assert +from fast_llm.utils import Assert, normalize_probabilities if typing.TYPE_CHECKING: from fast_llm.layers.decoder.block import BlockWithBias, DecoderBlock + from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer + + +class StochasticMixerKwargs(BlockKwargs): + """Kwargs keys for stochastic mixer.""" + + mixer_name = "stochastic_mixer_name" + generator = "stochastic_mixer_generator" @config_class() @@ -55,6 +64,13 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi return super()._from_dict(default, strict=strict) +class StochasticMixerSamplingStrategy(str, enum.Enum): + """Strategy for sampling mixers in a stochastic mixer.""" + + uniform = "uniform" + weighted = "weighted" + + @config_class(registry=True) class MixerConfig(BlockWithBiasConfig): """ @@ -71,6 +87,75 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi return super()._from_dict(default, strict=strict) +@config_class(dynamic_type={MixerConfig: "stochastic"}) +class StochasticMixerConfig(MixerConfig): + """ + Stochastic mixer that uniformly samples from multiple mixer options during training. + + For supernet training, each forward pass randomly selects one mixer to execute, + training all mixers with different subsets of data. + """ + + _abstract = False + + mixers: dict[str, MixerConfig] = Field( + desc="Dict of mixer options to sample from (must contain at least 1). " + "Keys are mixer names used for debugging and namespacing.", + hint=FieldHint.architecture, + ) + + sampling_strategy: StochasticMixerSamplingStrategy = Field( + default=StochasticMixerSamplingStrategy.uniform, + desc="Strategy for sampling mixers during training.", + hint=FieldHint.feature, + ) + + sampling_weights: dict[str, float] | None = Field( + default=None, + desc="Sampling probability for each mixer by name (will be normalized to sum to 1.0). " + "Only used when sampling_strategy='weighted'. " + "If None with uniform strategy, all mixers have equal probability.", + hint=FieldHint.feature, + ) + + main_mixer_name: str | None = Field( + default=None, + desc="Name of the main mixer. " + "Used for inference/eval, checkpoint loading (receives pretrained weights), " + "and checkpoint saving (only this mixer is exported). " + "If None, uses the first mixer in the dict.", + hint=FieldHint.feature, + ) + + def _validate(self) -> None: + super()._validate() + + # Validate mixers dict is not empty + Assert.gt(len(self.mixers), 0) + + # Set main_mixer_name to first mixer if not specified + if self.main_mixer_name is None: + with self._set_implicit_default(): + self.main_mixer_name = next(iter(self.mixers.keys())) + + # Validate main mixer name exists + if self.main_mixer_name not in self.mixers: + raise ValueError(f"main_mixer_name '{self.main_mixer_name}' not found in mixers") + + # Validate and normalize sampling weights + if self.sampling_weights is not None: + Assert.eq(set(self.sampling_weights.keys()), set(self.mixers.keys())) + # Normalize weights to sum to 1.0 (also validates non-negative and positive sum) + normalized_values = normalize_probabilities(list(self.sampling_weights.values())) + self.sampling_weights = dict(zip(self.sampling_weights.keys(), normalized_values)) + + @property + def layer_class(self) -> "type[StochasticMixer]": + from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer + + return StochasticMixer + + @config_class(dynamic_type={BlockConfig: "decoder"}) class DecoderBlockConfig(BlockConfig): _abstract = False diff --git a/fast_llm/layers/decoder/stochastic_mixer.py b/fast_llm/layers/decoder/stochastic_mixer.py new file mode 100644 index 00000000..329a5b87 --- /dev/null +++ b/fast_llm/layers/decoder/stochastic_mixer.py @@ -0,0 +1,167 @@ +import logging +import typing + +import torch + +from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.decoder.block import BlockWithBias +from fast_llm.layers.decoder.config import StochasticMixerConfig, StochasticMixerKwargs, StochasticMixerSamplingStrategy +from fast_llm.tensor import TensorMeta + +logger = logging.getLogger(__name__) + + +class StochasticMixer[ConfigType: StochasticMixerConfig](BlockWithBias[ConfigType]): + """ + A mixer that stochastically samples from multiple mixer options during training. + + In training mode, each forward pass randomly selects one mixer according to + the sampling strategy. In eval mode, uses the configured inference mixer. + + This is useful for supernet training where you want to train multiple + architecture variants (e.g., attention vs. Mamba) with different data subsets. + """ + + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + return_bias: bool = True, + ): + super().__init__( + config, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + return_bias=return_bias, + ) + + # Initialize all mixers + self.mixers = torch.nn.ModuleDict( + { + name: mixer_config.get_layer( + distributed_config, + hidden_dim, + lr_scale, + peft=peft, + return_bias=return_bias, + ) + for name, mixer_config in self._config.mixers.items() + } + ) + + if self._config.sampling_strategy == StochasticMixerSamplingStrategy.uniform: + self._sampling_probs = torch.ones(len(self.mixers), device="cpu") / len(self.mixers) + elif self._config.sampling_strategy == StochasticMixerSamplingStrategy.weighted: + if self._config.sampling_weights is None: + raise ValueError("sampling_weights must be provided when using weighted sampling strategy") + self._sampling_probs = torch.tensor( + [self._config.sampling_weights[name] for name in self.mixers.keys()], + dtype=torch.float32, + device="cpu", + ) + else: + raise NotImplementedError(f"Sampling strategy {self._config.sampling_strategy} not implemented") + + logger.info( + f"Initialized StochasticMixer with {len(self.mixers)} mixers: " + f"{', '.join(f'{name}={type(mixer).__name__}' for name, mixer in self.mixers.items())} " + f"(main={self._config.main_mixer_name})" + ) + + # Mark all mixer parameters with allow_no_grad since only one mixer + # is active per forward pass during training. Even though all mixers + # will eventually be trained, on any single forward pass, the non-selected + # mixers won't receive gradients. + for mixer in self.mixers.values(): + for param in mixer.parameters(recurse=True): + if hasattr(param, "allow_no_grad"): + param.allow_no_grad = True + + def setup(self, distributed: Distributed) -> None: + """Setup all mixers with the distributed context.""" + super().setup(distributed) + for mixer in self.mixers.values(): + mixer.setup(distributed) + + def _sample_mixer_name(self, kwargs: dict[str, typing.Any]) -> str: + if not self.training: + return self._config.main_mixer_name + + generator = kwargs[StochasticMixerKwargs.generator] + mixer_idx = torch.multinomial(self._sampling_probs, num_samples=1, generator=generator).item() + return list(self.mixers.keys())[mixer_idx] + + def _forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + mixer_name = self._sample_mixer_name(kwargs) + + if self._debug.enabled: + logger.debug(f"StochasticMixer selecting mixer {mixer_name}: {type(self.mixers[mixer_name]).__name__}") + + return self.mixers[mixer_name]._forward(input_, kwargs, losses, metrics) + + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + from fast_llm.layers.block.config import BlockKwargs + + iteration = kwargs[BlockKwargs.iteration] + generator = torch.Generator(device="cpu") + generator.manual_seed(iteration) + kwargs[StochasticMixerKwargs.generator] = generator + + for mixer in self.mixers.values(): + mixer.preprocess(batch, kwargs) + + def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: + """ + Return expected compute usage (weighted average of all mixers). + + This gives a more accurate estimate than just using one mixer, + since during training we'll be using all of them according to + their sampling probabilities. + """ + usages = [mixer.get_compute_usage(input_, kwargs, config) for mixer in self.mixers.values()] + + # Weight by sampling probability and return the expected value + expected_usage = sum(usage * prob.item() for usage, prob in zip(usages, self._sampling_probs)) + + return int(expected_usage) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + """ + Merge loss definitions from all mixers with namespacing. + + Each mixer's losses are namespaced with the mixer name to avoid conflicts. + This ensures we allocate space for any auxiliary losses that any + of the mixers might need, even if multiple mixers have losses with the same name. + """ + all_losses = [] + for mixer_name, mixer in self.mixers.items(): + mixer_losses = mixer.get_loss_definitions(count=count) + # Namespace each loss with the mixer name to avoid conflicts + for loss_def in mixer_losses: + namespaced_loss = LossDef( + name=f"{mixer_name}/{loss_def.name}", + formatted_name=f"{mixer_name}/{loss_def.formatted_name}", + count=loss_def.count, + dtype=loss_def.dtype, + ) + all_losses.append(namespaced_loss) + + return all_losses diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index 32293266..41431462 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -2,6 +2,7 @@ Import these submodules to ensure classes are added to the dynamic class registry. """ +from fast_llm.layers.attention.config import AttentionConfig # isort: skip from fast_llm.layers.ssm.config import MambaConfig, Mamba2Config, DiscreteMamba2Config # isort: skip from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig # isort: skip from fast_llm.engine.evaluation.evaluators import EvaluatorsConfig # isort: skip diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index a901a046..c7e3f5b5 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -13,6 +13,7 @@ from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelConfig, MultiTokenPredictionConfig from fast_llm.models.gpt.conversion.config import ( + Apriel2CheckpointFormat, AprielHybridSSMCheckpointFormat, AutoGPTHuggingfaceCheckpointFormat, DiffusionDreamCheckpointFormat, @@ -117,6 +118,7 @@ class GPTModelConfig(FastLLMModelConfig): DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, AprielHybridSSMCheckpointFormat, + Apriel2CheckpointFormat, ) @classmethod diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py new file mode 100644 index 00000000..55b5e309 --- /dev/null +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -0,0 +1,601 @@ +""" +Apriel2 checkpoint format converter. + +Apriel2 is a HuggingFace format that closely mirrors Fast-LLM's config structure, +making conversion straightforward. +""" + +import typing + +from transformers import PretrainedConfig + +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.external import WeightConverter +from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.decoder.config import DecoderBlockConfig, StochasticMixerConfig +from fast_llm.layers.ssm.config import Mamba2Config +from fast_llm.models.gpt.config import GPTModelConfig +from fast_llm.models.gpt.conversion.config import Apriel2CheckpointFormat +from fast_llm.models.gpt.conversion.llama import get_parameter_converter, get_weight_and_bias_converters +from fast_llm.models.gpt.conversion.mistral import ( + MistralBaseModelConverter, + MistralBlockConverter, + MistralDecoderConverter, + MistralHeadConverter, + MistralHuggingfaceCheckpointHandler, +) +from fast_llm.utils import Assert, safe_merge_dicts + + +class Apriel2AttentionConverter: + """Converter for attention mixers.""" + + @classmethod + def import_config(cls, config: dict) -> dict: + """Import attention config from Apriel2 format.""" + return { + "type": "attention", + "heads": config.get("heads", 32), + "head_groups": config.get("head_groups", config.get("heads", 32)), + "head_size": config.get("head_size", None), + "rotary": config.get("rotary", {"type": "default", "theta": 10000.0}), + "add_linear_biases": config.get("add_linear_biases", False), + "window_size": config.get("window_size", None), + } + + @classmethod + def export_config(cls, config: AttentionConfig) -> dict: + """Export attention config to Apriel2 format.""" + from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig + + # Determine rotary type string + if type(config.rotary) is DefaultRotaryConfig: + rotary_type = "default" + elif type(config.rotary) is Llama3RotaryConfig: + rotary_type = "llama3" + elif type(config.rotary) is YarnRotaryConfig: + rotary_type = "yarn" + else: + raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") + + return { + "type": "attention", + "heads": config.heads, + "head_groups": config.head_groups, + "head_size": config.head_size, + "add_linear_biases": config.add_linear_biases, + "rotary": { + "type": rotary_type, + "theta": config.rotary.theta, + }, + "window_size": config.window_size, + } + + @classmethod + def get_converters( + cls, + config: AttentionConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + """Get weight converters for attention.""" + from fast_llm.models.gpt.conversion.llama import QueryWeightConverter, KeyValueWeightConverter + + # Use same weight names as Llama converter + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.query", + f"{hf_prefix}.q_proj", + config.add_linear_biases, + QueryWeightConverter, + config, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.key_value", + (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), + config.add_linear_biases, + KeyValueWeightConverter, + config, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.dense", + f"{hf_prefix}.o_proj", + config.add_linear_biases, + drop_on_export=drop_on_export, + ), + ] + + +class Apriel2MambaConverter: + """Converter for Mamba mixers.""" + + @classmethod + def import_config(cls, config: dict) -> dict: + """Import Mamba config from Apriel2 format.""" + return { + "type": "mamba_2", + "state_size": config.get("state_size", 16), + "d_inner": config.get("d_inner"), + "d_xb": config.get("d_xb", None), + "dt_rank": config.get("dt_rank", "auto"), + "add_linear_biases": config.get("add_linear_biases", False), + } + + @classmethod + def export_config(cls, config: Mamba2Config) -> dict: + """Export Mamba config to Apriel2 format.""" + exported = { + "type": "mamba", + "state_size": config.state_size, + "d_inner": config.d_inner, + "d_conv": config.convolution_layer.kernel_size, + "add_linear_biases": config.add_linear_biases, + "conv_bias": config.convolution_layer.bias.enabled, + "dt_proj_bias": config.dt_layer.bias.enabled, + } + + if config.d_xb is not None: + exported["d_xb"] = config.d_xb + + if config.dt_rank != "auto": + exported["dt_rank"] = config.dt_rank + + return exported + + @classmethod + def get_converters( + cls, + config: Mamba2Config, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + """Get weight converters for Mamba.""" + return [ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.in_proj", + f"{hf_prefix}.in_proj", + config.add_linear_biases, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.dt_in_proj", + f"{hf_prefix}.dt_in_proj", + config.add_linear_biases, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.dt_proj", + f"{hf_prefix}.dt_proj", + config.dt_layer.bias.enabled, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.convolution", + f"{hf_prefix}.conv1d", + config.convolution_layer.bias.enabled, + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.A_log", + f"{hf_prefix}.A_log", + drop_on_export=drop_on_export, + ), + get_parameter_converter( + f"{fast_llm_prefix}.D", + f"{hf_prefix}.D", + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.out_proj", + f"{hf_prefix}.out_proj", + config.add_linear_biases, + drop_on_export=drop_on_export, + ), + ] + + +# TODO: Add converters for GatedDeltaNet and KimiLinearAttention when implemented + + +class Apriel2StochasticMixerConverter: + """Converter for stochastic mixers.""" + + @classmethod + def import_config(cls, config: dict) -> dict: + """Import stochastic mixer config from Apriel2 format.""" + # Import each sub-mixer config + mixers = {} + for name, sub_mixer_config in config.get("mixers", {}).items(): + mixer_type = sub_mixer_config.get("type") + if mixer_type == "attention": + mixers[name] = Apriel2AttentionConverter.import_config(sub_mixer_config) + elif mixer_type == "mamba": + mixers[name] = Apriel2MambaConverter.import_config(sub_mixer_config) + else: + raise ValueError(f"Unknown sub-mixer type: {mixer_type}") + + return { + "type": "stochastic", + "mixers": mixers, + "main_mixer_name": config.get("main_mixer_name"), + "sampling_strategy": config.get("sampling_strategy", "uniform"), + } + + @classmethod + def export_config(cls, config: StochasticMixerConfig) -> dict: + """Export stochastic mixer config to Apriel2 format.""" + # Export each sub-mixer config + mixers = {} + for name, sub_mixer in config.mixers.items(): + mixer_type = type(sub_mixer) + if mixer_type is AttentionConfig: + mixers[name] = Apriel2AttentionConverter.export_config(sub_mixer) + elif mixer_type is Mamba2Config: + mixers[name] = Apriel2MambaConverter.export_config(sub_mixer) + else: + raise ValueError(f"Unknown sub-mixer type: {mixer_type}") + + return { + "type": "stochastic", + "mixers": mixers, + "main_mixer_name": config.main_mixer_name, + "sampling_strategy": config.sampling_strategy.value, + } + + @classmethod + def get_converters( + cls, + config: StochasticMixerConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + """Get weight converters for stochastic mixer.""" + converters = [] + + # Create converters for each sub-mixer + for name, sub_mixer in config.mixers.items(): + mixer_type = type(sub_mixer) + + if mixer_type is AttentionConfig: + converter_class = Apriel2AttentionConverter + # Attention sub-mixers have .self_attn nested inside + hf_sub_mixer_prefix = f"{hf_prefix}.mixers.{name}.self_attn" + elif mixer_type is Mamba2Config: + converter_class = Apriel2MambaConverter + hf_sub_mixer_prefix = f"{hf_prefix}.mixers.{name}" + else: + raise ValueError(f"Unknown sub-mixer type: {mixer_type}") + + # Sub-mixers are stored in a ModuleDict with names as keys + converters.extend( + converter_class.get_converters( + sub_mixer, + f"{fast_llm_prefix}.mixers.{name}", + hf_sub_mixer_prefix, + drop_on_export=drop_on_export, + ) + ) + + return converters + + +class Apriel2BlockConverter(MistralBlockConverter): + """Converter for decoder blocks.""" + + @classmethod + def import_config(cls, config: dict, block_config: dict) -> dict: + """Import block config from Apriel2 format.""" + # Import mixer config + mixer_config = block_config.get("mixer", {}) + mixer_type = mixer_config.get("type", "attention") + + if mixer_type == "attention": + mixer = Apriel2AttentionConverter.import_config(mixer_config) + elif mixer_type == "mamba": + mixer = Apriel2MambaConverter.import_config(mixer_config) + elif mixer_type == "stochastic": + mixer = Apriel2StochasticMixerConverter.import_config(mixer_config) + else: + raise ValueError(f"Unknown mixer type: {mixer_type}") + + from fast_llm.functional.config import ActivationType + + mlp_config = block_config.get("mlp", {"type": "mlp"}) + mlp = { + "type": "mlp", + "intermediate_size": mlp_config.get("intermediate_size"), + "activation": ActivationType.from_hf_name(mlp_config.get("activation", "silu")), + "gated": True, + "add_linear_biases": mlp_config.get("add_linear_biases", False), + } + + normalization = block_config.get("normalization", {"type": "rms_norm"}) + + return { + "mixer": mixer, + "mlp": mlp, + "normalization": normalization, + } + + @classmethod + def export_config(cls, config: DecoderBlockConfig) -> dict: + """Export block config to Apriel2 format.""" + from fast_llm.layers.common.normalization.config import ( + RMSNormalizationConfig, + LayerNormalizationConfig, + NoNormalizationConfig, + ) + + # Export mixer config + mixer_type = type(config.mixer) + + if mixer_type is AttentionConfig: + mixer = Apriel2AttentionConverter.export_config(config.mixer) + elif mixer_type is Mamba2Config: + mixer = Apriel2MambaConverter.export_config(config.mixer) + elif mixer_type is StochasticMixerConfig: + mixer = Apriel2StochasticMixerConverter.export_config(config.mixer) + else: + raise ValueError(f"Unknown mixer type: {mixer_type}") + + # Determine normalization type string + norm_type = type(config.normalization) + if norm_type is RMSNormalizationConfig: + norm_type_str = "rms_norm" + elif norm_type is LayerNormalizationConfig: + norm_type_str = "layer_norm" + elif norm_type is NoNormalizationConfig: + norm_type_str = "none" + else: + raise ValueError(f"Unknown normalization type: {norm_type}") + + # Export MLP + from fast_llm.layers.decoder.mlp.config import MLPConfig + + if not isinstance(config.mlp, MLPConfig): + raise ValueError(f"Unsupported MLP type: {type(config.mlp)}") + + mlp = { + "type": "mlp", + "intermediate_size": config.mlp.intermediate_size, + "activation": config.mlp.activation.value, + } + + # Export normalization + normalization = {"type": norm_type_str} + + return { + "mixer": mixer, + "mlp": mlp, + "normalization": normalization, + } + + @classmethod + def get_converters( + cls, + config: DecoderBlockConfig, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + """Get weight converters for block.""" + converters = [] + + # Mixer converters - all at .mixer with appropriate sub-paths + mixer_type = type(config.mixer) + if mixer_type is AttentionConfig: + converter_class = Apriel2AttentionConverter + hf_mixer_prefix = f"{hf_prefix}.mixer.self_attn" + elif mixer_type is Mamba2Config: + converter_class = Apriel2MambaConverter + hf_mixer_prefix = f"{hf_prefix}.mixer" + elif mixer_type is StochasticMixerConfig: + converter_class = Apriel2StochasticMixerConverter + hf_mixer_prefix = f"{hf_prefix}.mixer" + else: + raise ValueError(f"Unknown mixer type: {mixer_type}") + + converters.extend( + converter_class.get_converters( + config.mixer, + f"{fast_llm_prefix}.mixer", + hf_mixer_prefix, + drop_on_export=drop_on_export, + ) + ) + + # MLP converters - Fast-LLM uses layer_1 and layer_2 + from fast_llm.models.gpt.conversion.llama import SplitWeightConverter, MLPLayer2Converter + + converters.extend([ + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_1", + (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), + config.mlp.add_linear_biases, + SplitWeightConverter, + drop_on_export=drop_on_export, + ), + *get_weight_and_bias_converters( + f"{fast_llm_prefix}.mlp.layer_2", + f"{hf_prefix}.mlp.down_proj", + config.mlp.add_linear_biases, + MLPLayer2Converter, + drop_on_export=drop_on_export, + ), + ]) + + # Normalization converters - Fast-LLM uses norm_1 and norm_2 + from fast_llm.models.gpt.conversion.llama import LlamaNormalizationConverter + + converters.extend([ + *LlamaNormalizationConverter.get_converters( + config.normalization, + f"{fast_llm_prefix}.norm_1", + f"{hf_prefix}.input_layernorm", + drop_on_export=drop_on_export, + ), + *LlamaNormalizationConverter.get_converters( + config.normalization, + f"{fast_llm_prefix}.norm_2", + f"{hf_prefix}.post_attention_layernorm", + drop_on_export=drop_on_export, + ), + ]) + + return converters + + +class Apriel2DecoderConverter(MistralDecoderConverter): + """Converter for decoder.""" + + block_converter_class: typing.ClassVar[type[Apriel2BlockConverter]] = Apriel2BlockConverter + + @classmethod + def import_config(cls, config: dict) -> dict: + """Import decoder config from Apriel2 format.""" + decoder_config = config.get("decoder", {}) + decoder_type = decoder_config.get("type", "fixed") + + if decoder_type == "fixed": + # Fixed decoder: single block config + block_config = decoder_config.get("block", {}) + imported_block = cls.block_converter_class.import_config(config, block_config) + + return { + "type": "fixed", + "num_blocks": decoder_config.get("num_blocks", config.get("num_hidden_layers", 32)), + "block": imported_block, + } + + elif decoder_type == "pattern": + # Pattern decoder: multiple named blocks + blocks = {} + for name, block_config in decoder_config.get("blocks", {}).items(): + blocks[name] = cls.block_converter_class.import_config(config, block_config) + + return { + "type": "pattern", + "blocks": blocks, + "pattern": decoder_config.get("pattern", []), + "num_blocks": decoder_config.get("num_blocks", config.get("num_hidden_layers", 32)), + } + + else: + raise ValueError(f"Unknown decoder type: {decoder_type}") + + @classmethod + def export_config(cls, config) -> dict: + """Export decoder config to Apriel2 format.""" + from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig + + if isinstance(config, FixedBlockSequenceConfig): + # Fixed decoder + block_config = cls.block_converter_class.export_config(config.block) + return { + "decoder": { + "type": "fixed", + "num_blocks": config.num_blocks, + "block": block_config, + } + } + + elif isinstance(config, PatternBlockSequenceConfig): + # Pattern decoder + blocks = {} + for name, block_config in config.blocks.items(): + blocks[name] = cls.block_converter_class.export_config(block_config) + + return { + "decoder": { + "type": "pattern", + "blocks": blocks, + "pattern": config.pattern, + "num_blocks": config.num_blocks, + } + } + + else: + raise ValueError(f"Unknown decoder config type: {type(config)}") + + @classmethod + def get_converters( + cls, + config, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list[WeightConverter]: + """Get weight converters for decoder.""" + from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig + + converters = [] + if type(config) is FixedBlockSequenceConfig: + for block_index in range(config.num_blocks): + converters += cls.block_converter_class.get_converters( + config.block, + f"{fast_llm_prefix}.{block_index}", + f"{hf_prefix}.{block_index}", + drop_on_export, + ) + elif type(config) is PatternBlockSequenceConfig: + for block_index in range(config.num_blocks): + block_config = config.blocks[config.pattern[block_index % len(config.pattern)]] + converters += cls.block_converter_class.get_converters( + block_config, + f"{fast_llm_prefix}.{block_index}", + f"{hf_prefix}.{block_index}", + drop_on_export, + ) + else: + raise NotImplementedError(f"Unsupported config type: {type(config).__name__}") + return converters + + +class Apriel2HeadConverter(MistralHeadConverter): + block_converter_class: typing.ClassVar[type[Apriel2BlockConverter]] = Apriel2BlockConverter + + +class Apriel2BaseModelConverter(MistralBaseModelConverter): + decoder_converter_class: typing.ClassVar[type[Apriel2DecoderConverter]] = Apriel2DecoderConverter + head_converter_class: typing.ClassVar[type[Apriel2HeadConverter]] = Apriel2HeadConverter + + +class Apriel2HuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): + """HuggingFace checkpoint handler for Apriel2 format.""" + + format: typing.ClassVar[type[CheckpointFormat]] = Apriel2CheckpointFormat + architecture: typing.ClassVar[str] = "Apriel2ForCausalLM" + base_model_converter_class: typing.ClassVar[type[Apriel2BaseModelConverter]] = Apriel2BaseModelConverter + + @classmethod + def get_transformers_configuration_class(cls) -> type[PretrainedConfig]: + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config + + @classmethod + def get_model_files(cls) -> tuple[str, str, str | None]: + from fast_llm_external_models.apriel2 import ( + configuration_apriel2, + modeling_apriel2, + ) + + return configuration_apriel2.__file__, modeling_apriel2.__file__, None + + @classmethod + def _export_config(cls, config: GPTModelConfig) -> dict[str, typing.Any]: + return safe_merge_dicts( + super()._export_config(config), + { + "auto_map": { + "AutoConfig": "configuration_apriel2.Apriel2Config", + "AutoModel": "modeling_apriel2.Apriel2Model", + "AutoModelForCausalLM": "modeling_apriel2.Apriel2ForCausalLM", + }, + }, + ) diff --git a/fast_llm/models/gpt/conversion/auto.py b/fast_llm/models/gpt/conversion/auto.py index 659d1f12..0dbf3774 100644 --- a/fast_llm/models/gpt/conversion/auto.py +++ b/fast_llm/models/gpt/conversion/auto.py @@ -3,7 +3,9 @@ from fast_llm.engine.checkpoint.external import AutoStateDictCheckpointHandler from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler from fast_llm.models.gpt.conversion.apriel import AprielHuggingfaceCheckpointHandler +from fast_llm.models.gpt.conversion.apriel2 import Apriel2HuggingfaceCheckpointHandler from fast_llm.models.gpt.conversion.config import ( + Apriel2CheckpointFormat, AprielHybridSSMCheckpointFormat, DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, @@ -35,4 +37,5 @@ class AutoGPTHuggingfaceCheckpointHandler( DiffusionDreamCheckpointFormat.name: DiffusionDreamHuggingfaceCheckpointHandler, DiffusionLlamaCheckpointFormat.name: DiffusionLlamaHuggingfaceCheckpointHandler, AprielHybridSSMCheckpointFormat.name: AprielHuggingfaceCheckpointHandler, + Apriel2CheckpointFormat.name: Apriel2HuggingfaceCheckpointHandler, } diff --git a/fast_llm/models/gpt/conversion/config.py b/fast_llm/models/gpt/conversion/config.py index 7c06906a..888fce3d 100644 --- a/fast_llm/models/gpt/conversion/config.py +++ b/fast_llm/models/gpt/conversion/config.py @@ -47,3 +47,7 @@ class DiffusionLlamaCheckpointFormat(GPTHuggingfaceCheckpointFormat): class AprielHybridSSMCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_hybrid_ssm" + + +class Apriel2CheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "apriel2" diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index a9249226..7c5f0778 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -494,6 +494,7 @@ def get_converters( f"{fast_llm_prefix}.output_weights", "lm_head.weight", drop_on_import=exported_config["tie_word_embeddings"], + drop_on_export=exported_config["tie_word_embeddings"], ), ] diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index efa348ec..c24983cf 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,7 +10,7 @@ from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.block.config import BlockDimNames, BlockKwargs from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig @@ -220,6 +220,7 @@ def preprocess_batch( **kwargs_meta, AttentionKwargs.past_key_values: pasts, AttentionKwargs.presents: presents, + BlockKwargs.iteration: iteration, } if phase != PhaseType.inference: sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels diff --git a/fast_llm_external_models/apriel2/__init__.py b/fast_llm_external_models/apriel2/__init__.py new file mode 100644 index 00000000..4eed64f7 --- /dev/null +++ b/fast_llm_external_models/apriel2/__init__.py @@ -0,0 +1,5 @@ +"""Apriel2 - HuggingFace format that mirrors Fast-LLM's architecture.""" + +from fast_llm_external_models.apriel2 import configuration_apriel2, modeling_apriel2 + +__all__ = ["configuration_apriel2", "modeling_apriel2"] diff --git a/fast_llm_external_models/apriel2/cache.py b/fast_llm_external_models/apriel2/cache.py new file mode 100644 index 00000000..02a348c4 --- /dev/null +++ b/fast_llm_external_models/apriel2/cache.py @@ -0,0 +1,356 @@ +from __future__ import annotations +from typing import Optional, Any +import torch +from transformers.cache_utils import Cache + + +class _AttentionCache: + __slots__ = ['key', 'value', 'window'] + + def __init__(self, window=None): + self.key = None + self.value = None + self.window = window + + def update(self, key, value): + if self.key is None: + if self.window and key.shape[-2] > self.window: + self.key = key[..., -self.window:, :].contiguous() + self.value = value[..., -self.window:, :].contiguous() + else: + self.key = key.contiguous() + self.value = value.contiguous() + else: + if self.window: + self.key = self._window(self.key, key) + self.value = self._window(self.value, value) + else: + self.key = torch.cat([self.key, key], -2) + self.value = torch.cat([self.value, value], -2) + return self.key, self.value + + def _window(self, cache, new): + if cache.shape[-2] == self.window and new.shape[-2] == 1: + cache = cache.roll(-1, -2) + cache[..., -1:, :] = new + return cache + return torch.cat([cache, new], -2)[..., -self.window:, :].contiguous() + + +class _SSMCache: + __slots__ = ['conv', 'recurrent'] + + def __init__(self): + self.conv = None + self.recurrent = None + + +class _DummyCacheLayer: + pass + + +class Apriel2Cache(Cache): + + def __init__(self, config): + super().__init__(layer_class_to_replicate=_DummyCacheLayer) + self.config = config + n = config.num_hidden_layers + self.layers = [] + self.mixer_types = [] + self.active_mixers = [None] * n + + for i in range(n): + block = config.get_block_config(i) + mixer = block.get("mixer", {}) + mtype = mixer.get("type", "attention") + + if mtype == "stochastic": + sub = {} + main = mixer.get("main_mixer_name") + for name, cfg in mixer.get("mixers", {}).items(): + if cfg.get("type") == "attention": + sub[name] = _AttentionCache(cfg.get("sliding_window")) + else: + sub[name] = _SSMCache() + self.layers.append(sub) + self.mixer_types.append(mixer["mixers"][main].get("type") if main else "attention") + elif mtype == "attention": + self.layers.append(_AttentionCache(mixer.get("sliding_window"))) + self.mixer_types.append("attention") + else: + self.layers.append(_SSMCache()) + self.mixer_types.append(mtype) + + def update(self, key_states, value_states, layer_idx, cache_kwargs=None): + layer = self.layers[layer_idx] + if isinstance(layer, dict): + mixer = self.active_mixers[layer_idx] + if mixer is None: + raise RuntimeError(f"Stochastic layer {layer_idx} needs active_mixer set") + return layer[mixer].update(key_states, value_states) + return layer.update(key_states, value_states) + + def set_active_mixer(self, layer_idx, mixer_name): + self.active_mixers[layer_idx] = mixer_name + + def get_seq_length(self, layer_idx=0): + layer = self.layers[layer_idx] + if isinstance(layer, dict): + mixer = self.active_mixers[layer_idx] + if mixer and isinstance(layer[mixer], _AttentionCache): + return layer[mixer].key.shape[-2] if layer[mixer].key is not None else 0 + return 0 + if isinstance(layer, _AttentionCache): + return layer.key.shape[-2] if layer.key is not None else 0 + return 0 + + def get_max_cache_shape(self, layer_idx=0): + layer = self.layers[layer_idx] + if isinstance(layer, dict): + mixer = self.active_mixers[layer_idx] + if mixer and isinstance(layer[mixer], _AttentionCache): + return layer[mixer].window + elif isinstance(layer, _AttentionCache): + return layer.window + return None + + def get_mask_sizes(self, cache_position, layer_idx): + query_length = cache_position.shape[0] + past_seen_tokens = self.get_seq_length(layer_idx) + kv_length = query_length + past_seen_tokens + kv_offset = past_seen_tokens + return kv_length, kv_offset + + @property + def has_previous_state(self): + for layer in self.layers: + if isinstance(layer, dict): + for cache in layer.values(): + if isinstance(cache, _SSMCache) and cache.conv is not None: + return True + elif isinstance(layer, _SSMCache) and layer.conv is not None: + return True + return False + + @property + def key_cache(self): + return _LayerListAccessor(self, 'key') + + @property + def value_cache(self): + return _LayerListAccessor(self, 'value') + + @property + def conv_states(self): + return _LayerListAccessor(self, 'conv') + + @property + def recurrent_states(self): + return _LayerListAccessor(self, 'recurrent') + + def reorder_cache(self, beam_idx): + for i, layer in enumerate(self.layers): + if isinstance(layer, dict): + for cache in layer.values(): + self._reorder_cache_obj(cache, beam_idx) + else: + self._reorder_cache_obj(layer, beam_idx) + + def _reorder_cache_obj(self, cache, beam_idx): + if isinstance(cache, _AttentionCache): + if cache.key is not None: + cache.key = cache.key.index_select(0, beam_idx.to(cache.key.device)) + cache.value = cache.value.index_select(0, beam_idx.to(cache.value.device)) + elif isinstance(cache, _SSMCache): + if cache.conv is not None: + cache.conv = cache.conv.index_select(0, beam_idx.to(cache.conv.device)) + if cache.recurrent is not None: + cache.recurrent = cache.recurrent.index_select(0, beam_idx.to(cache.recurrent.device)) + + def reset(self): + for layer in self.layers: + if isinstance(layer, dict): + for cache in layer.values(): + self._reset_cache_obj(cache) + else: + self._reset_cache_obj(layer) + + def _reset_cache_obj(self, cache): + if isinstance(cache, _AttentionCache): + cache.key = None + cache.value = None + elif isinstance(cache, _SSMCache): + cache.conv = None + cache.recurrent = None + + def crop(self, max_length): + for layer in self.layers: + if isinstance(layer, dict): + for cache in layer.values(): + if isinstance(cache, _AttentionCache) and cache.key is not None: + cache.key = cache.key[..., :max_length, :] + cache.value = cache.value[..., :max_length, :] + elif isinstance(layer, _AttentionCache) and layer.key is not None: + layer.key = layer.key[..., :max_length, :] + layer.value = layer.value[..., :max_length, :] + + def batch_repeat_interleave(self, repeats): + for layer in self.layers: + if isinstance(layer, dict): + for cache in layer.values(): + self._batch_repeat_cache_obj(cache, repeats) + else: + self._batch_repeat_cache_obj(layer, repeats) + + def _batch_repeat_cache_obj(self, cache, repeats): + if isinstance(cache, _AttentionCache): + if cache.key is not None: + cache.key = cache.key.repeat_interleave(repeats, dim=0) + cache.value = cache.value.repeat_interleave(repeats, dim=0) + elif isinstance(cache, _SSMCache): + if cache.conv is not None: + cache.conv = cache.conv.repeat_interleave(repeats, dim=0) + if cache.recurrent is not None: + cache.recurrent = cache.recurrent.repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices): + for layer in self.layers: + if isinstance(layer, dict): + for cache in layer.values(): + self._batch_select_cache_obj(cache, indices) + else: + self._batch_select_cache_obj(layer, indices) + + def _batch_select_cache_obj(self, cache, indices): + if isinstance(cache, _AttentionCache): + if cache.key is not None: + cache.key = cache.key.index_select(0, indices.to(cache.key.device)) + cache.value = cache.value.index_select(0, indices.to(cache.value.device)) + elif isinstance(cache, _SSMCache): + if cache.conv is not None: + cache.conv = cache.conv.index_select(0, indices.to(cache.conv.device)) + if cache.recurrent is not None: + cache.recurrent = cache.recurrent.index_select(0, indices.to(cache.recurrent.device)) + + @property + def is_compileable(self): + return False + + @property + def is_initialized(self): + for layer in self.layers: + if isinstance(layer, dict): + for cache in layer.values(): + if isinstance(cache, _AttentionCache) and cache.key is not None: + return True + if isinstance(cache, _SSMCache) and cache.conv is not None: + return True + else: + if isinstance(layer, _AttentionCache) and layer.key is not None: + return True + if isinstance(layer, _SSMCache) and layer.conv is not None: + return True + return False + + @property + def is_sliding(self): + result = [] + for layer in self.layers: + if isinstance(layer, dict): + has_sliding = any( + isinstance(cache, _AttentionCache) and cache.window is not None + for cache in layer.values() + ) + result.append(has_sliding) + elif isinstance(layer, _AttentionCache): + result.append(layer.window is not None) + else: + result.append(False) + return result + + @property + def max_batch_size(self): + for layer in self.layers: + if isinstance(layer, dict): + for cache in layer.values(): + if isinstance(cache, _AttentionCache) and cache.key is not None: + return cache.key.shape[0] + if isinstance(cache, _SSMCache) and cache.conv is not None: + return cache.conv.shape[0] + else: + if isinstance(layer, _AttentionCache) and layer.key is not None: + return layer.key.shape[0] + if isinstance(layer, _SSMCache) and layer.conv is not None: + return layer.conv.shape[0] + return None + + @property + def max_cache_len(self): + max_len = None + for layer in self.layers: + if isinstance(layer, dict): + for cache in layer.values(): + if isinstance(cache, _AttentionCache): + if cache.window is not None: + max_len = cache.window if max_len is None else min(max_len, cache.window) + elif isinstance(layer, _AttentionCache): + if layer.window is not None: + max_len = layer.window if max_len is None else min(max_len, layer.window) + return max_len + + def __len__(self): + return len(self.layers) + + def __getitem__(self, idx): + layer = self.layers[idx] + if isinstance(layer, dict): + mixer = self.active_mixers[idx] + if mixer and isinstance(layer[mixer], _AttentionCache): + c = layer[mixer] + if c.key is not None: + return c.key, c.value + elif isinstance(layer, _AttentionCache): + if layer.key is not None: + return layer.key, layer.value + + for i, l in enumerate(self.layers): + if isinstance(l, _AttentionCache) and l.key is not None: + return torch.empty((0,), device=l.key.device, dtype=l.key.dtype), torch.empty((0,), device=l.key.device, dtype=l.key.dtype) + elif isinstance(l, dict): + for c in l.values(): + if isinstance(c, _AttentionCache) and c.key is not None: + return torch.empty((0,), device=c.key.device, dtype=c.key.dtype), torch.empty((0,), device=c.key.device, dtype=c.key.dtype) + return torch.empty((0,)), torch.empty((0,)) + + +class _LayerListAccessor: + __slots__ = ['cache', 'attr'] + + def __init__(self, cache, attr): + self.cache = cache + self.attr = attr + + def __getitem__(self, idx): + layer = self.cache.layers[idx] + if isinstance(layer, dict): + mixer = self.cache.active_mixers[idx] + if mixer is None: + raise RuntimeError( + f"Stochastic layer {idx} requires set_active_mixer() to be called before accessing cache. " + f"Available mixers: {list(layer.keys())}" + ) + return getattr(layer[mixer], self.attr) + return getattr(layer, self.attr, None) + + def __setitem__(self, idx, value): + layer = self.cache.layers[idx] + if isinstance(layer, dict): + mixer = self.cache.active_mixers[idx] + if mixer is None: + raise RuntimeError( + f"Stochastic layer {idx} requires set_active_mixer() to be called before accessing cache. " + f"Available mixers: {list(layer.keys())}" + ) + setattr(layer[mixer], self.attr, value) + elif hasattr(layer, self.attr): + setattr(layer, self.attr, value) diff --git a/fast_llm_external_models/apriel2/configuration_apriel2.py b/fast_llm_external_models/apriel2/configuration_apriel2.py new file mode 100644 index 00000000..73f92714 --- /dev/null +++ b/fast_llm_external_models/apriel2/configuration_apriel2.py @@ -0,0 +1,138 @@ +""" +Apriel2 configuration - HuggingFace format that mirrors Fast-LLM's config structure. +""" + +from typing import Any, Optional, Union + +from transformers import PretrainedConfig + + +class Apriel2Config(PretrainedConfig): + """ + Configuration class for Apriel2 models. + + This config mirrors Fast-LLM's hierarchical structure: + + decoder: + type: "fixed" or "pattern" + num_blocks: int + + # For fixed decoder: + block: + mixer: {type, ...params} + mlp: {type, ...params} + normalization: {type} + + # For pattern decoder: + blocks: + block_name: + mixer: {type, ...params} + mlp: {type, ...params} + normalization: {type} + pattern: [block_name, ...] + + Mixer types: attention, mamba, gated_delta_net, kimi_linear_attention, stochastic + For stochastic mixers, mixer.mixers is a dict of {name: mixer_config} + """ + + model_type = "apriel2" + + def __init__( + self, + vocab_size: int = 32000, + hidden_size: int = 4096, + # Decoder configuration + decoder: Optional[dict] = None, + # Embedding config + max_position_embeddings: int = 2048, + rope_theta: float = 10000.0, + # Attention defaults (can be overridden per-block) + num_attention_heads: int = 32, + num_key_value_heads: Optional[int] = None, + head_dim: Optional[int] = None, + # Head config + rms_norm_eps: float = 1e-5, + tie_word_embeddings: bool = False, + # Generation config + bos_token_id: int = 1, + eos_token_id: int = 2, + pad_token_id: Optional[int] = None, + use_cache: bool = True, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads + self.rms_norm_eps = rms_norm_eps + self.tie_word_embeddings = tie_word_embeddings + self.use_cache = use_cache + + # Decoder configuration with defaults + self.decoder = decoder or { + "type": "fixed", + "num_blocks": 32, + "block": { + "mixer": {"type": "attention"}, + "mlp": {"type": "mlp"}, + "normalization": {"type": "rms_norm"}, + }, + } + + # Convenience accessor for HuggingFace compatibility + self.num_hidden_layers = self.decoder.get("num_blocks", 32) + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def get_text_config(self, decoder: bool = False): + """Return self to ensure tie_word_embeddings is accessible.""" + return self + + def get_block_name(self, layer_idx: int) -> str: + """Get the block name for a specific layer.""" + decoder_type = self.decoder.get("type", "fixed") + + if decoder_type == "fixed": + return "block" + elif decoder_type == "pattern": + pattern = self.decoder.get("pattern", []) + if not pattern: + raise ValueError("Pattern decoder requires 'pattern' field") + return pattern[layer_idx % len(pattern)] + else: + raise ValueError(f"Unknown decoder type: {decoder_type}") + + def get_block_config(self, layer_idx: int) -> dict: + """Get the block configuration for a specific layer.""" + decoder_type = self.decoder.get("type", "fixed") + + if decoder_type == "fixed": + # Fixed decoder: all blocks use the same configuration + return self.decoder.get("block", self._default_block_config()) + elif decoder_type == "pattern": + # Pattern decoder: blocks follow a repeating pattern + blocks = self.decoder.get("blocks", {}) + pattern = self.decoder.get("pattern", []) + if not blocks or not pattern: + raise ValueError("Pattern decoder requires 'blocks' and 'pattern' fields") + block_name = pattern[layer_idx % len(pattern)] + return blocks[block_name] + else: + raise ValueError(f"Unknown decoder type: {decoder_type}") + + def _default_block_config(self) -> dict: + """Create default block configuration.""" + return { + "mixer": {"type": "attention"}, + "mlp": {"type": "mlp"}, + "normalization": {"type": "rms_norm"}, + } diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py new file mode 100644 index 00000000..fe277c7f --- /dev/null +++ b/fast_llm_external_models/apriel2/modeling_apriel2.py @@ -0,0 +1,1151 @@ +""" +Apriel2 modeling - HuggingFace format that mirrors Fast-LLM's architecture. +""" + +import math +import random +from typing import Any, Optional, Union +from types import SimpleNamespace + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import nn +from transformers import GenerationMixin, PreTrainedModel +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.processing_utils import Unpack +from transformers.utils import logging + +from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config +from fast_llm_external_models.apriel2.cache import Apriel2Cache +from transformers.models.mistral.modeling_mistral import ( + MistralAttention, + MistralMLP, + MistralRMSNorm, +) +from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextGatedDeltaNet + +from transformers.utils.import_utils import is_torch_flex_attn_available +from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask + +# Try to import optimized kernels from mamba_ssm +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + _mamba_ssm_available = True +except ImportError: + causal_conv1d_fn = None + causal_conv1d_update = None + selective_scan_fn = None + selective_state_update = None + _mamba_ssm_available = False + +# Try to import FLA (Fused Linear Attention) library for optimizations +try: + from fla.modules import FusedRMSNormGated + from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule + _fla_available = True +except ImportError: + FusedRMSNormGated = None + chunk_gated_delta_rule = None + fused_recurrent_gated_delta_rule = None + _fla_available = False + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask +else: + BlockMask = torch.Tensor + +logger = logging.get_logger(__name__) + +is_fast_path_available = _mamba_ssm_available and all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + +# Log availability of optimized kernels +if not _mamba_ssm_available: + logger.warning("mamba_ssm library not available. Mamba layers will not work without it.") +if not _fla_available: + logger.info("FLA (Fused Linear Attention) library not available. Using fallback implementations for GatedDeltaNet.") +if not is_fast_path_available: + logger.warning("Fast path for Mamba is not available. Some kernels are missing.") + + +# Helper functions for Mamba +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +@torch.compile +def segsum(x): + """More stable segment sum calculation.""" + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + x_segsum = torch.cumsum(x, dim=-2) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +@torch.compile +def materialize_mixer(A_log, B, C, D): + """ + Since the transfer matrix will be equated to the attention matrix, + we need to support the form: torch.matmul(attn_weights, value_states). + Thus, y = torch.matmul(T, X) + """ + batch_size, length, n_heads, d_state = B.shape + assert A_log.shape == (batch_size, length, n_heads) + assert B.shape == C.shape == (batch_size, length, n_heads, d_state) + + A_log = rearrange(-F.softplus(A_log), "b l h -> b h l") + powers = torch.exp(segsum(A_log)) + T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers) + + if D is not None: + T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1) + + T = rearrange(T, "b h z l -> b h l z") + return T + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """Tunes out the hidden states for padding tokens.""" + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + return hidden_states + + +class Apriel2Attention(nn.Module): + """ + Attention wrapper that handles rotary embeddings internally. + Contains self.self_attn and self.rotary_emb as sub-modules. + Mirrors Fast-LLM's architecture where each Attention has its own rotary. + """ + + def __init__(self, d_model: int, mixer_config: dict, layer_idx: int, config): + super().__init__() + + # Extract attention parameters from mixer_config + num_heads = mixer_config.get("heads", 32) + num_key_value_heads = mixer_config.get("head_groups", num_heads) + head_dim = mixer_config.get("head_size", d_model // num_heads) + rope_theta = ( + mixer_config.get("rotary", {}).get("theta", 10000.0) + if isinstance(mixer_config.get("rotary"), dict) + else 10000.0 + ) + + # Create attention config + attn_config = SimpleNamespace( + hidden_size=d_model, + num_attention_heads=num_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + max_position_embeddings=config.max_position_embeddings, + rope_theta=rope_theta, + attention_dropout=0.0, + sliding_window=mixer_config.get("sliding_window", None), + _attn_implementation=config._attn_implementation, + ) + + # Create attention sub-module + self.self_attn = MistralAttention(attn_config, layer_idx) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple] = None, + **kwargs, + ): + return self.self_attn(hidden_states, position_embeddings, attention_mask, **kwargs) + + +def create_mixer(mixer_config: dict, hidden_size: int, layer_idx: int, config, allow_stochastic: bool = True): + mixer_type = mixer_config.get("type", "attention") + + if mixer_type == "attention": + return Apriel2Attention(hidden_size, mixer_config, layer_idx, config) + elif mixer_type == "mamba": + return Apriel2Mamba(hidden_size, mixer_config, layer_idx=layer_idx) + elif mixer_type == "gated_delta_net": + return Apriel2GatedDeltaNet(hidden_size, mixer_config, layer_idx=layer_idx) + elif mixer_type == "kimi_linear_attention": + return KimiLinearAttention(hidden_size, mixer_config, layer_idx=layer_idx) + elif mixer_type == "stochastic": + if not allow_stochastic: + raise ValueError("Stochastic mixers cannot contain nested stochastic mixers") + return Apriel2StochasticMixer(mixer_config, config, layer_idx) + else: + raise ValueError(f"Unknown mixer type: {mixer_type}") + + +class Apriel2Mamba(nn.Module): + """Mamba mixer.""" + + def __init__( + self, + d_model, + config_dict: dict, + layer_idx=None, + device=None, + dtype=None, + ): + """Initialize Mamba from a config dictionary.""" + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + # Extract parameters from config dict + d_state = config_dict.get("state_size", 16) + d_inner = config_dict.get("d_inner") + d_xb = config_dict.get("d_xb", None) + d_conv = config_dict.get("d_conv", 4) + expand = config_dict.get("expand", 2) + dt_rank = config_dict.get("dt_rank", "auto") + dt_min = config_dict.get("dt_min", 0.001) + dt_max = config_dict.get("dt_max", 0.1) + dt_init = config_dict.get("dt_init", "random") + dt_scale = config_dict.get("dt_scale", 1.0) + dt_init_floor = config_dict.get("dt_init_floor", 1e-4) + repeat_kv_before_conv = config_dict.get("repeat_kv_before_conv", True) + conv_bias = config_dict["conv_bias"] + bias = config_dict.get("add_linear_biases", False) + dt_proj_bias = config_dict["dt_proj_bias"] + + self.d_model = d_model + self.d_xb = d_xb if d_xb is not None else d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = d_inner if d_inner is not None else int(self.expand * self.d_model) + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + self.use_fast_path = True + self.layer_idx = layer_idx + self.repeat_kv_before_conv = repeat_kv_before_conv + + if self.repeat_kv_before_conv: + self.conv1d = nn.Conv1d( + in_channels=self.d_inner, + out_channels=self.d_inner, + bias=conv_bias, + kernel_size=d_conv, + groups=self.d_inner, + padding=d_conv - 1, + **factory_kwargs, + ) + else: + self.conv1d = nn.Conv1d( + in_channels=self.d_xb, + out_channels=self.d_xb, + bias=conv_bias, + kernel_size=d_conv, + groups=self.d_xb, + padding=d_conv - 1, + **factory_kwargs, + ) + + self.activation = "silu" + self.act = nn.SiLU() + + self.num_xb_head = self.d_xb // self.d_state + self.num_C_head = self.d_inner // self.d_state + self.repeat_group = self.num_C_head // self.num_xb_head + + self.in_proj = nn.Linear(self.d_model, 2 * self.d_xb + 2 * self.d_inner, bias=bias, **factory_kwargs) + self.dt_in_proj = nn.Linear(self.d_model, self.dt_rank, bias=bias, **factory_kwargs) + self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=dt_proj_bias, **factory_kwargs) + + # Initialize special dt projection to preserve variance at initialization + dt_init_std = self.dt_rank**-0.5 * dt_scale + if dt_init == "constant": + nn.init.constant_(self.dt_proj.weight, dt_init_std) + elif dt_init == "random": + nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) + else: + raise NotImplementedError + + # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + if self.dt_proj.bias is not None: + dt = torch.exp( + torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) + ).clamp(min=dt_init_floor) + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + self.dt_proj.bias.copy_(inv_dt) + self.dt_proj.bias._no_reinit = True + + # S4D real initialization + A = repeat( + torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), + "n -> d n", + d=self.d_inner, + ).contiguous() + A_log = torch.log(A) + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value=None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + """Forward pass for Mamba.""" + if not is_fast_path_available: + raise RuntimeError( + "Mamba requires mamba_ssm library with causal_conv1d and selective_scan kernels. " + "Install with: pip install mamba-ssm causal-conv1d" + ) + if "cuda" not in self.in_proj.weight.device.type: + raise RuntimeError("Mamba only supports CUDA devices. Current device: " + str(self.in_proj.weight.device)) + + cache_position = kwargs.get("cache_position", None) + batch, seqlen, dim = hidden_states.shape + + ssm_state, conv_state = None, None + use_precomputed_states = False + + seqlen_offset = kwargs.get("seqlen_offset", cache_position[0]) if cache_position is not None else 0 + use_precomputed_states = ( + past_key_value is not None + and isinstance(past_key_value, Apriel2Cache) + and past_key_value.conv_states[self.layer_idx] is not None + and seqlen == 1 + and past_key_value.conv_states[self.layer_idx].shape[0] + == past_key_value.recurrent_states[self.layer_idx].shape[0] + == batch + and cache_position is not None + and seqlen_offset > 0 + ) + + ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch) + # Adaptive mode selection: use step() for single-token generation + # This provides significant speedup during autoregressive decoding + if use_precomputed_states: + out, _, _ = self.step(hidden_states, conv_state, ssm_state) + return (out,) + + A = -torch.exp(self.A_log.float()) + + zxbc = self.in_proj(hidden_states) + z, x, B, C = torch.split( + zxbc, + [self.d_inner, self.d_xb, self.d_xb, self.d_inner], + dim=-1, + ) + + x = rearrange(x, "b l d -> b d l") + z = rearrange(z, "b l d -> b d l") + + B = rearrange(B, "b l (n_group dstate) -> b n_group l dstate", dstate=self.d_state) + B = repeat_kv(B, self.repeat_group) + B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() + C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() + + dt = self.dt_proj(self.dt_in_proj(hidden_states)) + dt = rearrange(dt, "b l d -> b d l") + + if self.repeat_kv_before_conv: + x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) + x = repeat_kv(x, self.repeat_group) + x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + + if conv_state is not None: + conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) + + # Compute short convolution + if causal_conv1d_fn is None: + x = self.act(self.conv1d(x)[..., :seqlen]) + else: + assert self.activation in ["silu", "swish"] + x = causal_conv1d_fn( + x=x, + weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), + bias=self.conv1d.bias, + activation=self.activation, + ) + + if not self.repeat_kv_before_conv: + x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) + x = repeat_kv(x, self.repeat_group) + x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + + y = selective_scan_fn( + x, + dt, + A, + B, + C, + self.D.float(), + z=z, + delta_bias=self.dt_proj.bias.float() if self.dt_proj.bias is not None else None, + delta_softplus=True, + return_last_state=(ssm_state is not None), + ) + + if ssm_state is not None: + y, last_state = y + ssm_state.copy_(rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head)) + + y = rearrange(y, "b d l -> b l d") + out = self.out_proj(y) + + return (out[:, :seqlen, :],) + + def step(self, hidden_states, conv_state, ssm_state): + dtype = hidden_states.dtype + assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" + + hidden_states_input = hidden_states.squeeze(1) + + A = -torch.exp(self.A_log.float()) + + zxbc = self.in_proj(hidden_states_input) + z, x, B, C = torch.split( + zxbc, + [self.d_inner, self.d_xb, self.d_xb, self.d_inner], + dim=-1, + ) + + B = rearrange(B, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state) + B = torch.repeat_interleave(B, dim=1, repeats=self.repeat_group) + C = rearrange(C, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state).contiguous() + + dt = self.dt_proj(self.dt_in_proj(hidden_states_input)) + + if self.repeat_kv_before_conv: + x = rearrange(x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state) + x = torch.repeat_interleave(x, dim=1, repeats=self.repeat_group) + x = rearrange(x, "b n_group dstate -> b (n_group dstate)") + + # Conv step + if causal_conv1d_update is None: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) + conv_state[:, :, -1] = x + x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) + if self.conv1d.bias is not None: + x = x + self.conv1d.bias + x = self.act(x).to(dtype=dtype) + else: + x = causal_conv1d_update( + x, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation, + ) + + if not self.repeat_kv_before_conv: + x = rearrange(x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state) + x = torch.repeat_interleave(x, dim=1, repeats=self.repeat_group) + x = rearrange(x, "b n_group dstate -> b (n_group dstate)") + + x = rearrange(x, "b (h d) -> b h d", h=self.num_C_head) + dt = rearrange(dt, "b (h d) -> b h d", h=self.num_C_head) + A = rearrange(A, "(h d) n -> h d n", h=self.num_C_head) + D = rearrange(self.D, "(h d) -> h d", h=self.num_C_head) + z = rearrange(z, "b (h d) -> b h d", h=self.num_C_head) + dt_bias = rearrange(self.dt_proj.bias, "(h d) -> h d", h=self.num_C_head) if self.dt_proj.bias is not None else None + + # SSM step + assert selective_state_update is not None + y = selective_state_update(ssm_state, x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=True) + y = rearrange(y, "b h d -> b (h d)") + out = self.out_proj(y) + + return out.unsqueeze(1), conv_state, ssm_state + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + device = self.out_proj.weight.device + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + if self.repeat_kv_before_conv: + conv_state = torch.zeros(batch_size, self.d_inner, self.d_conv, device=device, dtype=conv_dtype) + else: + conv_state = torch.zeros(batch_size, self.d_xb, self.d_conv, device=device, dtype=conv_dtype) + ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype + ssm_state = torch.zeros( + batch_size, self.num_C_head, self.d_inner // self.num_C_head, self.d_state, device=device, dtype=ssm_dtype + ) + return conv_state, ssm_state + + def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): + assert self.layer_idx is not None + if inference_params is None or not isinstance(inference_params, Apriel2Cache): + return None, None + + if inference_params.conv_states[self.layer_idx] is None: + conv_state, ssm_state = self.allocate_inference_cache(batch_size, max_seqlen=0) + inference_params.conv_states[self.layer_idx] = conv_state + inference_params.recurrent_states[self.layer_idx] = ssm_state + + ssm_state = inference_params.recurrent_states[self.layer_idx] + conv_state = inference_params.conv_states[self.layer_idx] + + if initialize_states: + ssm_state.zero_() + conv_state.zero_() + + return ssm_state, conv_state + + +class Apriel2GatedDeltaNet(nn.Module): + """Wrapper around Qwen3NextGatedDeltaNet to match apriel2 interface.""" + + def __init__( + self, + d_model, + config_dict: dict, + layer_idx=None, + device=None, + dtype=None, + ): + super().__init__() + + # Map config_dict to Qwen3NextConfig format + config = SimpleNamespace( + hidden_size=d_model, + linear_num_value_heads=config_dict.get("num_value_heads", 32), + linear_num_key_heads=config_dict.get("num_key_heads", 8), + linear_key_head_dim=config_dict.get("key_head_dim", 64), + linear_value_head_dim=config_dict.get("value_head_dim", 64), + linear_conv_kernel_dim=config_dict.get("conv_kernel_size", 4), + hidden_act=config_dict.get("activation", "silu"), + rms_norm_eps=config_dict.get("norm_eps", 1e-5), + dtype=dtype, + ) + + self.gdn = Qwen3NextGatedDeltaNet(config, layer_idx) + + def forward(self, hidden_states: torch.Tensor, past_key_value=None, attention_mask=None, **kwargs): + cache_position = kwargs.get("cache_position", None) + output = self.gdn( + hidden_states, cache_params=past_key_value, cache_position=cache_position, attention_mask=attention_mask + ) + return (output,) + + +class KimiLinearAttention(nn.Module): + """KimiLinearAttention mixer - stub for future implementation.""" + + def __init__( + self, + d_model, + config_dict: dict, + layer_idx=None, + device=None, + dtype=None, + ): + super().__init__() + raise NotImplementedError("KimiLinearAttention not yet implemented in apriel2") + + def forward(self, hidden_states: torch.Tensor, **kwargs): + raise NotImplementedError("KimiLinearAttention not yet implemented in apriel2") + + +class Apriel2DecoderBlock(nn.Module): + def __init__(self, config: Apriel2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + + # Get block name and config for this layer + self.block_name = config.get_block_name(layer_idx) + block_config = config.get_block_config(layer_idx) + + # Create mixer based on type + mixer_config = block_config.get("mixer", {"type": "attention"}) + self.mixer = create_mixer(mixer_config, config.hidden_size, layer_idx, config, allow_stochastic=True) + + # Create MLP + mlp_config = block_config.get("mlp", {"type": "mlp"}) + self.mlp = self._create_mlp(mlp_config, config) + + # Create normalization layers + norm_config = block_config.get("normalization", {"type": "rms_norm"}) + self.input_layernorm = self._create_norm(norm_config, config) + self.post_attention_layernorm = self._create_norm(norm_config, config) + + def _create_mlp(self, mlp_config: dict, config: Apriel2Config): + """Create MLP based on config.""" + mlp_type = mlp_config.get("type", "mlp") + + if mlp_type == "mlp": + intermediate_size = mlp_config.get("intermediate_size", config.hidden_size * 4) + mlp_cfg = SimpleNamespace( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=mlp_config.get("activation", "silu"), + ) + return MistralMLP(mlp_cfg) + else: + raise ValueError(f"Unknown MLP type: {mlp_type}") + + def _create_norm(self, norm_config: dict, config: Apriel2Config): + """Create normalization layer based on config.""" + norm_type = norm_config.get("type", "rms_norm") + if norm_type == "rms_norm": + return MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + elif norm_type == "layer_norm": + return nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + raise ValueError(f"Unknown normalization type: {norm_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Apriel2Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + position_embeddings=None, + **kwargs, + ) -> tuple: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + mixer_outputs = self.mixer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = mixer_outputs[0] + hidden_states = residual + hidden_states + + # MLP + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (mixer_outputs[1],) if len(mixer_outputs) > 1 else (None,) + if use_cache: + outputs += (mixer_outputs[2] if len(mixer_outputs) > 2 else None,) + + return outputs + + +class Apriel2StochasticMixer(nn.Module): + """ + Stochastic mixer that contains multiple mixer options. + + During training: randomly samples one mixer per forward pass + During inference: uses the main_mixer + """ + + def __init__(self, mixer_config: dict, config: Apriel2Config, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + + # Get sub-mixer configs + mixers_config = mixer_config.get("mixers", {}) + self.main_mixer_name = mixer_config.get("main_mixer_name", list(mixers_config.keys())[0]) + + # Sampling strategy + self.sampling_strategy = mixer_config.get("sampling_strategy", "uniform") + sampling_weights = mixer_config.get("sampling_weights", None) + + # Create each sub-mixer + self.mixers = nn.ModuleDict() + for name, sub_mixer_config in mixers_config.items(): + self.mixers[name] = create_mixer( + sub_mixer_config, config.hidden_size, layer_idx, config, allow_stochastic=False + ) + + # Set up sampling probabilities + mixer_names = list(self.mixers.keys()) + if self.sampling_strategy == "uniform": + self._sampling_probs = [1.0 / len(self.mixers)] * len(self.mixers) + elif self.sampling_strategy == "weighted": + if sampling_weights is None: + raise ValueError("sampling_weights must be provided when using weighted sampling strategy") + # Normalize weights to sum to 1.0 + total = sum(sampling_weights.get(name, 1.0) for name in mixer_names) + self._sampling_probs = [sampling_weights.get(name, 1.0) / total for name in mixer_names] + else: + raise ValueError(f"Unknown sampling_strategy: {self.sampling_strategy}") + + self._mixer_names = mixer_names + logger.info( + f"Initialized Apriel2StochasticMixer at layer {layer_idx} with {len(self.mixers)} mixers: " + f"{', '.join(mixer_names)} (main={self.main_mixer_name}, strategy={self.sampling_strategy})" + ) + + def forward( + self, hidden_states: torch.Tensor, attention_mask=None, position_embeddings: Optional[dict] = None, **kwargs + ): + # Sample mixer during training, use main_mixer during inference + if self.training: + mixer_name = random.choices(self._mixer_names, weights=self._sampling_probs)[0] + else: + mixer_name = self.main_mixer_name + + # Set active mixer in cache for proper state routing + past_key_value = kwargs.get("past_key_value") + if past_key_value is not None and hasattr(past_key_value, "set_active_mixer"): + past_key_value.set_active_mixer(self.layer_idx, mixer_name) + + mixer = self.mixers[mixer_name] + mixer_position_embeddings = position_embeddings.get(mixer_name) if position_embeddings else None + mixer_attention_mask = ( + attention_mask.get(mixer_name) if isinstance(attention_mask, dict) else attention_mask + ) + return mixer( + hidden_states, attention_mask=mixer_attention_mask, position_embeddings=mixer_position_embeddings, **kwargs + ) + + +class Apriel2PreTrainedModel(PreTrainedModel): + config_class = Apriel2Config + base_model_prefix = "model" + _no_split_modules = ["Apriel2DecoderBlock"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = False + _supports_static_cache = False + _supports_attention_backend = True + + def _prepare_cache_for_generation( + self, generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, *args + ): + if generation_config.use_cache is False: + return + model_kwargs["past_key_values"] = Apriel2Cache(config=self.config) + + def _init_weights(self, module): + std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02 + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, MistralRMSNorm): + module.weight.data.fill_(1.0) + + +class Apriel2Model(Apriel2PreTrainedModel): + def __init__(self, config: Apriel2Config): + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # Embeddings + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + # Build shared rotary embeddings (one per unique block type) + self.rotary_embs = nn.ModuleDict() + self._build_rotary_embs() + + # Decoder blocks + self.layers = nn.ModuleList( + [Apriel2DecoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + # Final norm + self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + self.post_init() + + def _create_rotary_emb_for_attention(self, mixer_config: dict): + from transformers.models.mistral.modeling_mistral import MistralRotaryEmbedding + + head_dim = mixer_config.get("head_size", self.config.hidden_size // mixer_config.get("heads", 32)) + rope_theta = ( + mixer_config.get("rotary", {}).get("theta", 10000.0) + if isinstance(mixer_config.get("rotary"), dict) + else 10000.0 + ) + + rotary_config = SimpleNamespace( + max_position_embeddings=self.config.max_position_embeddings, + rope_theta=rope_theta, + head_dim=head_dim, + hidden_size=self.config.hidden_size, + num_attention_heads=mixer_config.get("heads", 32), + partial_rotary_factor=1.0, + ) + return MistralRotaryEmbedding(config=rotary_config) + + def _build_attn_config_for_mask(self, mixer_config: dict): + """Build attention config for causal mask creation.""" + num_heads = mixer_config.get("heads", 32) + num_key_value_heads = mixer_config.get("head_groups", num_heads) + head_dim = mixer_config.get("head_size", self.config.hidden_size // num_heads) + + return SimpleNamespace( + hidden_size=self.config.hidden_size, + num_attention_heads=num_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + max_position_embeddings=self.config.max_position_embeddings, + sliding_window=mixer_config.get("sliding_window", None), + _attn_implementation=self.config._attn_implementation, + ) + + def _build_rotary_embs(self): + """Build rotary embedding instances for all unique attention blocks.""" + decoder_type = self.config.decoder.get("type", "fixed") + + if decoder_type == "fixed": + block_config = self.config.decoder.get("block", {}) + self._build_rotary_embs_for_block("block", block_config) + elif decoder_type == "pattern": + blocks = self.config.decoder.get("blocks", {}) + for block_name, block_config in blocks.items(): + self._build_rotary_embs_for_block(block_name, block_config) + else: + raise ValueError(f"Unknown decoder type: {decoder_type}") + + def _build_rotary_embs_for_block(self, block_name: str, block_config: dict): + """Build rotary embeddings for a single block and its mixers.""" + mixer_config = block_config.get("mixer", {}) + mixer_type = mixer_config.get("type") + + if mixer_type == "attention": + self.rotary_embs[block_name] = self._create_rotary_emb_for_attention(mixer_config) + elif mixer_type == "stochastic": + mixers = mixer_config.get("mixers", {}) + nested_dict = nn.ModuleDict() + for mixer_name, sub_mixer_config in mixers.items(): + if sub_mixer_config.get("type") == "attention": + nested_dict[mixer_name] = self._create_rotary_emb_for_attention(sub_mixer_config) + if len(nested_dict) > 0: + self.rotary_embs[block_name] = nested_dict + + def _create_causal_mask( + self, + attn_config, + input_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_ids: torch.LongTensor, + past_key_values: Optional[Apriel2Cache], + cache_position: torch.Tensor, + ) -> Optional[Union[torch.Tensor, BlockMask]]: + """Create causal mask for an attention config.""" + + mask_function = create_causal_mask if attn_config.sliding_window is None else create_sliding_window_causal_mask + return mask_function( + config=attn_config, + input_embeds=input_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + def _compute_position_embeddings_and_masks( + self, + input_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_ids: torch.LongTensor, + past_key_values: Optional[Apriel2Cache], + cache_position: torch.Tensor, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Compute position embeddings and attention masks for all unique attention blocks.""" + position_embeddings = {} + attention_masks = {} + decoder_type = self.config.decoder.get("type", "fixed") + + if decoder_type == "fixed": + block_config = self.config.decoder.get("block", {}) + self._compute_for_block( + "block", + block_config, + input_embeds, + attention_mask, + position_ids, + past_key_values, + cache_position, + position_embeddings, + attention_masks, + ) + elif decoder_type == "pattern": + blocks = self.config.decoder.get("blocks", {}) + for block_name, block_config in blocks.items(): + self._compute_for_block( + block_name, + block_config, + input_embeds, + attention_mask, + position_ids, + past_key_values, + cache_position, + position_embeddings, + attention_masks, + ) + else: + raise ValueError(f"Unknown decoder type: {decoder_type}") + + return position_embeddings, attention_masks + + def _compute_for_block( + self, + block_name: str, + block_config: dict, + input_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_ids: torch.LongTensor, + past_key_values: Optional[Apriel2Cache], + cache_position: torch.Tensor, + position_embeddings: dict[str, Any], + attention_masks: dict[str, Any], + ) -> None: + """Compute position embeddings and attention masks for a block.""" + mixer_config = block_config.get("mixer", {}) + mixer_type = mixer_config.get("type") + + if mixer_type == "attention": + rotary_emb = self.rotary_embs[block_name] + cos, sin = rotary_emb(input_embeds, position_ids) + attn_config = self._build_attn_config_for_mask(mixer_config) + causal_mask = self._create_causal_mask( + attn_config, input_embeds, attention_mask, position_ids, past_key_values, cache_position + ) + + position_embeddings[block_name] = (cos, sin) + attention_masks[block_name] = causal_mask + + elif mixer_type == "stochastic": + mixers = mixer_config.get("mixers", {}) + nested_pos_embs = {} + nested_masks = {} + + for mixer_name, sub_mixer_config in mixers.items(): + if sub_mixer_config.get("type") == "attention": + rotary_emb = self.rotary_embs[block_name][mixer_name] + cos, sin = rotary_emb(input_embeds, position_ids) + attn_config = self._build_attn_config_for_mask(sub_mixer_config) + causal_mask = self._create_causal_mask( + attn_config, input_embeds, attention_mask, position_ids, past_key_values, cache_position + ) + + nested_pos_embs[mixer_name] = (cos, sin) + nested_masks[mixer_name] = causal_mask + + if nested_pos_embs: + position_embeddings[block_name] = nested_pos_embs + attention_masks[block_name] = nested_masks + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Apriel2Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = Apriel2Cache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + position_embeddings, causal_masks = self._compute_position_embeddings_and_masks( + inputs_embeds, attention_mask, position_ids, past_key_values, cache_position + ) + + hidden_states = inputs_embeds + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for layer_idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + block_name = self.config.get_block_name(layer_idx) + layer_position_embeddings = position_embeddings.get(block_name) + layer_attention_mask = causal_masks.get(block_name) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=layer_attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + position_embeddings=layer_position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if use_cache: + next_decoder_cache = past_key_values + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, next_decoder_cache, all_hidden_states, all_self_attns] if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class Apriel2ForCausalLM(Apriel2PreTrainedModel, GenerationMixin): + """Apriel2 model with a language modeling head.""" + + def __init__(self, config: Apriel2Config): + super().__init__(config) + self.model = Apriel2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Apriel2Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Union[tuple, CausalLMOutputWithPast]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Forward through model + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift for next-token prediction + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py index 40c4cfa8..e6358443 100644 --- a/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py +++ b/fast_llm_external_models/apriel_hybrid_ssm/modeling_apriel_hybrid_ssm.py @@ -18,7 +18,7 @@ from transformers.modeling_utils import PreTrainedModel from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm from transformers.processing_utils import Unpack -from transformers.utils import LossKwargs, logging +from transformers.utils import TransformersKwargs, logging from transformers.utils.generic import ModelOutput from fast_llm_external_models.apriel_hybrid_ssm.configuration_apriel_hybrid_ssm import AprielHybridSSMConfig @@ -1199,18 +1199,18 @@ def __init__(self, config: AprielHybridSSMConfig, **kwargs): super().__init__(config_copy, **kwargs) self.config = config blocks = [] - logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}") - for layer_idx, type in enumerate(config.hybrid_block_layout): - if type == "m2d": + logger.info(f"Loading hybrid model with the following layout: {config.hybrid_block_layout}") + for layer_idx, block_type in enumerate(config.hybrid_block_layout): + if block_type == "m2d": blocks.append(AprielSSMDecoderLayer(config, layer_idx)) - elif type == "m2": + elif block_type == "m2": blocks.append(AprielSSMM2DecoderLayer(config, layer_idx)) - elif type == "t": + elif block_type == "t": blocks.append(MistralDecoderLayer(config, layer_idx)) - elif type == "i": + elif block_type == "i": blocks.append(AprielHybridIdentity(config)) else: - raise ValueError(f"Invalid block type: {type}") + raise ValueError(f"Invalid block type: {block_type}") self.layers = nn.ModuleList(blocks) # Initialize weights and apply final processing @@ -1252,9 +1252,6 @@ def forward( return output -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... - - class AprielHybridSSMPreTrainedModel(PreTrainedModel): config_class = AprielHybridSSMConfig base_model_prefix = "model" @@ -1383,7 +1380,7 @@ def forward( output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): diff --git a/fast_llm_external_models/tests/__init__.py b/fast_llm_external_models/tests/__init__.py new file mode 100644 index 00000000..260aa65f --- /dev/null +++ b/fast_llm_external_models/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for fast_llm_external_models package.""" diff --git a/fast_llm_external_models/tests/conftest.py b/fast_llm_external_models/tests/conftest.py new file mode 100644 index 00000000..50645991 --- /dev/null +++ b/fast_llm_external_models/tests/conftest.py @@ -0,0 +1,15 @@ +"""Shared test fixtures for fast_llm_external_models. + +This conftest.py contains only fixtures that are shared across multiple model test suites. +Model-specific fixtures should be in the respective model's test directory +(e.g., test_apriel2/conftest.py, test_apriel_hybrid_ssm/conftest.py). +""" + +import pytest +import torch + + +@pytest.fixture +def device(): + """Get available device (CPU or CUDA).""" + return torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/fast_llm_external_models/tests/test_apriel2/__init__.py b/fast_llm_external_models/tests/test_apriel2/__init__.py new file mode 100644 index 00000000..b821099f --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/__init__.py @@ -0,0 +1 @@ +"""Tests for Apriel2 model implementation.""" diff --git a/fast_llm_external_models/tests/test_apriel2/conftest.py b/fast_llm_external_models/tests/test_apriel2/conftest.py new file mode 100644 index 00000000..4cadc988 --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/conftest.py @@ -0,0 +1,175 @@ +"""Test fixtures for Apriel2 model tests.""" + +import pytest +import torch + + +@pytest.fixture +def apriel2_config_tiny(): + """Tiny Apriel2 config for fast testing.""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + ) + + +@pytest.fixture +def apriel2_config_stochastic(): + """Apriel2 config with stochastic mixer for testing routing.""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + decoder={ + "type": "pattern", + "pattern": ["attn", "stoch"], + "blocks": { + "attn": {"mixer": {"type": "attention"}}, + "stoch": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"type": "attention", "sliding_window": 4096}, + "mamba": { + "type": "mamba", + "conv_bias": True, + "dt_proj_bias": True + } + } + } + } + } + } + ) + + +@pytest.fixture +def apriel2_config_multi_mixer(): + """Apriel2 config with multiple mixers of same type.""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=2, + decoder={ + "type": "pattern", + "pattern": ["multi"], + "blocks": { + "multi": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attn_small", + "mixers": { + "attn_small": {"type": "attention", "sliding_window": 2048}, + "attn_large": {"type": "attention", "sliding_window": 8192}, + "mamba_v1": { + "type": "mamba", + "conv_bias": True, + "dt_proj_bias": True + }, + "mamba_v2": { + "type": "mamba", + "conv_bias": True, + "dt_proj_bias": True + } + } + } + } + } + } + ) + + +@pytest.fixture +def apriel2_config_all_mixers(): + """Apriel2 config with all 4 mixer types in one stochastic mixer. + + This config is critical for testing: + - All mixer types work (attention, swa, mamba, gated_delta_net) + - Cache correctly isolates different mixer types + - Switching between mixers preserves independent state + """ + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + return Apriel2Config( + vocab_size=100, + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + decoder={ + "type": "pattern", + "pattern": ["attn", "all_mixers"], + "blocks": { + "attn": {"mixer": {"type": "attention"}}, + "all_mixers": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": { + "type": "attention" + }, + "swa": { + "type": "attention", + "sliding_window": 2048 + }, + "mamba": { + "type": "mamba", + "conv_bias": True, + "dt_proj_bias": True + }, + "gated_delta_net": { + "type": "gated_delta_net" + } + } + } + } + } + } + ) + + +@pytest.fixture +def apriel2_cache(apriel2_config_tiny): + """Create empty Apriel2Cache from tiny config.""" + from fast_llm_external_models.apriel2.cache import Apriel2Cache + + return Apriel2Cache(apriel2_config_tiny) + + +@pytest.fixture +def sample_input_ids(): + """Sample input token IDs for testing.""" + return torch.randint(0, 100, (2, 10)) # batch_size=2, seq_len=10 + + +@pytest.fixture +def sample_attention_states(): + """Sample attention key/value states for cache testing.""" + batch_size, num_heads, seq_len, head_dim = 2, 8, 10, 64 + key = torch.randn(batch_size, num_heads, seq_len, head_dim) + value = torch.randn(batch_size, num_heads, seq_len, head_dim) + return key, value + + +@pytest.fixture +def sample_ssm_states(): + """Sample SSM conv/recurrent states for cache testing.""" + batch_size, d_inner, d_conv = 2, 128, 4 + conv = torch.randn(batch_size, d_inner, d_conv) + recurrent = torch.randn(batch_size, d_inner, 16) # d_state=16 + return conv, recurrent diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache.py b/fast_llm_external_models/tests/test_apriel2/test_cache.py new file mode 100644 index 00000000..d10a935a --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_cache.py @@ -0,0 +1,146 @@ +"""Unit tests for Apriel2Cache.""" + +import pytest +import torch +from fast_llm_external_models.apriel2.cache import Apriel2Cache + + +class TestCacheBasics: + """Test basic cache creation and properties.""" + + def test_cache_creation(self, apriel2_config_tiny): + """Test cache creation from config.""" + cache = Apriel2Cache(apriel2_config_tiny) + assert len(cache) == apriel2_config_tiny.num_hidden_layers + assert cache.is_compileable == False + assert cache.is_initialized == False + assert isinstance(cache.is_sliding, list) + assert len(cache.is_sliding) == apriel2_config_tiny.num_hidden_layers + + def test_cache_properties_empty(self, apriel2_cache): + """Test cache properties when empty.""" + assert apriel2_cache.is_initialized == False + assert apriel2_cache.has_previous_state == False + assert apriel2_cache.max_batch_size is None + assert apriel2_cache.max_cache_len is None + + +class TestAttentionCache: + """Test attention cache operations.""" + + def test_attention_update(self, apriel2_cache, sample_attention_states): + """Test updating attention cache.""" + key, value = sample_attention_states + k_out, v_out = apriel2_cache.update(key, value, layer_idx=0) + + assert k_out.shape == key.shape + assert v_out.shape == value.shape + assert apriel2_cache.is_initialized == True + assert apriel2_cache.get_seq_length(0) == key.shape[2] + + def test_attention_concatenation(self, apriel2_cache, sample_attention_states): + """Test that cache concatenates new states.""" + key1, value1 = sample_attention_states + apriel2_cache.update(key1, value1, layer_idx=0) + + # Add more tokens + key2 = torch.randn(2, 8, 5, 64) + value2 = torch.randn(2, 8, 5, 64) + k_out, v_out = apriel2_cache.update(key2, value2, layer_idx=0) + + assert k_out.shape[2] == 15 # 10 + 5 + assert apriel2_cache.get_seq_length(0) == 15 + + +class TestSSMCache: + """Test SSM cache operations.""" + + def test_ssm_direct_access(self, apriel2_config_stochastic): + """Test direct SSM state access.""" + cache = Apriel2Cache(apriel2_config_stochastic) + + # Set active mixer to mamba + cache.set_active_mixer(1, "mamba") + + # Set conv states + conv = torch.randn(2, 128, 4) + cache.conv_states[1] = conv + + # Retrieve and verify + retrieved = cache.conv_states[1] + assert retrieved is not None + assert torch.allclose(retrieved, conv) + + +class TestStochasticMixer: + """Test stochastic mixer cache routing.""" + + def test_set_active_mixer(self, apriel2_config_stochastic): + """Test setting active mixer.""" + cache = Apriel2Cache(apriel2_config_stochastic) + cache.set_active_mixer(1, "attention") + assert cache.active_mixers[1] == "attention" + + def test_routing_to_different_mixers(self, apriel2_config_stochastic, sample_attention_states): + """Test that different mixers use separate caches.""" + cache = Apriel2Cache(apriel2_config_stochastic) + key, value = sample_attention_states + + # Use attention mixer + cache.set_active_mixer(1, "attention") + cache.update(key, value, layer_idx=1) + attn_len = cache.get_seq_length(1) + + # Switch to mamba mixer - should have empty cache + cache.set_active_mixer(1, "mamba") + mamba_len = cache.get_seq_length(1) + + assert attn_len == 10 + assert mamba_len == 0 # Different cache + + +class TestBeamSearch: + """Test beam search operations.""" + + def test_batch_repeat_interleave(self, apriel2_cache, sample_attention_states): + """Test repeating cache for beam search.""" + key, value = sample_attention_states + apriel2_cache.update(key, value, layer_idx=0) + + apriel2_cache.batch_repeat_interleave(2) + assert apriel2_cache.max_batch_size == 4 # 2 * 2 + + def test_reorder_cache(self, apriel2_cache, sample_attention_states): + """Test reordering cache for beam search.""" + key, value = sample_attention_states + apriel2_cache.update(key, value, layer_idx=0) + + beam_idx = torch.tensor([1, 0]) + apriel2_cache.reorder_cache(beam_idx) + + # Cache should still be valid + assert apriel2_cache.is_initialized == True + + +class TestCacheReset: + """Test cache reset operations.""" + + def test_reset(self, apriel2_cache, sample_attention_states): + """Test resetting cache.""" + key, value = sample_attention_states + apriel2_cache.update(key, value, layer_idx=0) + + assert apriel2_cache.is_initialized == True + + apriel2_cache.reset() + + assert apriel2_cache.is_initialized == False + assert apriel2_cache.get_seq_length(0) == 0 + + def test_crop(self, apriel2_cache, sample_attention_states): + """Test cropping cache to max length.""" + key, value = sample_attention_states + apriel2_cache.update(key, value, layer_idx=0) + + apriel2_cache.crop(5) + assert apriel2_cache.get_seq_length(0) == 5 diff --git a/fast_llm_external_models/tests/test_apriel2/test_cache_routing.py b/fast_llm_external_models/tests/test_apriel2/test_cache_routing.py new file mode 100644 index 00000000..220bc2cf --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_cache_routing.py @@ -0,0 +1,291 @@ +"""Tests for stochastic mixer cache routing and bug fixes.""" + +import pytest +import torch +from fast_llm_external_models.apriel2.cache import Apriel2Cache + + +class TestHasPreviousState: + """Test has_previous_state property with stochastic mixers.""" + + def test_checks_all_sub_caches(self, apriel2_config_stochastic): + """Test that has_previous_state checks ALL sub-caches, not just main mixer.""" + cache = Apriel2Cache(apriel2_config_stochastic) + + # Initially no SSM state + assert cache.has_previous_state == False + + # Set active mixer to mamba (NOT the main mixer which is attention) + cache.set_active_mixer(1, "mamba") + cache.conv_states[1] = torch.randn(2, 128, 4) + + # Should detect SSM state even though main mixer is "attention" + assert cache.has_previous_state == True + + def test_detects_any_ssm_cache(self, apriel2_config_multi_mixer): + """Test that has_previous_state detects SSM state in any sub-cache.""" + cache = Apriel2Cache(apriel2_config_multi_mixer) + + # Fill mamba_v1 + cache.set_active_mixer(0, "mamba_v1") + cache.conv_states[0] = torch.randn(2, 128, 4) + + # Fill mamba_v2 + cache.set_active_mixer(0, "mamba_v2") + cache.conv_states[0] = torch.randn(2, 128, 4) + + # Should detect SSM state from either variant + assert cache.has_previous_state == True + + +class TestPropertyAccessorGuards: + """Test that property accessors guard against None active_mixer.""" + + def test_get_raises_error_without_active_mixer(self, apriel2_config_stochastic): + """Test that accessing cache without set_active_mixer raises clear error.""" + cache = Apriel2Cache(apriel2_config_stochastic) + + with pytest.raises(RuntimeError) as exc_info: + _ = cache.conv_states[1] + + assert "requires set_active_mixer()" in str(exc_info.value) + assert "Available mixers:" in str(exc_info.value) + + def test_set_raises_error_without_active_mixer(self, apriel2_config_stochastic): + """Test that setting cache without set_active_mixer raises clear error.""" + cache = Apriel2Cache(apriel2_config_stochastic) + + with pytest.raises(RuntimeError) as exc_info: + cache.conv_states[1] = torch.randn(2, 128, 4) + + assert "requires set_active_mixer()" in str(exc_info.value) + + def test_access_works_after_set_active_mixer(self, apriel2_config_stochastic): + """Test that access works correctly after set_active_mixer.""" + cache = Apriel2Cache(apriel2_config_stochastic) + + # Set active mixer + cache.set_active_mixer(1, "mamba") + + # Now access should work + cache.conv_states[1] = torch.randn(2, 128, 4) + retrieved = cache.conv_states[1] + + assert retrieved is not None + + +class TestMixerSwitching: + """Test cache behavior when switching between different mixers.""" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="SSM mixers require CUDA") + def test_cache_preserves_state_across_mixer_switches(self, apriel2_config_all_mixers, device): + """Verify cache maintains independent state for each mixer when switching. + + This is the critical test for stochastic mixers: when we switch which mixer + is active, the cache must preserve previous mixer states while updating the + current mixer's state. + """ + if device.type != "cuda": + pytest.skip("SSM mixers require CUDA device") + + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM + + model = Apriel2ForCausalLM(apriel2_config_all_mixers).to(device) + model.eval() + + stochastic_layer_idx = 1 # Layer 1 is the stochastic layer + stochastic_layer = model.model.layers[stochastic_layer_idx] + input_ids = torch.randint(0, apriel2_config_all_mixers.vocab_size, (2, 10), device=device) + + # Forward 1: Use attention (default main mixer) + stochastic_layer.mixer.main_mixer_name = "attention" + outputs1 = model(input_ids, use_cache=True) + cache = outputs1.past_key_values + + # Verify: only attention has data + layer_cache = cache.layers[stochastic_layer_idx] + assert layer_cache['attention'].key is not None, "Attention cache should have KV states" + assert layer_cache['swa'].key is None, "SWA cache should be empty" + assert layer_cache['mamba'].conv is None, "Mamba cache should be empty" + assert layer_cache['gated_delta_net'].conv is None, "GatedDeltaNet cache should be empty" + attn_seq_len_1 = layer_cache['attention'].key.shape[-2] + + # Forward 2: Switch to mamba (new token) + stochastic_layer.mixer.main_mixer_name = "mamba" + new_token = torch.randint(0, apriel2_config_all_mixers.vocab_size, (2, 1), device=device) + outputs2 = model(new_token, past_key_values=cache, use_cache=True) + cache = outputs2.past_key_values + + # Verify: attention preserved, mamba added + assert layer_cache['attention'].key is not None, "Attention cache should be preserved" + assert layer_cache['attention'].key.shape[-2] == attn_seq_len_1, "Attention seq_len should not change" + assert layer_cache['mamba'].conv is not None, "Mamba cache should now have SSM states" + assert layer_cache['swa'].key is None, "SWA cache should still be empty" + assert layer_cache['gated_delta_net'].conv is None, "GatedDeltaNet cache should still be empty" + + # Forward 3: Switch to swa + stochastic_layer.mixer.main_mixer_name = "swa" + outputs3 = model(new_token, past_key_values=cache, use_cache=True) + cache = outputs3.past_key_values + + # Verify: attention + mamba preserved, swa added + assert layer_cache['attention'].key is not None, "Attention cache should be preserved" + assert layer_cache['mamba'].conv is not None, "Mamba cache should be preserved" + assert layer_cache['swa'].key is not None, "SWA cache should now have KV states" + assert layer_cache['gated_delta_net'].conv is None, "GatedDeltaNet cache should still be empty" + + # Forward 4: Switch to gated_delta_net + stochastic_layer.mixer.main_mixer_name = "gated_delta_net" + outputs4 = model(new_token, past_key_values=cache, use_cache=True) + cache = outputs4.past_key_values + + # Verify: ALL mixers now have independent state + assert layer_cache['attention'].key is not None, "Attention cache should be preserved" + assert layer_cache['mamba'].conv is not None, "Mamba cache should be preserved" + assert layer_cache['swa'].key is not None, "SWA cache should be preserved" + assert layer_cache['gated_delta_net'].conv is not None, "GatedDeltaNet cache should now have SSM states" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="SSM mixers require CUDA") + def test_cache_isolation_between_attention_and_ssm(self, apriel2_config_all_mixers, device): + """Verify attention caches (KV) and SSM caches (conv/recurrent) don't interfere.""" + if device.type != "cuda": + pytest.skip("SSM mixers require CUDA device") + + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM + + model = Apriel2ForCausalLM(apriel2_config_all_mixers).to(device) + model.eval() + + stochastic_layer_idx = 1 + stochastic_layer = model.model.layers[stochastic_layer_idx] + input_ids = torch.randint(0, apriel2_config_all_mixers.vocab_size, (2, 10), device=device) + + # Forward with attention + stochastic_layer.mixer.main_mixer_name = "attention" + outputs1 = model(input_ids, use_cache=True) + cache = outputs1.past_key_values + + # Get attention cache state + attn_cache = cache.layers[stochastic_layer_idx]['attention'] + attn_key = attn_cache.key.clone() + attn_value = attn_cache.value.clone() + + # Forward with mamba (using same cache) + stochastic_layer.mixer.main_mixer_name = "mamba" + new_token = torch.randint(0, apriel2_config_all_mixers.vocab_size, (2, 1), device=device) + outputs2 = model(new_token, past_key_values=cache, use_cache=True) + cache = outputs2.past_key_values + + # Verify attention cache unchanged + assert torch.allclose(cache.layers[stochastic_layer_idx]['attention'].key, attn_key), \ + "Attention KV cache should not be modified when mamba is active" + assert torch.allclose(cache.layers[stochastic_layer_idx]['attention'].value, attn_value), \ + "Attention KV cache should not be modified when mamba is active" + + # Verify mamba cache is populated + assert cache.layers[stochastic_layer_idx]['mamba'].conv is not None, \ + "Mamba SSM cache should be populated" + + def test_seq_len_tracking_per_mixer(self, apriel2_config_all_mixers): + """Verify seq_len is tracked independently for each mixer.""" + from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM + + model = Apriel2ForCausalLM(apriel2_config_all_mixers) + model.eval() + + stochastic_layer_idx = 1 + stochastic_layer = model.model.layers[stochastic_layer_idx] + + # Forward with attention (10 tokens) + input_ids1 = torch.randint(0, apriel2_config_all_mixers.vocab_size, (2, 10)) + stochastic_layer.mixer.main_mixer_name = "attention" + outputs1 = model(input_ids1, use_cache=True) + cache = outputs1.past_key_values + + cache.set_active_mixer(stochastic_layer_idx, "attention") + assert cache.get_seq_length(stochastic_layer_idx) == 10 + + # Forward with swa (5 tokens) - independent from attention + input_ids2 = torch.randint(0, apriel2_config_all_mixers.vocab_size, (2, 5)) + stochastic_layer.mixer.main_mixer_name = "swa" + outputs2 = model(input_ids2, use_cache=True) + cache2 = Apriel2Cache(apriel2_config_all_mixers) # Fresh cache for swa + outputs2 = model(input_ids2, past_key_values=cache2, use_cache=True) + cache2 = outputs2.past_key_values + + cache2.set_active_mixer(stochastic_layer_idx, "swa") + assert cache2.get_seq_length(stochastic_layer_idx) == 5 + + # Original cache should still have attention with seq_len=10 + cache.set_active_mixer(stochastic_layer_idx, "attention") + assert cache.get_seq_length(stochastic_layer_idx) == 10 + + +class TestMultipleMixersSameType: + """Test multiple mixers of the same type with independent caches.""" + + def test_attention_variants_independent(self, apriel2_config_multi_mixer): + """Test that different attention mixers have independent caches.""" + cache = Apriel2Cache(apriel2_config_multi_mixer) + + # Fill attn_small cache + cache.set_active_mixer(0, "attn_small") + key_small = torch.randn(2, 8, 10, 64) + value_small = torch.randn(2, 8, 10, 64) + cache.update(key_small, value_small, 0) + + assert cache.get_seq_length(0) == 10 + + # Switch to attn_large - should have empty cache + cache.set_active_mixer(0, "attn_large") + assert cache.get_seq_length(0) == 0 + + # Fill attn_large + key_large = torch.randn(2, 8, 5, 64) + value_large = torch.randn(2, 8, 5, 64) + cache.update(key_large, value_large, 0) + + assert cache.get_seq_length(0) == 5 + + # Switch back to attn_small - should still have original data + cache.set_active_mixer(0, "attn_small") + assert cache.get_seq_length(0) == 10 + + def test_ssm_variants_independent(self, apriel2_config_multi_mixer): + """Test that different SSM mixers have independent caches.""" + cache = Apriel2Cache(apriel2_config_multi_mixer) + + # Fill mamba_v1 + cache.set_active_mixer(0, "mamba_v1") + conv1 = torch.randn(2, 128, 4) + cache.conv_states[0] = conv1 + + # Fill mamba_v2 + cache.set_active_mixer(0, "mamba_v2") + conv2 = torch.randn(2, 128, 4) + cache.conv_states[0] = conv2 + + # Verify they're different + cache.set_active_mixer(0, "mamba_v1") + retrieved1 = cache.conv_states[0] + + cache.set_active_mixer(0, "mamba_v2") + retrieved2 = cache.conv_states[0] + + assert not torch.allclose(retrieved1, retrieved2) + assert torch.allclose(retrieved1, conv1) + assert torch.allclose(retrieved2, conv2) + + def test_different_window_sizes(self, apriel2_config_multi_mixer): + """Test that attention mixers with different window sizes are independent.""" + cache = Apriel2Cache(apriel2_config_multi_mixer) + + # Check that attn_small and attn_large have different window sizes + cache.set_active_mixer(0, "attn_small") + window_small = cache.get_max_cache_shape(0) + + cache.set_active_mixer(0, "attn_large") + window_large = cache.get_max_cache_shape(0) + + assert window_small == 2048 + assert window_large == 8192 diff --git a/fast_llm_external_models/tests/test_apriel2/test_model_structure.py b/fast_llm_external_models/tests/test_apriel2/test_model_structure.py new file mode 100644 index 00000000..86bcc661 --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_model_structure.py @@ -0,0 +1,122 @@ +"""Tests for Apriel2 model structure and architecture validation.""" + +import pytest +import torch +from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM +from fast_llm_external_models.apriel2.cache import Apriel2Cache, _AttentionCache, _SSMCache + + +class TestStochasticMixerStructure: + """Validate stochastic mixer architecture matches configuration.""" + + def test_all_submixers_present(self, apriel2_config_all_mixers): + """Stochastic layer contains all 4 configured sub-mixers.""" + model = Apriel2ForCausalLM(apriel2_config_all_mixers) + stochastic_layer = model.model.layers[1] # Layer 1 is the "all_mixers" layer + + assert hasattr(stochastic_layer.mixer, 'mixers'), "Stochastic mixer should have 'mixers' attribute" + assert set(stochastic_layer.mixer.mixers.keys()) == { + 'attention', 'swa', 'mamba', 'gated_delta_net' + }, "Stochastic mixer should contain all 4 configured mixer types" + + # Verify each mixer is the correct type + from fast_llm_external_models.apriel2.modeling_apriel2 import ( + Apriel2Attention, Apriel2Mamba, Apriel2GatedDeltaNet + ) + + assert isinstance(stochastic_layer.mixer.mixers['attention'], Apriel2Attention) + assert isinstance(stochastic_layer.mixer.mixers['swa'], Apriel2Attention) # SWA is Apriel2Attention with sliding_window + assert isinstance(stochastic_layer.mixer.mixers['mamba'], Apriel2Mamba) + assert isinstance(stochastic_layer.mixer.mixers['gated_delta_net'], Apriel2GatedDeltaNet) + + def test_main_mixer_is_configured(self, apriel2_config_all_mixers): + """Verify main_mixer_name is set correctly.""" + model = Apriel2ForCausalLM(apriel2_config_all_mixers) + stochastic_layer = model.model.layers[1] + + assert stochastic_layer.mixer.main_mixer_name == "attention" + assert stochastic_layer.mixer.main_mixer_name in stochastic_layer.mixer.mixers + + def test_cache_has_all_submixer_slots(self, apriel2_config_all_mixers): + """Cache for stochastic layer has dict with all mixer slots.""" + cache = Apriel2Cache(apriel2_config_all_mixers) + layer_cache = cache.layers[1] # stochastic layer + + assert isinstance(layer_cache, dict), "Stochastic layer cache should be a dict" + assert set(layer_cache.keys()) == { + 'attention', 'swa', 'mamba', 'gated_delta_net' + }, "Cache should have slots for all 4 mixers" + + def test_attention_mixers_use_attention_cache(self, apriel2_config_all_mixers): + """Attention and SWA use _AttentionCache, SSMs use _SSMCache.""" + cache = Apriel2Cache(apriel2_config_all_mixers) + layer_cache = cache.layers[1] + + # Attention-based mixers use AttentionCache + assert isinstance(layer_cache['attention'], _AttentionCache) + assert isinstance(layer_cache['swa'], _AttentionCache) + + # SSM-based mixers use SSMCache + assert isinstance(layer_cache['mamba'], _SSMCache) + assert isinstance(layer_cache['gated_delta_net'], _SSMCache) + + def test_parameter_counts_differ_by_config(self): + """Different configs create models with different parameter counts.""" + from fast_llm_external_models.apriel2.configuration_apriel2 import Apriel2Config + + config_tiny = Apriel2Config( + vocab_size=100, hidden_size=64, num_hidden_layers=2, + num_attention_heads=4, num_key_value_heads=2 + ) + + config_stochastic = Apriel2Config( + vocab_size=100, hidden_size=64, num_hidden_layers=2, + num_attention_heads=4, num_key_value_heads=2, + decoder={ + "type": "pattern", + "pattern": ["attn", "stoch"], + "blocks": { + "attn": {"mixer": {"type": "attention"}}, + "stoch": { + "mixer": { + "type": "stochastic", + "main_mixer_name": "attention", + "mixers": { + "attention": {"type": "attention"}, + "mamba": {"type": "mamba", "conv_bias": True, "dt_proj_bias": True} + } + } + } + } + } + ) + + model_tiny = Apriel2ForCausalLM(config_tiny) + model_stochastic = Apriel2ForCausalLM(config_stochastic) + + params_tiny = sum(p.numel() for p in model_tiny.parameters()) + params_stochastic = sum(p.numel() for p in model_stochastic.parameters()) + + assert params_stochastic > params_tiny, \ + "Stochastic mixer should have more parameters (has both attention and mamba)" + + def test_weights_are_initialized(self, apriel2_config_all_mixers): + """Verify model weights are initialized (not all zeros/constant).""" + model = Apriel2ForCausalLM(apriel2_config_all_mixers) + + # Check that model has parameters + stochastic_layer = model.model.layers[1] + total_params = sum(p.numel() for p in stochastic_layer.mixer.parameters()) + assert total_params > 0, "Stochastic mixer should have parameters" + + # Basic sanity: at least some parameters should be non-zero + non_zero_params = sum( + not torch.all(p == 0) + for mixer in stochastic_layer.mixer.mixers.values() + for p in mixer.parameters() + ) + assert non_zero_params > 0, "At least some mixer parameters should be non-zero" + + # Note: We don't check detailed statistics because: + # - SSMs use special initialization (dt_proj uses log-spaced values, high mean) + # - Some parameters may be intentionally constant (e.g., bias terms) diff --git a/fast_llm_external_models/tests/test_apriel2/test_modeling.py b/fast_llm_external_models/tests/test_apriel2/test_modeling.py new file mode 100644 index 00000000..e9b6256c --- /dev/null +++ b/fast_llm_external_models/tests/test_apriel2/test_modeling.py @@ -0,0 +1,107 @@ +"""Tests for Apriel2 model instantiation, forward pass, and generation.""" + +import pytest +import torch +from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2ForCausalLM + + +class TestApriel2Modeling: + """End-to-end tests for Apriel2 model with different configurations.""" + + @pytest.mark.parametrize("config_name", [ + "apriel2_config_tiny", + "apriel2_config_stochastic", + "apriel2_config_multi_mixer", + "apriel2_config_all_mixers" # Tests all 4 mixer types + ]) + def test_model_end_to_end(self, config_name, request): + """Test instantiation, forward pass, cache correctness, and generation. + + This comprehensive test validates: + 1. Model can be instantiated from config + 2. Forward pass produces correct output shapes + 3. Cache is actually being used (not dormant) + 4. Cache produces numerically identical results to non-cached computation + 5. Generation works end-to-end + + The cache correctness check is critical for stochastic mixer configs, + as it validates that set_active_mixer() is called correctly and cache + routing works in the actual model (not just in isolation). + """ + config = request.getfixturevalue(config_name) + + # Use longer sequences for better cache validation + seq_len = 50 + input_ids = torch.randint(0, config.vocab_size, (2, seq_len)) + + # 1. Instantiation + model = Apriel2ForCausalLM(config) + model.eval() # Disable dropout for deterministic results + assert model is not None + + # 2. Forward pass - basic shape validation + outputs = model(input_ids, use_cache=False) + assert outputs.logits.shape == (2, seq_len, config.vocab_size) + assert hasattr(outputs, 'logits') + + # 3. Verify cache is actually being used (not dormant) + split_pos = 30 + + # Forward with correct cache + outputs_part1 = model(input_ids[:, :split_pos], use_cache=True) + assert outputs_part1.past_key_values is not None + + outputs_correct_cache = model( + input_ids[:, split_pos:split_pos+1], + past_key_values=outputs_part1.past_key_values, + use_cache=True + ) + + # Forward with WRONG cache (zeros) - should give different results if cache is used + from fast_llm_external_models.apriel2.cache import Apriel2Cache + wrong_cache = Apriel2Cache(config) + # Initialize with zeros (wrong state) + for layer_idx in range(config.num_hidden_layers): + # For attention layers, initialize empty cache + if hasattr(wrong_cache.layers[layer_idx], 'key_cache'): + wrong_cache.layers[layer_idx].key_cache = torch.zeros(2, 4, 1, 16) + wrong_cache.layers[layer_idx].value_cache = torch.zeros(2, 4, 1, 16) + + outputs_wrong_cache = model( + input_ids[:, split_pos:split_pos+1], + past_key_values=wrong_cache, + use_cache=True + ) + + # If cache is being used, wrong cache should give different results + cache_is_used = not torch.allclose( + outputs_correct_cache.logits, + outputs_wrong_cache.logits, + atol=1e-3 + ) + assert cache_is_used, f"Cache appears to be dormant for {config_name} - wrong cache gives same results as correct cache" + + # 4. Cache correctness - validate cache produces same results as no-cache + # Compute full sequence without cache + outputs_full = model(input_ids, use_cache=False) + + # Compute in two steps with cache + outputs_part1 = model(input_ids[:, :split_pos], use_cache=True) + outputs_part2 = model( + input_ids[:, split_pos:split_pos+1], + past_key_values=outputs_part1.past_key_values, + use_cache=True + ) + + # Logits should match between cached and non-cached + assert torch.allclose( + outputs_full.logits[:, split_pos, :], + outputs_part2.logits[:, 0, :], + atol=1e-5 + ), f"Cache correctness failed for {config_name}: cached and non-cached logits differ" + + # 5. Generation - end-to-end validation + prompt = input_ids[:, :10] + generated = model.generate(prompt, max_new_tokens=10, use_cache=True) + assert generated.shape == (2, 20) # 10 prompt + 10 generated + assert torch.all(generated[:, :10] == prompt) # Prompt should be preserved diff --git a/pyproject.toml b/pyproject.toml index 70900b11..c7d3ffd2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,10 @@ [tool.black] line-length = 119 target-version = ['py312'] + +[tool.pytest.ini_options] +testpaths = [ + "tests", # Fast-LLM core tests + "fast_llm_external_models/tests" # External models tests +] +norecursedirs = ["Megatron-LM"] diff --git a/setup.cfg b/setup.cfg index 77073ab5..8ec619a4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,7 +52,7 @@ HUGGINGFACE = # To install on cpu environment (ex. for IDE support): # MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = - mamba_ssm[causal-conv1d]==2.2.4 + mamba_ssm[causal-conv1d]>=2.2.4 cartesia_pytorch>=0.0.2 GENERATION = diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 3c3bfb83..f75ad5eb 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -325,7 +325,7 @@ def test_huggingface_model(model_testing_config, get_convert_path): format=DistributedCheckpointFormat, load_config=ModelConfigType.model, ) - ) + ).eval() test_input = torch.randint( 0, model_ref.config.fast_llm_config.base_model.embeddings.vocab_size, @@ -334,21 +334,21 @@ def test_huggingface_model(model_testing_config, get_convert_path): device="cuda", ) output_ref = model_ref(test_input) - model_from_fast_llm = hf_class.from_pretrained(fast_llm_path) + model_from_fast_llm = hf_class.from_pretrained(fast_llm_path).eval() model_from_hf = hf_class.from_pretrained( CheckpointLoadConfig( path=hf_path, format=model_testing_config.checkpoint_format, load_config=ModelConfigType.model, ) - ) + ).eval() errors = [] auto_model = ( transformers.AutoModel if model_testing_config.name in ("diffusion_llama", "dream") else transformers.AutoModelForCausalLM ) - model_as_hf = auto_model.from_pretrained(hf_path, trust_remote_code=True).cuda() + model_as_hf = auto_model.from_pretrained(hf_path, trust_remote_code=True).cuda().eval() for name, model in zip( ("From state dict", "From Huggingface", "Native Huggingface"), (model_from_fast_llm, model_from_hf, model_as_hf), diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index c02521d7..c9c985c1 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -12,6 +12,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.models.gpt.conversion.config import ( + Apriel2CheckpointFormat, AprielHybridSSMCheckpointFormat, DiffusionDreamCheckpointFormat, DiffusionLlamaCheckpointFormat, @@ -694,6 +695,97 @@ def _update_and_add_testing_config( ) +_update_and_add_testing_config( + # Tests apriel2 format with pattern decoder mixing all mixer types. + # This comprehensive test exercises: attention, mamba, stochastic mixer, sliding window attention. + "llama", + "apriel2", + updates={ + ("model", "base_model", "tied_embedding_weight"): True, + ("model", "base_model", "decoder"): { + "type": "pattern", + "blocks": { + "attn_full": { + **copy.deepcopy(_llama_block), + "mixer": { + "type": "attention", + "rotary": {"type": "default", "theta": 10000}, + "heads": 8, + "head_groups": 4, + "head_size": 32, + "add_linear_biases": False, + }, + }, + "mamba": { + **copy.deepcopy(_llama_block), + "mixer": { + "type": "mamba_2", + "d_inner": 512, + "state_size": 16, + "dt_rank": 16, + "d_xb": 256, + "add_linear_biases": False, + }, + }, + "stochastic": { + **copy.deepcopy(_llama_block), + "mixer": { + "type": "stochastic", + "mixers": { + "attn": { + "type": "attention", + "rotary": {"type": "default", "theta": 10000}, + "heads": 8, + "head_groups": 4, + "head_size": 32, + "add_linear_biases": False, + }, + "mamba": { + "type": "mamba_2", + "d_inner": 512, + "state_size": 16, + "dt_rank": 16, + "d_xb": 256, + "add_linear_biases": False, + }, + }, + "sampling_strategy": "uniform", + "main_mixer_name": "attn", + }, + }, + "attn_swa": { + **copy.deepcopy(_llama_block), + "mixer": { + "type": "attention", + "rotary": {"type": "default", "theta": 10000}, + "heads": 8, + "head_groups": 4, + "head_size": 32, + "window_size": 128, + "add_linear_biases": False, + }, + }, + }, + "pattern": ["attn_full", "mamba", "stochastic", "attn_swa"], + "num_blocks": 4, + }, + }, + megatron_args=None, + checkpoint_format=Apriel2CheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + }, + compare_factor=2.0, + # Micro-sequence split not supported for Mamba. + skip_tests=("sdp", "ms"), +) + + @pytest.fixture(scope="session", params=MODEL_CONFIGS.keys()) def model_testing_config(request) -> ModelTestingConfig: models = request.config.getoption("--models")