11# SPDX-License-Identifier: Apache-2.0
22from dataclasses import dataclass , field
3- from typing import Optional
3+ from typing import List , Optional , Tuple
44
55from fastvideo .v1 .configs .models .encoders .base import (ImageEncoderArchConfig ,
66 ImageEncoderConfig ,
77 TextEncoderArchConfig ,
88 TextEncoderConfig )
99
1010
11+ def _is_transformer_layer (n : str , m ) -> bool :
12+ return "layers" in n and str .isdigit (n .split ("." )[- 1 ])
13+
14+
15+ def _is_embeddings (n : str , m ) -> bool :
16+ return n .endswith ("embeddings" )
17+
18+
1119@dataclass
1220class CLIPTextArchConfig (TextEncoderArchConfig ):
1321 vocab_size : int = 49408
@@ -27,6 +35,15 @@ class CLIPTextArchConfig(TextEncoderArchConfig):
2735 bos_token_id : int = 49406
2836 eos_token_id : int = 49407
2937 text_len : int = 77
38+ stacked_params_mapping : List [Tuple [str , str ,
39+ str ]] = field (default_factory = lambda : [
40+ # (param_name, shard_name, shard_id)
41+ ("qkv_proj" , "q_proj" , "q" ),
42+ ("qkv_proj" , "k_proj" , "k" ),
43+ ("qkv_proj" , "v_proj" , "v" ),
44+ ])
45+ _fsdp_shard_conditions : list = field (
46+ default_factory = lambda : [_is_transformer_layer , _is_embeddings ])
3047
3148
3249@dataclass
@@ -45,6 +62,13 @@ class CLIPVisionArchConfig(ImageEncoderArchConfig):
4562 attention_dropout : float = 0.0
4663 initializer_range : float = 0.02
4764 initializer_factor : float = 1.0
65+ stacked_params_mapping : List [Tuple [str , str ,
66+ str ]] = field (default_factory = lambda : [
67+ # (param_name, shard_name, shard_id)
68+ ("qkv_proj" , "q_proj" , "q" ),
69+ ("qkv_proj" , "k_proj" , "k" ),
70+ ("qkv_proj" , "v_proj" , "v" ),
71+ ])
4872
4973
5074@dataclass
0 commit comments