11# SPDX-License-Identifier: Apache-2.0
22from dataclasses import dataclass , field
3- from typing import Optional
3+ from typing import List , Optional , Tuple
44
55from fastvideo .v1 .configs .models .encoders .base import (TextEncoderArchConfig ,
66 TextEncoderConfig )
77
88
9+ def _is_transformer_layer (n : str , m ) -> bool :
10+ return "layers" in n and str .isdigit (n .split ("." )[- 1 ])
11+
12+
13+ def _is_embeddings (n : str , m ) -> bool :
14+ return n .endswith ("embed_tokens" )
15+
16+
17+ def _is_final_norm (n : str , m ) -> bool :
18+ return n .endswith ("norm" )
19+
20+
921@dataclass
1022class LlamaArchConfig (TextEncoderArchConfig ):
1123 vocab_size : int = 32000
@@ -32,6 +44,18 @@ class LlamaArchConfig(TextEncoderArchConfig):
3244 head_dim : Optional [int ] = None
3345 hidden_state_skip_layer : int = 2
3446 text_len : int = 256
47+ stacked_params_mapping : List [Tuple [str , str , str ]] = field (
48+ default_factory = lambda : [
49+ # (param_name, shard_name, shard_id)
50+ (".qkv_proj" , ".q_proj" , "q" ),
51+ (".qkv_proj" , ".k_proj" , "k" ),
52+ (".qkv_proj" , ".v_proj" , "v" ),
53+ (".gate_up_proj" , ".gate_proj" , 0 ), # type: ignore
54+ (".gate_up_proj" , ".up_proj" , 1 ), # type: ignore
55+ ])
56+ _fsdp_shard_conditions : list = field (
57+ default_factory = lambda :
58+ [_is_transformer_layer , _is_embeddings , _is_final_norm ])
3559
3660
3761@dataclass
0 commit comments