11# SPDX-License-Identifier: Apache-2.0
22from dataclasses import dataclass , field
3- from typing import List , Optional , Tuple
3+ from typing import Optional
44
55from fastvideo .v1 .configs .models .encoders .base import (ImageEncoderArchConfig ,
66 ImageEncoderConfig ,
77 TextEncoderArchConfig ,
88 TextEncoderConfig )
99
1010
11- def _is_transformer_layer (n : str , m ) -> bool :
12- return "layers" in n and str .isdigit (n .split ("." )[- 1 ])
13-
14-
15- def _is_embeddings (n : str , m ) -> bool :
16- return n .endswith ("embeddings" )
17-
18-
1911@dataclass
2012class CLIPTextArchConfig (TextEncoderArchConfig ):
2113 vocab_size : int = 49408
@@ -35,15 +27,6 @@ class CLIPTextArchConfig(TextEncoderArchConfig):
3527 bos_token_id : int = 49406
3628 eos_token_id : int = 49407
3729 text_len : int = 77
38- stacked_params_mapping : List [Tuple [str , str ,
39- str ]] = field (default_factory = lambda : [
40- # (param_name, shard_name, shard_id)
41- ("qkv_proj" , "q_proj" , "q" ),
42- ("qkv_proj" , "k_proj" , "k" ),
43- ("qkv_proj" , "v_proj" , "v" ),
44- ])
45- _fsdp_shard_conditions : list = field (
46- default_factory = lambda : [_is_transformer_layer , _is_embeddings ])
4730
4831
4932@dataclass
@@ -62,13 +45,6 @@ class CLIPVisionArchConfig(ImageEncoderArchConfig):
6245 attention_dropout : float = 0.0
6346 initializer_range : float = 0.02
6447 initializer_factor : float = 1.0
65- stacked_params_mapping : List [Tuple [str , str ,
66- str ]] = field (default_factory = lambda : [
67- # (param_name, shard_name, shard_id)
68- ("qkv_proj" , "q_proj" , "q" ),
69- ("qkv_proj" , "k_proj" , "k" ),
70- ("qkv_proj" , "v_proj" , "v" ),
71- ])
7248
7349
7450@dataclass
0 commit comments