|
11 | 11 | from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ |
12 | 12 |
|
13 | 13 | """ |
14 | | -Note: this is mostly addapted from https://github.com/Zyphra/Zamba2, similar code is aslo in https://github.com/state-spaces/mamba. |
| 14 | +Note: this is mostly adapted from https://github.com/Zyphra/Zamba2, similar code is also in https://github.com/state-spaces/mamba. |
15 | 15 | For now it only supports training and not inference. |
16 | 16 | This works with triton 3.1.0 |
17 | 17 | """ |
|
20 | 20 | def init_A(d_state, d_inner) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: |
21 | 21 | def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa |
22 | 22 | # S4D real initialization |
23 | | - # TODO: adopt this innitialization to work for tensor parallel setting! |
| 23 | + # TODO: adopt this initialization to work for tensor parallel setting! |
24 | 24 | A = einops.repeat(torch.arange(1, d_state + 1, dtype=torch.float32), "n -> d n", d=d_inner).contiguous() |
25 | 25 | A_log = torch.log(A) # Keep A_log in fp32 |
26 | 26 | if tensor.shape != A_log.shape: |
@@ -106,7 +106,7 @@ def __init__( |
106 | 106 | ) |
107 | 107 | self.x_proj.weight.auto_grad_accumulation = True |
108 | 108 |
|
109 | | - # TODO: the weights are innitialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 |
| 109 | + # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 |
110 | 110 | self.dt_proj_weight = ParameterMeta.from_dims( |
111 | 111 | (td_inner, tdt_rank), |
112 | 112 | init_method=kaiming_init_(tdt_rank.size), |
|
0 commit comments