-
Notifications
You must be signed in to change notification settings - Fork 3.4k
LLM Forward Step #12673
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
LLM Forward Step #12673
Changes from 8 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
b32914f
pretrain loss func
maanug-nv f3c03dd
get batch and forward
maanug-nv d397d34
add rerun functionality to loss
maanug-nv 7d21d7e
formatting
maanug-nv 27515de
injection of state
maanug-nv 2181140
remove globalstate singleton functionality
maanug-nv 46ef694
update example
maanug-nv 82bf9f6
missing copyright
maanug-nv 75c5fe3
fix for latest mcore
maanug-nv 080901c
syntax
maanug-nv 6f085c9
move assertion
maanug-nv 686d6f9
refactor for eval
maanug-nv b7ac969
move to avoid circular import
maanug-nv 71894a1
fix
maanug-nv fb13862
unused
maanug-nv 354b5a3
cache num fw args in train and eval
maanug-nv b31d7f9
docstring fix
maanug-nv 430741f
remove duplicate
maanug-nv File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from functools import partial | ||
| from typing import Iterable | ||
|
|
||
| from megatron.core import parallel_state | ||
| from megatron.core.models.gpt import GPTModel | ||
| from megatron.core.utils import get_batch_on_this_cp_rank | ||
|
|
||
| from nemo.tron.config import ConfigContainer | ||
| from nemo.tron.llm.utils import get_batch_on_this_tp_rank | ||
| from nemo.tron.losses import masked_next_token_loss | ||
| from nemo.tron.state import GlobalState | ||
|
|
||
|
|
||
| def get_batch(data_iterator, cfg: ConfigContainer): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be good to add typehint + docs for the return value |
||
| """Generate a batch.""" | ||
|
|
||
| if (not parallel_state.is_pipeline_first_stage()) and (not parallel_state.is_pipeline_last_stage()): | ||
| return None, None, None, None, None | ||
|
|
||
| # get batches based on the TP rank you are on | ||
| batch = get_batch_on_this_tp_rank(data_iterator, cfg) | ||
|
|
||
| # slice batch along sequence dimension for context parallelism | ||
| batch = get_batch_on_this_cp_rank(batch) | ||
|
|
||
| return batch.values() | ||
|
|
||
|
|
||
| def forward_step(state: GlobalState, data_iterator: Iterable, model: GPTModel): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here, the return type will be helpful |
||
| """Forward training step. | ||
|
|
||
| Args: | ||
| data_iterator : Input data iterator | ||
| model (GPTModel): The GPT Model | ||
maanug-nv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
|
|
||
| timers = state.cfg.model_config.timers | ||
maanug-nv marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| straggler_timer = state.straggler_timer | ||
|
|
||
| timers("batch-generator", log_level=2).start() | ||
| with straggler_timer(bdata=True): | ||
| tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator, state.cfg) | ||
| timers("batch-generator").stop() | ||
|
|
||
| with straggler_timer: | ||
| output_tensor = model(tokens, position_ids, attention_mask, labels=labels) | ||
|
|
||
| return output_tensor, partial(masked_next_token_loss, loss_mask) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import Dict, Iterable | ||
| import torch | ||
| from megatron.core import parallel_state | ||
| from nemo.tron.config import ConfigContainer | ||
|
|
||
|
|
||
| def get_batch_on_this_tp_rank(data_iterator: Iterable, cfg: ConfigContainer) -> Dict[str, torch.Tensor]: | ||
| def _broadcast(item): | ||
| if item is not None: | ||
| torch.distributed.broadcast( | ||
| item, | ||
| parallel_state.get_tensor_model_parallel_src_rank(), | ||
| group=parallel_state.get_tensor_model_parallel_group(), | ||
| ) | ||
|
|
||
| if parallel_state.get_tensor_model_parallel_rank() == 0: | ||
| if data_iterator is not None: | ||
| data = next(data_iterator) | ||
| else: | ||
| data = None | ||
|
|
||
| batch = { | ||
| "tokens": data["tokens"].cuda(non_blocking=True), | ||
| "labels": data["labels"].cuda(non_blocking=True), | ||
| "loss_mask": data["loss_mask"].cuda(non_blocking=True), | ||
| "attention_mask": None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking=True), | ||
| "position_ids": data["position_ids"].cuda(non_blocking=True), | ||
| } | ||
|
|
||
| if cfg.model_config.pipeline_model_parallel_size == 1: | ||
| _broadcast(batch["tokens"]) | ||
| _broadcast(batch["labels"]) | ||
| _broadcast(batch["loss_mask"]) | ||
| _broadcast(batch["attention_mask"]) | ||
| _broadcast(batch["position_ids"]) | ||
|
|
||
| elif parallel_state.is_pipeline_first_stage(): | ||
| _broadcast(batch["tokens"]) | ||
| _broadcast(batch["attention_mask"]) | ||
| _broadcast(batch["position_ids"]) | ||
|
|
||
| elif parallel_state.is_pipeline_last_stage(): | ||
| _broadcast(batch["labels"]) | ||
| _broadcast(batch["loss_mask"]) | ||
| _broadcast(batch["attention_mask"]) | ||
|
|
||
| else: | ||
| mbs = cfg.train_config.micro_batch_size | ||
| seq_length = cfg.model_config.seq_length | ||
| tokens = torch.empty( | ||
| (mbs, seq_length), | ||
| dtype=torch.int64, | ||
| device=torch.cuda.current_device(), | ||
| ) | ||
| labels = torch.empty( | ||
| (mbs, seq_length), | ||
| dtype=torch.int64, | ||
| device=torch.cuda.current_device(), | ||
| ) | ||
| loss_mask = torch.empty( | ||
| (mbs, seq_length), | ||
| dtype=torch.float32, | ||
| device=torch.cuda.current_device(), | ||
| ) | ||
| if cfg.dataset_config.create_attention_mask_in_dataloader: | ||
| attention_mask = torch.empty( | ||
| ( | ||
| mbs, | ||
| 1, | ||
| seq_length, | ||
| seq_length, | ||
| ), | ||
| dtype=torch.bool, | ||
| device=torch.cuda.current_device(), | ||
| ) | ||
| else: | ||
| attention_mask = None | ||
| position_ids = torch.empty( | ||
| (mbs, seq_length), | ||
| dtype=torch.int64, | ||
| device=torch.cuda.current_device(), | ||
| ) | ||
|
|
||
| if cfg.model_config.pipeline_model_parallel_size == 1: | ||
| _broadcast(tokens) | ||
| _broadcast(labels) | ||
| _broadcast(loss_mask) | ||
| _broadcast(attention_mask) | ||
| _broadcast(position_ids) | ||
|
|
||
| elif parallel_state.is_pipeline_first_stage(): | ||
| labels = None | ||
| loss_mask = None | ||
|
|
||
| _broadcast(tokens) | ||
| _broadcast(attention_mask) | ||
| _broadcast(position_ids) | ||
|
|
||
| elif parallel_state.is_pipeline_last_stage(): | ||
| tokens = None | ||
| position_ids = None | ||
|
|
||
| _broadcast(labels) | ||
| _broadcast(loss_mask) | ||
| _broadcast(attention_mask) | ||
|
|
||
| batch = { | ||
| "tokens": tokens, | ||
| "labels": labels, | ||
| "loss_mask": loss_mask, | ||
| "attention_mask": attention_mask, | ||
| "position_ids": position_ids, | ||
| } | ||
|
|
||
| return batch |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.