@@ -236,10 +236,10 @@ def __init__(self, config):
236236 self .scale = self .head_dim ** - 0.5
237237 self .dropout = config .attention_dropout
238238
239- self .k_proj = mint . nn .Linear (self .embed_dim , self .embed_dim )
240- self .v_proj = mint . nn .Linear (self .embed_dim , self .embed_dim )
241- self .q_proj = mint . nn .Linear (self .embed_dim , self .embed_dim )
242- self .out_proj = mint . nn .Linear (self .embed_dim , self .embed_dim )
239+ self .k_proj = nn .Dense (self .embed_dim , self .embed_dim )
240+ self .v_proj = nn .Dense (self .embed_dim , self .embed_dim )
241+ self .q_proj = nn .Dense (self .embed_dim , self .embed_dim )
242+ self .out_proj = nn .Dense (self .embed_dim , self .embed_dim )
243243
244244 def _shape (self , tensor : ms .Tensor , seq_len : int , bsz : int ):
245245 return tensor .view (bsz , seq_len , self .num_heads , self .head_dim ).swapaxes (1 , 2 )
@@ -328,8 +328,8 @@ def __init__(self, config):
328328 super ().__init__ ()
329329 self .config = config
330330 self .activation_fn = ACT2FN [config .hidden_act ]
331- self .fc1 = mint . nn .Linear (config .hidden_size , config .intermediate_size )
332- self .fc2 = mint . nn .Linear (config .intermediate_size , config .hidden_size )
331+ self .fc1 = nn .Dense (config .hidden_size , config .intermediate_size )
332+ self .fc2 = nn .Dense (config .intermediate_size , config .hidden_size )
333333
334334 def construct (self , hidden_states : ms .Tensor ) -> ms .Tensor :
335335 hidden_states = self .fc1 (hidden_states )
@@ -777,8 +777,8 @@ def __init__(self, config: CLIPConfig):
777777 self .text_model = CLIPTextTransformer (text_config )
778778 self .vision_model = CLIPVisionTransformer (vision_config )
779779
780- self .visual_projection = mint . nn .Linear (self .vision_embed_dim , self .projection_dim , bias = False )
781- self .text_projection = mint . nn .Linear (self .text_embed_dim , self .projection_dim , bias = False )
780+ self .visual_projection = nn .Dense (self .vision_embed_dim , self .projection_dim , has_bias = False )
781+ self .text_projection = nn .Dense (self .text_embed_dim , self .projection_dim , has_bias = False )
782782 self .logit_scale = ms .Parameter (ms .Tensor (self .logit_scale_init_value ), name = "logit_scale" )
783783
784784 # Initialize weights and apply final processing
@@ -977,7 +977,7 @@ def __init__(self, config: CLIPTextConfig):
977977
978978 self .text_model = CLIPTextTransformer (config )
979979
980- self .text_projection = mint . nn .Linear (config .hidden_size , config .projection_dim , bias = False )
980+ self .text_projection = nn .Dense (config .hidden_size , config .projection_dim , has_bias = False )
981981
982982 # Initialize weights and apply final processing
983983 self .post_init ()
@@ -1051,7 +1051,7 @@ def __init__(self, config: CLIPVisionConfig):
10511051
10521052 self .vision_model = CLIPVisionTransformer (config )
10531053
1054- self .visual_projection = mint . nn .Linear (config .hidden_size , config .projection_dim , bias = False )
1054+ self .visual_projection = nn .Dense (config .hidden_size , config .projection_dim , has_bias = False )
10551055
10561056 # Initialize weights and apply final processing
10571057 self .post_init ()
0 commit comments