@@ -114,6 +114,19 @@ class VisualGLMPretrainedModel(MixPretrainedModel):
114114
115115 def _init_weights (self , module ):
116116 """Initialize the weights"""
117+
118+ def trunc_normal_ (tensor , mean = 0.0 , std = 1.0 , min = - 2 , max = 2 ):
119+ origin_dtype = paddle .get_default_dtype ()
120+ paddle .set_default_dtype ("float32" )
121+ with paddle .no_grad ():
122+ normal = paddle .normal (mean = mean , std = std , shape = tensor .shape )
123+ trunc = paddle .clip (normal , min = min , max = max )
124+ if origin_dtype != "float32" :
125+ trunc = trunc .astype (origin_dtype )
126+ tensor .set_value (trunc )
127+ paddle .set_default_dtype (origin_dtype )
128+ return tensor
129+
117130 factor = self .config .initializer_range
118131 if isinstance (module , nn .Conv2D ) or isinstance (module , nn .Embedding ) or isinstance (module , nn .Linear ):
119132 normal_ (module .weight , mean = 0.0 , std = factor )
@@ -123,11 +136,12 @@ def _init_weights(self, module):
123136 if isinstance (module , VisualGLMVisionEmbeddings ):
124137 if hasattr (self .config , "vision_config" ):
125138 factor = self .config .vision_config .initializer_range
126- trunc_normal_ = nn . initializer . TruncatedNormal ( mean = 0.0 , std = factor )
139+
127140 trunc_normal_ (module .position_embedding )
128141 trunc_normal_ (
129142 module .class_embedding ,
130143 )
144+
131145 elif isinstance (module , nn .LayerNorm ):
132146 zeros_ (module .bias )
133147 ones_ (module .weight )
@@ -588,7 +602,7 @@ def forward(
588602
589603 if attention_mask is not None :
590604 # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
591- attention_scores = attention_scores + attention_mask
605+ attention_scores = paddle . cast (( attention_scores + attention_mask ), attention_scores . dtype )
592606
593607 # Normalize the attention scores to probabilities.
594608 attention_probs = nn .Softmax (axis = - 1 )(attention_scores )
0 commit comments