@@ -126,11 +126,9 @@ def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: Gemma3Tran
126126
127127 def context_forward (self , input_embdings , infer_state : InferStateInfo , layer_weight ):
128128 input_embdings = input_embdings .to (torch .bfloat16 )
129- # if self.layer_num_ == 0: print('0: layer_input_before_norm', input_embdings)
130129 input1 = self ._att_norm (input_embdings .view (- 1 , self .embed_dim_ ).float (), infer_state , layer_weight ).to (
131130 torch .bfloat16
132131 )
133- # if self.layer_num_ == 0: print('0: layer_input_after_norm', input1)
134132 cache_kv = self ._pre_cache_kv (infer_state , layer_weight )
135133 q , cache_kv = self ._get_qkv (input1 , cache_kv , infer_state , layer_weight )
136134 input1 = None
@@ -141,20 +139,15 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei
141139 if self .tp_world_size_ > 1 :
142140 all_reduce (o , op = dist .ReduceOp .SUM , group = infer_state .dist_group , async_op = False )
143141 o = self ._ffn_norm (o .float (), infer_state , layer_weight ).to (torch .bfloat16 )
144- # if self.layer_num_ == 0: print("0:o_after_norm", o)
145142 input_embdings .add_ (o .view (- 1 , self .embed_dim_ ))
146143 o = None
147144
148- # if self.layer_num_ == 0: print("0:ffn_hidden_before_norm", input_embdings)
149145 input1 = self ._pre_feedforward_layernorm (input_embdings .float (), infer_state , layer_weight ).to (torch .bfloat16 )
150- # if self.layer_num_ == 0: print("0:ffn_hidden_after_norm", input1)
151146 ffn_out = self ._ffn (input1 , infer_state , layer_weight )
152147 input1 = None
153148 if self .tp_world_size_ > 1 :
154149 all_reduce (ffn_out , op = dist .ReduceOp .SUM , group = infer_state .dist_group , async_op = False )
155- # if self.layer_num_ == 0: print("0:ffn_out", ffn_out)
156150 ffn_out = self ._post_feedforward_layernorm (ffn_out .float (), infer_state , layer_weight ).to (torch .bfloat16 )
157- # if self.layer_num_ == 0: print("0:ffn_out_after_norm", ffn_out)
158151 input_embdings .add_ (ffn_out .view (- 1 , self .embed_dim_ ))
159152 return input_embdings
160153
0 commit comments