-
Notifications
You must be signed in to change notification settings - Fork 404
[Enhance] concat micro batch inputs before emb's and lm_head's forward #1409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -326,24 +326,19 @@ def _micro_batch_forward( | |||||
| assert len(seq_ctx_list) == len(loss_ctx_list), "seq_ctx and loss_ctx must have same length" | ||||||
|
|
||||||
| # Prepare input embeddings for all micro-batches | ||||||
| hidden_states_list: list[torch.Tensor] = [] | ||||||
| position_embeddings_list = [] | ||||||
|
|
||||||
| for ctx in seq_ctx_list: | ||||||
| input_ids = ctx.input_ids | ||||||
| position_ids = ctx.position_ids | ||||||
|
|
||||||
| if input_ids is not None: | ||||||
| hidden_states = self.embed_tokens(input_ids) | ||||||
| else: | ||||||
| hidden_states = ctx.inputs_embeds | ||||||
|
|
||||||
| # create position embeddings to be shared across the decoder layers | ||||||
| assert position_ids is not None | ||||||
| position_embeddings = self.rotary_emb(hidden_states, position_ids) | ||||||
|
|
||||||
| hidden_states_list.append(hidden_states) | ||||||
| position_embeddings_list.append(position_embeddings) | ||||||
| if seq_ctx_list[0].input_ids is None: | ||||||
HIT-cwh marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| cat_hidden_states = torch.cat([ctx.inputs_embeds for ctx in seq_ctx_list], dim=1) | ||||||
| else: | ||||||
HIT-cwh marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| cat_input_ids = torch.cat([ctx.input_ids for ctx in seq_ctx_list], dim=1) | ||||||
| cat_hidden_states = self.embed_tokens(cat_input_ids) | ||||||
| cat_position_ids = torch.cat([ctx.position_ids for ctx in seq_ctx_list], dim=1) | ||||||
HIT-cwh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| cat_position_embeddings = self.rotary_emb(cat_hidden_states, cat_position_ids) | ||||||
| position_embeddings_list = list( | ||||||
| zip( | ||||||
| cat_position_embeddings[0].chunk(len(seq_ctx_list), dim=1), | ||||||
| cat_position_embeddings[1].chunk(len(seq_ctx_list), dim=1), | ||||||
| ) | ||||||
| ) | ||||||
|
|
||||||
| # Initialize output containers | ||||||
| output: dict = {} | ||||||
|
|
@@ -353,28 +348,22 @@ def _micro_batch_forward( | |||||
|
|
||||||
| # Process through layers | ||||||
| cat_seq_ctx: SequenceContext | None = None | ||||||
| cat_position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None | ||||||
| cat_hidden_states: torch.Tensor | None = None | ||||||
|
|
||||||
| moe_forawrd = False | ||||||
HIT-cwh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| for idx, decoder_layer in self.layers.items(): | ||||||
| layer_idx = int(idx) | ||||||
|
|
||||||
| if layer_idx < self.config.first_k_dense_replace: | ||||||
| if cat_seq_ctx is None: | ||||||
| cat_seq_ctx = SequenceContext.pack(seq_ctx_list) | ||||||
| cos = torch.cat([pe[0] for pe in position_embeddings_list], dim=1) | ||||||
| sin = torch.cat([pe[1] for pe in position_embeddings_list], dim=1) | ||||||
| cat_position_embeddings = (cos, sin) | ||||||
| cat_hidden_states = torch.cat(hidden_states_list, dim=1) | ||||||
| cat_seq_ctx = SequenceContext.cat(seq_ctx_list) | ||||||
| # Dense decoder layer - process concated hidden states | ||||||
| cat_hidden_states = decoder_layer( | ||||||
| cat_hidden_states, | ||||||
| position_embeddings=cat_position_embeddings, | ||||||
| seq_ctx=cat_seq_ctx, | ||||||
| ) | ||||||
| else: | ||||||
| if cat_hidden_states is not None and not moe_forawrd: | ||||||
| if not moe_forawrd: | ||||||
HIT-cwh marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| # TODO: `i.clone()` here is weird. However, the current Implementation of | ||||||
| # `async_save_on_cpu` is not friendly with `chunk` op (maybe caused by shared storage? not sure), | ||||||
| # resulting in nan grad norm. So we have to clone the chunked tensors here to make sure each | ||||||
|
|
@@ -415,25 +404,24 @@ def _micro_batch_forward( | |||||
| router_weights_list[i][f"layer{idx}"] = router_weights[i] | ||||||
|
|
||||||
| # Apply final norm to all micro-batches | ||||||
| for i, hidden_states in enumerate(hidden_states_list): | ||||||
| hidden_states_list[i] = self.norm(hidden_states) | ||||||
| cat_hidden_states = torch.cat(hidden_states_list, dim=1) | ||||||
HIT-cwh marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| cat_hidden_states = self.norm(cat_hidden_states) | ||||||
|
|
||||||
| # Process final outputs for each micro-batch | ||||||
| loss_list: list[torch.Tensor] = [] | ||||||
| logits_list: list[torch.Tensor] = [] | ||||||
| moe_extra_info = ModelForwardExtraLogInfo() | ||||||
| for hidden_states, loss_ctx_single in zip(hidden_states_list, loss_ctx_list): | ||||||
| loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx_single) # type: ignore | ||||||
| loss_list.append(loss) | ||||||
| if logits is not None: | ||||||
| logits_list.append(logits) | ||||||
| if extra_info: | ||||||
| moe_extra_info.append(extra_info) | ||||||
| cat_loss_ctx = CELossContext.cat(loss_ctx_list) | ||||||
| loss, (logits, extra_info) = self.lm_head(cat_hidden_states, cat_loss_ctx) # type: ignore | ||||||
|
|
||||||
| # Aggregate losses (mean across micro-batches) | ||||||
|
||||||
| # Aggregate losses (mean across micro-batches) | |
| # Aggregate loss value (using sum across micro-batches or scalar loss as returned) |
Outdated
Copilot
AI
Jan 4, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type annotation loss: torch.Tensor on line 415 is redundant since loss was already assigned on line 412. This type annotation doesn't provide any new information and could be confusing as it appears between the assignment and usage of the variable. Consider removing this redundant annotation.
| loss: torch.Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please modify the type annotation of LMHead.__call__, the type of loss could be inferred as toch.Tensor automatically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If loss_ctx is not None, the call method will return a torch.Tensor; otherwise, it will return None. How should I modify the typehint? @HAOCHENYE
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the variable final_logits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When constructing MoEModelOutputs, final_logits will be utilized.
Uh oh!
There was an error while loading. Please reload this page.