Skip to content

Commit ccb6cc1

Browse files
Update transformer_layer_infer.py
1 parent 24176b3 commit ccb6cc1

File tree

1 file changed

+0
-7
lines changed

1 file changed

+0
-7
lines changed

lightllm/models/gemma3/layer_infer/transformer_layer_infer.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,9 @@ def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: Gemma3Tran
126126

127127
def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
128128
input_embdings = input_embdings.to(torch.bfloat16)
129-
# if self.layer_num_ == 0: print('0: layer_input_before_norm', input_embdings)
130129
input1 = self._att_norm(input_embdings.view(-1, self.embed_dim_).float(), infer_state, layer_weight).to(
131130
torch.bfloat16
132131
)
133-
# if self.layer_num_ == 0: print('0: layer_input_after_norm', input1)
134132
cache_kv = self._pre_cache_kv(infer_state, layer_weight)
135133
q, cache_kv = self._get_qkv(input1, cache_kv, infer_state, layer_weight)
136134
input1 = None
@@ -141,20 +139,15 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei
141139
if self.tp_world_size_ > 1:
142140
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
143141
o = self._ffn_norm(o.float(), infer_state, layer_weight).to(torch.bfloat16)
144-
# if self.layer_num_ == 0: print("0:o_after_norm", o)
145142
input_embdings.add_(o.view(-1, self.embed_dim_))
146143
o = None
147144

148-
# if self.layer_num_ == 0: print("0:ffn_hidden_before_norm", input_embdings)
149145
input1 = self._pre_feedforward_layernorm(input_embdings.float(), infer_state, layer_weight).to(torch.bfloat16)
150-
# if self.layer_num_ == 0: print("0:ffn_hidden_after_norm", input1)
151146
ffn_out = self._ffn(input1, infer_state, layer_weight)
152147
input1 = None
153148
if self.tp_world_size_ > 1:
154149
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
155-
# if self.layer_num_ == 0: print("0:ffn_out", ffn_out)
156150
ffn_out = self._post_feedforward_layernorm(ffn_out.float(), infer_state, layer_weight).to(torch.bfloat16)
157-
# if self.layer_num_ == 0: print("0:ffn_out_after_norm", ffn_out)
158151
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
159152
return input_embdings
160153

0 commit comments

Comments
 (0)