-
Notifications
You must be signed in to change notification settings - Fork 400
[Enhance] concat micro batch inputs before emb's and lm_head's forward #1409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -326,24 +326,19 @@ def _micro_batch_forward( | |||||
| assert len(seq_ctx_list) == len(loss_ctx_list), "seq_ctx and loss_ctx must have same length" | ||||||
|
|
||||||
| # Prepare input embeddings for all micro-batches | ||||||
| hidden_states_list: list[torch.Tensor] = [] | ||||||
| position_embeddings_list = [] | ||||||
|
|
||||||
| for ctx in seq_ctx_list: | ||||||
| input_ids = ctx.input_ids | ||||||
| position_ids = ctx.position_ids | ||||||
|
|
||||||
| if input_ids is not None: | ||||||
| hidden_states = self.embed_tokens(input_ids) | ||||||
| else: | ||||||
| hidden_states = ctx.inputs_embeds | ||||||
|
|
||||||
| # create position embeddings to be shared across the decoder layers | ||||||
| assert position_ids is not None | ||||||
| position_embeddings = self.rotary_emb(hidden_states, position_ids) | ||||||
|
|
||||||
| hidden_states_list.append(hidden_states) | ||||||
| position_embeddings_list.append(position_embeddings) | ||||||
| if seq_ctx_list[0].input_ids is None: | ||||||
| cat_hidden_states = torch.cat([ctx.inputs_embeds for ctx in seq_ctx_list], dim=1) # type: ignore | ||||||
| else: | ||||||
| cat_input_ids = torch.cat([ctx.input_ids for ctx in seq_ctx_list], dim=1) # type: ignore | ||||||
| cat_hidden_states = self.embed_tokens(cat_input_ids) | ||||||
| cat_position_ids = torch.cat([ctx.position_ids for ctx in seq_ctx_list], dim=1) # type: ignore | ||||||
| cat_position_embeddings = self.rotary_emb(cat_hidden_states, cat_position_ids) # type: ignore | ||||||
| position_embeddings_list = list( | ||||||
| zip( | ||||||
| cat_position_embeddings[0].chunk(len(seq_ctx_list), dim=1), | ||||||
| cat_position_embeddings[1].chunk(len(seq_ctx_list), dim=1), | ||||||
| ) | ||||||
| ) | ||||||
|
|
||||||
| # Initialize output containers | ||||||
| output: dict = {} | ||||||
|
|
@@ -353,35 +348,29 @@ def _micro_batch_forward( | |||||
|
|
||||||
| # Process through layers | ||||||
| cat_seq_ctx: SequenceContext | None = None | ||||||
| cat_position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None | ||||||
| cat_hidden_states: torch.Tensor | None = None | ||||||
|
|
||||||
| moe_forawrd = False | ||||||
| moe_forward = False | ||||||
| for idx, decoder_layer in self.layers.items(): | ||||||
| layer_idx = int(idx) | ||||||
|
|
||||||
| if layer_idx < self.config.first_k_dense_replace: | ||||||
| if cat_seq_ctx is None: | ||||||
| cat_seq_ctx = SequenceContext.pack(seq_ctx_list) | ||||||
| cos = torch.cat([pe[0] for pe in position_embeddings_list], dim=1) | ||||||
| sin = torch.cat([pe[1] for pe in position_embeddings_list], dim=1) | ||||||
| cat_position_embeddings = (cos, sin) | ||||||
| cat_hidden_states = torch.cat(hidden_states_list, dim=1) | ||||||
| cat_seq_ctx = SequenceContext.cat(seq_ctx_list) | ||||||
| # Dense decoder layer - process concated hidden states | ||||||
| cat_hidden_states = decoder_layer( | ||||||
| cat_hidden_states, | ||||||
| position_embeddings=cat_position_embeddings, | ||||||
| seq_ctx=cat_seq_ctx, | ||||||
| ) | ||||||
| else: | ||||||
| if cat_hidden_states is not None and not moe_forawrd: | ||||||
| if not moe_forward: | ||||||
| # TODO: `i.clone()` here is weird. However, the current Implementation of | ||||||
| # `async_save_on_cpu` is not friendly with `chunk` op (maybe caused by shared storage? not sure), | ||||||
| # resulting in nan grad norm. So we have to clone the chunked tensors here to make sure each | ||||||
| # hidden state has its own storage. This workaround may introduce extra memory and time cost, and | ||||||
| # should be optimized in the future. | ||||||
| hidden_states_list = [i.clone() for i in cat_hidden_states.chunk(len(seq_ctx_list), dim=1)] | ||||||
| moe_forawrd = True | ||||||
| moe_forward = True | ||||||
|
|
||||||
| if int(os.getenv("XTUNER_ACTIVATION_OFFLOAD", "0")) == 1: | ||||||
| with async_save_on_cpu( | ||||||
|
|
@@ -415,23 +404,18 @@ def _micro_batch_forward( | |||||
| router_weights_list[i][f"layer{idx}"] = router_weights[i] | ||||||
|
|
||||||
| # Apply final norm to all micro-batches | ||||||
| for i, hidden_states in enumerate(hidden_states_list): | ||||||
| hidden_states_list[i] = self.norm(hidden_states) | ||||||
| cat_hidden_states = torch.cat(hidden_states_list, dim=1) | ||||||
HIT-cwh marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| cat_hidden_states = self.norm(cat_hidden_states) | ||||||
|
|
||||||
| # Process final outputs for each micro-batch | ||||||
| loss_list: list[torch.Tensor] = [] | ||||||
| logits_list: list[torch.Tensor] = [] | ||||||
| moe_extra_info = ModelForwardExtraLogInfo() | ||||||
| for hidden_states, loss_ctx_single in zip(hidden_states_list, loss_ctx_list): | ||||||
| loss, (logits, extra_info) = self.lm_head(hidden_states, loss_ctx_single) # type: ignore | ||||||
| loss_list.append(loss) | ||||||
| if logits is not None: | ||||||
| logits_list.append(logits) | ||||||
| if extra_info: | ||||||
| moe_extra_info.append(extra_info) | ||||||
| cat_loss_ctx = CELossContext.cat(loss_ctx_list) | ||||||
| loss, (logits, extra_info) = self.lm_head(cat_hidden_states, cat_loss_ctx) | ||||||
|
|
||||||
| # Aggregate losses (mean across micro-batches) | ||||||
|
||||||
| # Aggregate losses (mean across micro-batches) | |
| # Aggregate loss value (using sum across micro-batches or scalar loss as returned) |
Uh oh!
There was an error while loading. Please reload this page.