Skip to content

Commit f847627

Browse files
committed
add moonshine stt model
1 parent 5fac1de commit f847627

File tree

6 files changed

+751
-0
lines changed

6 files changed

+751
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .moonshine import Model, ModelConfig
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import inspect
2+
from dataclasses import dataclass
3+
from typing import Any, Dict, Optional
4+
5+
6+
@dataclass
7+
class ModelConfig:
8+
model_type: str = "moonshine"
9+
vocab_size: int = 32768
10+
hidden_size: int = 288
11+
intermediate_size: int = 1152
12+
encoder_num_hidden_layers: int = 6
13+
decoder_num_hidden_layers: int = 6
14+
encoder_num_attention_heads: int = 8
15+
decoder_num_attention_heads: int = 8
16+
encoder_num_key_value_heads: Optional[int] = None
17+
decoder_num_key_value_heads: Optional[int] = None
18+
encoder_hidden_act: str = "gelu"
19+
decoder_hidden_act: str = "silu"
20+
max_position_embeddings: int = 512
21+
attention_bias: bool = False
22+
attention_dropout: float = 0.0
23+
partial_rotary_factor: float = 0.9
24+
rope_theta: float = 10000.0
25+
bos_token_id: int = 1
26+
eos_token_id: int = 2
27+
decoder_start_token_id: int = 1
28+
tie_word_embeddings: bool = True
29+
pad_head_dim_to_multiple_of: Optional[int] = None
30+
31+
def __post_init__(self):
32+
if self.encoder_num_key_value_heads is None:
33+
self.encoder_num_key_value_heads = self.encoder_num_attention_heads
34+
if self.decoder_num_key_value_heads is None:
35+
self.decoder_num_key_value_heads = self.decoder_num_attention_heads
36+
37+
@classmethod
38+
def from_dict(cls, params: Dict[str, Any]) -> "ModelConfig":
39+
return cls(
40+
**{
41+
k: v
42+
for k, v in params.items()
43+
if k in inspect.signature(cls).parameters
44+
}
45+
)

0 commit comments

Comments
 (0)