Skip to content

Commit 4212497

Browse files
authored
fix visualglm predict (#1214)
1 parent c5d00f9 commit 4212497

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

paddlemix/models/visualglm/modeling.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,19 @@ class VisualGLMPretrainedModel(MixPretrainedModel):
114114

115115
def _init_weights(self, module):
116116
"""Initialize the weights"""
117+
118+
def trunc_normal_(tensor, mean=0.0, std=1.0, min=-2, max=2):
119+
origin_dtype = paddle.get_default_dtype()
120+
paddle.set_default_dtype("float32")
121+
with paddle.no_grad():
122+
normal = paddle.normal(mean=mean, std=std, shape=tensor.shape)
123+
trunc = paddle.clip(normal, min=min, max=max)
124+
if origin_dtype != "float32":
125+
trunc = trunc.astype(origin_dtype)
126+
tensor.set_value(trunc)
127+
paddle.set_default_dtype(origin_dtype)
128+
return tensor
129+
117130
factor = self.config.initializer_range
118131
if isinstance(module, nn.Conv2D) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
119132
normal_(module.weight, mean=0.0, std=factor)
@@ -123,11 +136,12 @@ def _init_weights(self, module):
123136
if isinstance(module, VisualGLMVisionEmbeddings):
124137
if hasattr(self.config, "vision_config"):
125138
factor = self.config.vision_config.initializer_range
126-
trunc_normal_ = nn.initializer.TruncatedNormal(mean=0.0, std=factor)
139+
127140
trunc_normal_(module.position_embedding)
128141
trunc_normal_(
129142
module.class_embedding,
130143
)
144+
131145
elif isinstance(module, nn.LayerNorm):
132146
zeros_(module.bias)
133147
ones_(module.weight)
@@ -588,7 +602,7 @@ def forward(
588602

589603
if attention_mask is not None:
590604
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
591-
attention_scores = attention_scores + attention_mask
605+
attention_scores = paddle.cast((attention_scores + attention_mask), attention_scores.dtype)
592606

593607
# Normalize the attention scores to probabilities.
594608
attention_probs = nn.Softmax(axis=-1)(attention_scores)

0 commit comments

Comments
 (0)