Skip to content

Commit 354b5a3

Browse files
committed
cache num fw args in train and eval
Signed-off-by: Maanu Grover <[email protected]>
1 parent fb13862 commit 354b5a3

File tree

4 files changed

+28
-18
lines changed

4 files changed

+28
-18
lines changed

nemo/tron/api.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import inspect
1615
from typing import Callable
1716

1817
from nemo.tron.checkpointing import save_checkpoint
@@ -40,15 +39,6 @@ def megatron_pretrain(
4039
test_data_iterator = setup_output.test_data_iterator
4140
ckpt_context = setup_output.checkpointing_context
4241

43-
# Check num args to forward_step_func
44-
num_fw_args = len(inspect.signature(forward_step_func).parameters)
45-
fail_msg = f"""
46-
forward_step_func has {num_fw_args} arguments. Only the following signatures are supported:
47-
2 args: forward_step_func(data_iterator: Iterable, model: GPTModel)
48-
3 args: forward_step_func(state: GlobalState, data_iterator: Iterable, model: GPTModel)
49-
"""
50-
assert num_fw_args in (2, 3), fail_msg
51-
5242
## TRAINING ##
5343
if not config.train_config.skip_train:
5444
print_rank_0("training ...")

nemo/tron/eval.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from nemo.tron import fault_tolerance
2626
from nemo.tron.state import GlobalState
2727
from nemo.tron.utils.common_utils import is_last_rank, print_rank_0, print_rank_last
28-
from nemo.tron.utils.train_utils import maybe_inject_state
28+
from nemo.tron.utils.train_utils import check_forward_step_func_num_args, maybe_inject_state
2929

3030

3131
def evaluate(
@@ -39,8 +39,10 @@ def evaluate(
3939
non_loss_data_func=None,
4040
):
4141
"""Evaluation."""
42-
timers = state.timers
42+
# Check num args to forward_step_func
43+
num_fw_args = check_forward_step_func_num_args(forward_step_func)
4344

45+
timers = state.timers
4446
timers("evaluate", log_level=0).start(barrier=True)
4547

4648
# Turn on evaluation mode which disables dropout.
@@ -67,7 +69,7 @@ def evaluate(
6769
if verbose:
6870
print_rank_0(f"Evaluating iter {iteration}/{state.cfg.train_config.eval_iters}")
6971

70-
wrapped_forward_step = maybe_inject_state(forward_step_func, state)
72+
wrapped_forward_step = maybe_inject_state(forward_step_func, state, num_fw_args=num_fw_args)
7173
forward_backward_func = get_forward_backward_func()
7274
# Don't care about timing during evaluation
7375
config.timers = None

nemo/tron/train.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from nemo.tron.utils.common_utils import append_to_progress_log, barrier_and_log, get_world_size_safe, print_rank_0
4444
from nemo.tron.utils.train_utils import (
4545
calc_params_l2_norm,
46+
check_forward_step_func_num_args,
4647
logical_and_across_model_parallel_group,
4748
maybe_inject_state,
4849
reduce_max_stat_across_model_parallel_group,
@@ -68,6 +69,9 @@ def train(
6869
timers = global_state.timers
6970
straggler_timer = global_state.straggler_timer
7071

72+
# Check num args to forward_step_func
73+
num_fw_args = check_forward_step_func_num_args(forward_step_func)
74+
7175
# Turn on training mode which enables dropout.
7276
for model_module in model:
7377
model_module.train()
@@ -231,7 +235,7 @@ def train(
231235
# Run training step.
232236
fault_tolerance.on_training_step_start(global_state)
233237
loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = train_step(
234-
forward_step_func, train_data_iterator, model, optimizer, scheduler, global_state
238+
forward_step_func, num_fw_args, train_data_iterator, model, optimizer, scheduler, global_state
235239
)
236240
fault_tolerance.on_training_step_end(global_state)
237241
if should_checkpoint:
@@ -403,6 +407,7 @@ def train(
403407

404408
def train_step(
405409
forward_step_func,
410+
num_fw_args,
406411
data_iterator,
407412
model,
408413
optimizer,
@@ -424,7 +429,7 @@ def train_step(
424429
optimizer.zero_grad()
425430

426431
# Optionally inject state into forward step
427-
wrapped_forward_step = maybe_inject_state(forward_step_func, global_state)
432+
wrapped_forward_step = maybe_inject_state(forward_step_func, global_state, num_fw_args=num_fw_args)
428433

429434
# Forward pass.
430435
forward_backward_func = get_forward_backward_func()

nemo/tron/utils/train_utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import inspect
1616
from datetime import datetime
1717
from functools import partial
18-
from typing import Callable
18+
from typing import Callable, Optional
1919

2020
import torch
2121
from megatron.core import parallel_state
@@ -514,10 +514,23 @@ def reduce_aux_losses_tracker_across_ranks():
514514
torch.distributed.all_reduce(values, group=tracker[name]["avg_group"], op=torch.distributed.ReduceOp.AVG)
515515

516516

517-
def maybe_inject_state(forward_step_func: Callable, state: GlobalState) -> Callable:
518-
num_fw_args = len(inspect.signature(forward_step_func).parameters)
517+
def maybe_inject_state(forward_step_func: Callable, state: GlobalState, num_fw_args: Optional[int] = None) -> Callable:
518+
if not num_fw_args:
519+
num_fw_args = len(inspect.signature(forward_step_func).parameters)
519520
if num_fw_args == 3:
520521
# inject global_state
521522
return partial(forward_step_func, state)
522523
else:
523524
return forward_step_func
525+
526+
527+
def check_forward_step_func_num_args(forward_step_func: Callable) -> int:
528+
num_fw_args = len(inspect.signature(forward_step_func).parameters)
529+
fail_msg = f"""
530+
forward_step_func has {num_fw_args} arguments. Only the following signatures are supported:
531+
2 args: forward_step_func(data_iterator: Iterable, model: GPTModel)
532+
3 args: forward_step_func(state: GlobalState, data_iterator: Iterable, model: GPTModel)
533+
"""
534+
assert num_fw_args in (2, 3), fail_msg
535+
536+
return num_fw_args

0 commit comments

Comments
 (0)