|
| 1 | +import inspect |
| 2 | +from dataclasses import dataclass |
| 3 | +from typing import Any, Dict, Optional |
| 4 | + |
| 5 | + |
| 6 | +@dataclass |
| 7 | +class ModelConfig: |
| 8 | + model_type: str = "moonshine" |
| 9 | + vocab_size: int = 32768 |
| 10 | + hidden_size: int = 288 |
| 11 | + intermediate_size: int = 1152 |
| 12 | + encoder_num_hidden_layers: int = 6 |
| 13 | + decoder_num_hidden_layers: int = 6 |
| 14 | + encoder_num_attention_heads: int = 8 |
| 15 | + decoder_num_attention_heads: int = 8 |
| 16 | + encoder_num_key_value_heads: Optional[int] = None |
| 17 | + decoder_num_key_value_heads: Optional[int] = None |
| 18 | + encoder_hidden_act: str = "gelu" |
| 19 | + decoder_hidden_act: str = "silu" |
| 20 | + max_position_embeddings: int = 512 |
| 21 | + attention_bias: bool = False |
| 22 | + attention_dropout: float = 0.0 |
| 23 | + partial_rotary_factor: float = 0.9 |
| 24 | + rope_theta: float = 10000.0 |
| 25 | + bos_token_id: int = 1 |
| 26 | + eos_token_id: int = 2 |
| 27 | + decoder_start_token_id: int = 1 |
| 28 | + tie_word_embeddings: bool = True |
| 29 | + pad_head_dim_to_multiple_of: Optional[int] = None |
| 30 | + |
| 31 | + def __post_init__(self): |
| 32 | + if self.encoder_num_key_value_heads is None: |
| 33 | + self.encoder_num_key_value_heads = self.encoder_num_attention_heads |
| 34 | + if self.decoder_num_key_value_heads is None: |
| 35 | + self.decoder_num_key_value_heads = self.decoder_num_attention_heads |
| 36 | + |
| 37 | + @classmethod |
| 38 | + def from_dict(cls, params: Dict[str, Any]) -> "ModelConfig": |
| 39 | + return cls( |
| 40 | + **{ |
| 41 | + k: v |
| 42 | + for k, v in params.items() |
| 43 | + if k in inspect.signature(cls).parameters |
| 44 | + } |
| 45 | + ) |
0 commit comments