Skip to content

Commit 9e3d2cc

Browse files
author
fzilan
committed
tmp: revert clip
1 parent 66e351e commit 9e3d2cc

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

mindone/transformers/models/clip/modeling_clip.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,10 @@ def __init__(self, config):
236236
self.scale = self.head_dim**-0.5
237237
self.dropout = config.attention_dropout
238238

239-
self.k_proj = mint.nn.Linear(self.embed_dim, self.embed_dim)
240-
self.v_proj = mint.nn.Linear(self.embed_dim, self.embed_dim)
241-
self.q_proj = mint.nn.Linear(self.embed_dim, self.embed_dim)
242-
self.out_proj = mint.nn.Linear(self.embed_dim, self.embed_dim)
239+
self.k_proj = nn.Dense(self.embed_dim, self.embed_dim)
240+
self.v_proj = nn.Dense(self.embed_dim, self.embed_dim)
241+
self.q_proj = nn.Dense(self.embed_dim, self.embed_dim)
242+
self.out_proj = nn.Dense(self.embed_dim, self.embed_dim)
243243

244244
def _shape(self, tensor: ms.Tensor, seq_len: int, bsz: int):
245245
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).swapaxes(1, 2)
@@ -328,8 +328,8 @@ def __init__(self, config):
328328
super().__init__()
329329
self.config = config
330330
self.activation_fn = ACT2FN[config.hidden_act]
331-
self.fc1 = mint.nn.Linear(config.hidden_size, config.intermediate_size)
332-
self.fc2 = mint.nn.Linear(config.intermediate_size, config.hidden_size)
331+
self.fc1 = nn.Dense(config.hidden_size, config.intermediate_size)
332+
self.fc2 = nn.Dense(config.intermediate_size, config.hidden_size)
333333

334334
def construct(self, hidden_states: ms.Tensor) -> ms.Tensor:
335335
hidden_states = self.fc1(hidden_states)
@@ -777,8 +777,8 @@ def __init__(self, config: CLIPConfig):
777777
self.text_model = CLIPTextTransformer(text_config)
778778
self.vision_model = CLIPVisionTransformer(vision_config)
779779

780-
self.visual_projection = mint.nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
781-
self.text_projection = mint.nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
780+
self.visual_projection = nn.Dense(self.vision_embed_dim, self.projection_dim, has_bias=False)
781+
self.text_projection = nn.Dense(self.text_embed_dim, self.projection_dim, has_bias=False)
782782
self.logit_scale = ms.Parameter(ms.Tensor(self.logit_scale_init_value), name="logit_scale")
783783

784784
# Initialize weights and apply final processing
@@ -977,7 +977,7 @@ def __init__(self, config: CLIPTextConfig):
977977

978978
self.text_model = CLIPTextTransformer(config)
979979

980-
self.text_projection = mint.nn.Linear(config.hidden_size, config.projection_dim, bias=False)
980+
self.text_projection = nn.Dense(config.hidden_size, config.projection_dim, has_bias=False)
981981

982982
# Initialize weights and apply final processing
983983
self.post_init()
@@ -1051,7 +1051,7 @@ def __init__(self, config: CLIPVisionConfig):
10511051

10521052
self.vision_model = CLIPVisionTransformer(config)
10531053

1054-
self.visual_projection = mint.nn.Linear(config.hidden_size, config.projection_dim, bias=False)
1054+
self.visual_projection = nn.Dense(config.hidden_size, config.projection_dim, has_bias=False)
10551055

10561056
# Initialize weights and apply final processing
10571057
self.post_init()

0 commit comments

Comments
 (0)