Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def add_megatron_arguments(parser: argparse.ArgumentParser):
parser = _add_one_logger_args(parser)
parser = _add_inprocess_restart_args(parser)
parser = _add_ft_package_args(parser)
parser = _add_tensor_inspect_args(parser)
parser = _add_rerun_machine_args(parser)
parser = _add_msc_args(parser)
parser = _add_kitchen_quantization_arguments(parser)
Expand Down Expand Up @@ -2120,6 +2121,19 @@ def _add_ft_package_args(parser):
return parser


def _add_tensor_inspect_args(parser):
group = parser.add_argument_group(title='tensor_inspect')
group.add_argument('--tensor-inspect', action='store_true',
help='Enable tensor inspection via NVIDIA DLFw Inspect.')
group.add_argument('--tensor-inspect-config', type=str, default=None,
help='Path to YAML config for tensor inspection features.')
group.add_argument('--tensor-inspect-log-dir', type=str, default=None,
help='Directory for tensor inspection logs.')
group.add_argument('--tensor-inspect-feature-dirs', type=str, nargs='+', default=None,
help='Directories containing tensor inspection feature implementations.')
return parser


def _add_logging_args(parser):
from megatron.training.config import LoggerConfig

Expand Down
157 changes: 157 additions & 0 deletions megatron/training/tensor_inspect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

"""NVIDIA DLFw Inspect integration for tensor inspection and statistics collection."""

from typing import Any, List, Optional

from megatron.training.utils import print_rank_0


MISSING_NVINSPECT_MSG = (
"nvdlfw_inspect is not available. Please install it with `pip install nvdlfw-inspect`."
)

try:
import nvdlfw_inspect.api as nvinspect_api
from nvdlfw_inspect.logging import BaseLogger, MetricLogger, wrap_tensorboard_writer

HAVE_NVINSPECT = True
except (ImportError, ModuleNotFoundError):
HAVE_NVINSPECT = False
nvinspect_api = None
BaseLogger = None
MetricLogger = None

def wrap_tensorboard_writer(x):
return x


def _get_default_feature_dirs() -> List[str]:
"""Get default feature directories from installed packages."""
feature_dirs = []
try:
import importlib
from pathlib import Path

te_features_mod = importlib.import_module("transformer_engine.debug.features")
te_features_dir = Path(te_features_mod.__file__).parent
if te_features_dir.exists():
feature_dirs.append(str(te_features_dir))
except Exception:
pass

return feature_dirs


def _clean_metric_name(name: str) -> str:
"""Strip model wrapper prefixes from metric names for cleaner logging."""
prefixes = ["model.module.module.", "model.module.", "model."]
for prefix in prefixes:
if name.startswith(prefix):
return name[len(prefix) :]
return name


def _maybe_attach_metric_loggers(tensorboard_logger: Any, wandb_logger: Any) -> None:
"""Attach TensorBoard and W&B loggers to nvdlfw_inspect."""
if not HAVE_NVINSPECT:
return

try:
if tensorboard_logger is not None:
tb_logger = wrap_tensorboard_writer(tensorboard_logger)
MetricLogger.add_logger(tb_logger)

if wandb_logger is not None and hasattr(wandb_logger, "log"):
if BaseLogger is None:
return

class _WandbModuleLogger(BaseLogger):
def __init__(self, wandb_module):
super().__init__()
self._wandb = wandb_module

def log_scalar(self, name: str, value: float, iteration: int, **kwargs):
clean_name = _clean_metric_name(name)
self._wandb.log({clean_name: value}, step=iteration)

MetricLogger.add_logger(_WandbModuleLogger(wandb_logger))

except Exception as e:
print_rank_0(f"Warning: Failed to attach metric loggers to tensor inspection: {e}")


def initialize_tensor_inspect_pre_model(
enabled: bool,
config_file: Optional[str] = None,
feature_dirs: Optional[List[str]] = None,
log_dir: Optional[str] = None,
init_training_step: int = 0,
) -> None:
"""Initialize NVIDIA-DL-Framework-Inspect before model construction."""
if not enabled:
return

if not HAVE_NVINSPECT:
raise ImportError(MISSING_NVINSPECT_MSG)

if feature_dirs is None:
feature_dirs = _get_default_feature_dirs()

nvinspect_api.initialize(
config_file=config_file or "",
feature_dirs=feature_dirs,
log_dir=log_dir or ".",
statistics_logger=None,
init_training_step=init_training_step,
default_logging_enabled=True,
)
print_rank_0("Initialized NVIDIA DLFw Inspect.")


def finalize_tensor_inspect_post_model(
enabled: bool,
model: List[Any],
tensorboard_logger: Any = None,
wandb_logger: Any = None,
current_training_step: Optional[int] = None,
include_context_parallel: bool = True,
) -> None:
"""Finalize tensor inspection setup after model creation."""
if not enabled:
return

if not HAVE_NVINSPECT:
raise ImportError(MISSING_NVINSPECT_MSG)

from megatron.core.parallel_state import get_tensor_and_data_parallel_group

_maybe_attach_metric_loggers(tensorboard_logger, wandb_logger)

if current_training_step is not None:
nvinspect_api.initialize_training_step(int(current_training_step))

nvinspect_api.infer_and_assign_layer_names(model)
nvinspect_api.set_tensor_reduction_group(
get_tensor_and_data_parallel_group(with_context_parallel=include_context_parallel)
)
print_rank_0("Finalized NVIDIA DLFw Inspect.")


def tensor_inspect_step(enabled: bool) -> None:
"""Advance the tensor inspection step counter."""
if not enabled:
return

if not HAVE_NVINSPECT:
raise ImportError(MISSING_NVINSPECT_MSG)

nvinspect_api.step()


def tensor_inspect_end(enabled: bool) -> None:
"""Shutdown tensor inspection."""
if not enabled or not HAVE_NVINSPECT:
return

nvinspect_api.end_debug()
32 changes: 32 additions & 0 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ def set_startup_timestamps(program_start=None, main_entry=None):
)

from .async_utils import maybe_finalize_async_save
from .tensor_inspect import (
initialize_tensor_inspect_pre_model,
finalize_tensor_inspect_post_model,
tensor_inspect_step,
tensor_inspect_end,
)
from .utils import (
append_to_progress_log,
calc_params_l2_norm,
Expand Down Expand Up @@ -920,6 +926,15 @@ def pretrain(
else:
checkpointing_context = {}

if getattr(args, 'tensor_inspect', False):
initialize_tensor_inspect_pre_model(
enabled=True,
config_file=args.tensor_inspect_config,
feature_dirs=args.tensor_inspect_feature_dirs,
log_dir=args.tensor_inspect_log_dir or args.save,
init_training_step=getattr(args, 'iteration', 0),
)

# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup', log_level=0).start(barrier=True)
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
Expand Down Expand Up @@ -1014,6 +1029,16 @@ def pretrain(
"This flag is only useful when doing refit since the weights are shared with the training model."
)

if getattr(args, 'tensor_inspect', False):
finalize_tensor_inspect_post_model(
enabled=True,
model=model,
tensorboard_logger=get_tensorboard_writer(),
wandb_logger=get_wandb_writer(),
current_training_step=args.iteration,
include_context_parallel=True,
)

# Data stuff.
app_metrics['app_build_dataiters_start_time'] = one_logger_utils.get_timestamp_in_ms()
timers('train/valid/test-data-iterators-setup', log_level=0).start(barrier=True)
Expand Down Expand Up @@ -2894,6 +2919,10 @@ def trace_handler(p):
forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func, iteration=iteration
)
ft_integration.on_training_step_end()

if getattr(args, 'tensor_inspect', False):
tensor_inspect_step(enabled=True)

if should_checkpoint:
save_checkpoint_and_time(
iteration,
Expand Down Expand Up @@ -3119,6 +3148,9 @@ def trace_handler(p):
print_rank_0(f"Total training energy (GPU): {total_energy / 1e6:.3f} MJ")
energy_monitor.shutdown()

if getattr(args, 'tensor_inspect', False):
tensor_inspect_end(enabled=True)

# If any exit conditions (signal handler, duration, iterations) have been reached, exit.
if should_exit:
wandb_writer = get_wandb_writer()
Expand Down
Loading
Loading