Skip to content

Commit 64ae361

Browse files
committed
remove unrelated changes
Signed-off-by: Hongbin Liu <[email protected]>
1 parent cb612c7 commit 64ae361

File tree

9 files changed

+28
-253
lines changed

9 files changed

+28
-253
lines changed

megatron/core/tensor_parallel/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,9 @@
2929
from .random import (
3030
CheckpointWithoutOutput,
3131
checkpoint,
32-
convert_cuda_rng_state,
3332
get_cuda_rng_tracker,
3433
get_data_parallel_rng_tracker_name,
3534
get_expert_parallel_rng_tracker_name,
36-
is_graph_safe_cuda_rng_tracker,
3735
model_parallel_cuda_manual_seed,
3836
)
3937
from .utils import (
@@ -66,11 +64,9 @@
6664
"scatter_to_sequence_parallel_region",
6765
# random.py
6866
"checkpoint",
69-
"convert_cuda_rng_state",
7067
"get_cuda_rng_tracker",
7168
"model_parallel_cuda_manual_seed",
7269
"get_expert_parallel_rng_tracker_name",
73-
"is_graph_safe_cuda_rng_tracker",
7470
"CheckpointWithoutOutput",
7571
# utils.py
7672
"split_tensor_along_last_dim",

megatron/core/tensor_parallel/random.py

Lines changed: 6 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
22

33
# Parts of the code here are adapted from PyTorch
44
# repo: https://github.com/pytorch/pytorch
@@ -111,41 +111,6 @@ def cb():
111111
_lazy_call(cb)
112112

113113

114-
def convert_cuda_rng_state(
115-
state: Union[torch.Tensor, torch.Generator], to_graphable: bool = False
116-
) -> Union[torch.Tensor, torch.Generator]:
117-
"""
118-
Convert the cuda rng state tensor to the graphable version,
119-
or from the graphable version to the non-graphable tensor version.
120-
"""
121-
if to_graphable:
122-
if isinstance(state, torch.Tensor):
123-
# Convert to the graphable version.
124-
# Store current rng state.
125-
orig_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
126-
# Set rng state to the desired one
127-
_set_cuda_rng_state(state, graph_safe=False)
128-
# Get the graphable state
129-
graphable_state = _get_cuda_rng_state(clone=True, graph_safe=True)
130-
# And set the state to the original state we started with.
131-
_set_cuda_rng_state(orig_cuda_rng_state, graph_safe=False)
132-
return graphable_state
133-
elif isinstance(state, torch.Generator):
134-
# already graphable, just return it.
135-
return state
136-
else:
137-
raise ValueError(f"Invalid state type: {type(state)}")
138-
else:
139-
if isinstance(state, torch.Tensor):
140-
# already non-graphable, just return it.
141-
return state
142-
elif isinstance(state, torch.Generator):
143-
# Convert to the non-graphable tensor version.
144-
return state.get_state()
145-
else:
146-
raise ValueError(f"Invalid state type: {type(state)}")
147-
148-
149114
def get_expert_parallel_rng_tracker_name():
150115
"""Get the expert parallel rng tracker name"""
151116
global _EXPERT_PARALLEL_RNG_TRACKER_NAME
@@ -196,10 +161,6 @@ def reset(self):
196161
# Seeds are just for book keeping and ensure no seed is set twice.
197162
self.seeds_ = set()
198163

199-
# Name of the rng state currently being used in the generator.
200-
# The default one is "default-rng" and won't be pushed to the self.states_ dictionary.
201-
self._current_state_name = "default-rng"
202-
203164
def get_states(self):
204165
"""Get rng states. Copy the dictionary so we have direct
205166
pointers to the states, not just a pointer to the dictionary."""
@@ -246,14 +207,10 @@ def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
246207
# Check if we have added the state
247208
if name not in self.states_:
248209
raise Exception('cuda rng state {} is not added'.format(name))
249-
# Store current rng state and name. Store in self.states_ if it's not the default state.
210+
# Store current rng state.
250211
orig_cuda_rng_state = _get_cuda_rng_state(graph_safe=self.use_cudagraphable_rng)
251-
orig_state_name = self._current_state_name
252-
if orig_state_name != "default-rng":
253-
self.states_[orig_state_name] = orig_cuda_rng_state
254-
# Set rng state and name to the desired one.
212+
# Set rng state to the desired one
255213
_set_cuda_rng_state(self.states_[name], graph_safe=self.use_cudagraphable_rng)
256-
self._current_state_name = name
257214
# Record cpu RNG state
258215
cpu_rng_state = torch.get_rng_state()
259216
# Do the stuff we wanted to do.
@@ -263,19 +220,10 @@ def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
263220
# Throw a warning if cpu RNG state changed
264221
if not torch.all(cpu_rng_state == torch.get_rng_state()).item():
265222
logging.getLogger(__name__).warning('CPU RNG state changed within GPU RNG context')
266-
# Check if the current state name is the same as the desired state name.
267-
if self._current_state_name != name:
268-
raise Exception(
269-
f'current state name {self._current_state_name} is not the same as the desired '
270-
f'state name {name}.'
271-
)
272223
# Update the current rng state for later use.
273224
self.states_[name] = _get_cuda_rng_state(graph_safe=self.use_cudagraphable_rng)
274-
# And set the state and name to the original state we started with.
275-
if orig_state_name != "default-rng":
276-
orig_cuda_rng_state = self.states_[orig_state_name]
225+
# And set the state to the original state we started with.
277226
_set_cuda_rng_state(orig_cuda_rng_state, graph_safe=self.use_cudagraphable_rng)
278-
self._current_state_name = orig_state_name
279227

280228

281229
# RNG tracker object.
@@ -429,34 +377,18 @@ def model_parallel_cuda_manual_seed(
429377
_CUDA_RNG_STATE_TRACKER.add(_EXPERT_PARALLEL_RNG_TRACKER_NAME, expert_parallel_seed)
430378

431379

432-
def is_graph_safe_cuda_rng_tracker(cuda_rng_tracker):
433-
"""Check if the cuda rng tracker is graph safe version."""
434-
if HAVE_TE and is_te_min_version("1.5.0"):
435-
from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker
436-
437-
if isinstance(cuda_rng_tracker, TECudaRNGStatesTracker):
438-
return True
439-
if getattr(cuda_rng_tracker, "use_cudagraphable_rng", False):
440-
return True
441-
return False
442-
443-
444380
def _get_all_rng_states():
445381
"""Get all the rng states."""
446382
cpu_rng_state = torch.get_rng_state()
447-
cuda_rng_state = _get_cuda_rng_state(
448-
graph_safe=is_graph_safe_cuda_rng_tracker(get_cuda_rng_tracker())
449-
)
383+
cuda_rng_state = _get_cuda_rng_state()
450384
cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
451385
return cpu_rng_state, cuda_rng_state, cuda_rng_state_tracker
452386

453387

454388
def _set_all_rng_states(cpu_rng_state, cuda_rng_state, cuda_rng_state_tracker):
455389
"""Set all the rng states."""
456390
torch.set_rng_state(cpu_rng_state)
457-
_set_cuda_rng_state(
458-
cuda_rng_state, graph_safe=is_graph_safe_cuda_rng_tracker(get_cuda_rng_tracker())
459-
)
391+
_set_cuda_rng_state(cuda_rng_state)
460392
get_cuda_rng_tracker().set_states(cuda_rng_state_tracker)
461393

462394

megatron/core/transformer/cuda_graphs.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1911,12 +1911,7 @@ def create_cudagraphs(self):
19111911

19121912
# Prepare CUDA Graph capturing input data and call `make_graphed_callables`.
19131913
sample_args, kwargs = self._get_cuda_graph_input_data()
1914-
if self.config.sequence_parallel:
1915-
rng_context = get_cuda_rng_tracker().fork()
1916-
else:
1917-
rng_context = nullcontext()
1918-
with rng_context:
1919-
graphs = make_graphed_callables(tuple(self.flattened_callables), sample_args, **kwargs)
1914+
graphs = make_graphed_callables(tuple(self.flattened_callables), sample_args, **kwargs)
19201915

19211916
# Push the captured graphs to the corresponding TransformerBlock.
19221917
num_layers_accumulated = 0

megatron/core/transformer/moe/moe_utils.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from megatron.core.fp4_utils import get_fp4_align_size
1212
from megatron.core.fp8_utils import get_fp8_align_size
1313
from megatron.core.process_groups_config import ProcessGroupCollection
14-
from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name
1514
from megatron.core.transformer.cuda_graphs import is_graph_capturing
1615
from megatron.core.transformer.enums import CudaGraphScope
1716
from megatron.core.transformer.transformer_config import TransformerConfig
@@ -919,7 +918,6 @@ def get_moe_layer_wise_logging_tracker():
919918
return _MOE_LAYER_WISE_LOGGING_TRACKER
920919

921920

922-
@internal_api
923921
class RandomSTE(torch.autograd.Function):
924922
"""
925923
Straight-Through Estimator(STE) function that returns random values
@@ -928,14 +926,26 @@ class RandomSTE(torch.autograd.Function):
928926
This is used to generate random logits of router for load-balanced benchmark.
929927
"""
930928

929+
generator = None
930+
random_logits = None
931+
931932
@staticmethod
932933
def forward(ctx, logits):
933934
"""
934935
Forward pass returns random logits with rank-specific seed.
935936
"""
936-
with get_cuda_rng_tracker().fork(get_expert_parallel_rng_tracker_name()):
937-
random_logits = logits.clone().normal_()
938-
return random_logits
937+
if is_graph_capturing() and RandomSTE.random_logits is not None:
938+
return RandomSTE.random_logits
939+
940+
if RandomSTE.generator is None:
941+
global_rank = torch.distributed.get_rank()
942+
base_seed = 42
943+
seed = base_seed + global_rank
944+
RandomSTE.generator = torch.Generator(device=logits.device)
945+
RandomSTE.generator.manual_seed(seed)
946+
947+
RandomSTE.random_logits = logits.clone().normal_(generator=RandomSTE.generator)
948+
return RandomSTE.random_logits
939949

940950
@staticmethod
941951
def backward(ctx, grad_output):

megatron/core/transformer/multi_token_prediction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
22

33
import warnings
44
from contextlib import nullcontext

megatron/training/arguments.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,10 +1255,7 @@ def validate_args(args, defaults={}):
12551255

12561256
# CUDA Graphs
12571257
if args.cuda_graph_impl != "none":
1258-
if (
1259-
"transformer_engine" in (args.transformer_impl, args.cuda_graph_impl)
1260-
and not args.te_rng_tracker
1261-
):
1258+
if args.transformer_impl == 'transformer_engine' and not args.te_rng_tracker:
12621259
args.te_rng_tracker = True
12631260
warn_rank_0("te_rng_tracker is not enabled, enabling it for CUDA graphs.", args.rank)
12641261
assert (

megatron/training/checkpointing.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1766,8 +1766,6 @@ def load_model_state_dict(module, state_dict, strict: bool):
17661766
# rng states.
17671767
if not release and not args.finetune and not args.no_load_rng and not ignore_rng_state:
17681768
try:
1769-
cuda_rng_tracker = tensor_parallel.get_cuda_rng_tracker()
1770-
graph_safe_rng = tensor_parallel.is_graph_safe_cuda_rng_tracker(cuda_rng_tracker)
17711769
if 'rng_state' in state_dict:
17721770
if args.ckpt_format == "fsdp_dtensor":
17731771
# FSDP DTensor checkpoints store rng_state in a different format.
@@ -1793,10 +1791,8 @@ def load_model_state_dict(module, state_dict, strict: bool):
17931791
# Check for empty states array
17941792
if not rng_state['rng_tracker_states']:
17951793
raise KeyError
1796-
rng_tracker_states = {
1797-
k: tensor_parallel.convert_cuda_rng_state(v, to_graphable=graph_safe_rng)
1798-
for k, v in rng_state['rng_tracker_states'].items()
1799-
}
1794+
tensor_parallel.get_cuda_rng_tracker().set_states(
1795+
rng_state['rng_tracker_states'])
18001796
else: # backward compatability
18011797
random.setstate(state_dict['random_rng_state'])
18021798
np.random.set_state(state_dict['np_rng_state'])
@@ -1805,11 +1801,8 @@ def load_model_state_dict(module, state_dict, strict: bool):
18051801
# Check for empty states array
18061802
if not state_dict['rng_tracker_states']:
18071803
raise KeyError
1808-
rng_tracker_states = {
1809-
k: tensor_parallel.convert_cuda_rng_state(v, to_graphable=graph_safe_rng)
1810-
for k, v in state_dict['rng_tracker_states'].items()
1811-
}
1812-
cuda_rng_tracker.set_states(rng_tracker_states)
1804+
tensor_parallel.get_cuda_rng_tracker().set_states(
1805+
state_dict['rng_tracker_states'])
18131806
except KeyError:
18141807
print_rank_0('Unable to load rng state from checkpoint {}. '
18151808
'Specify --no-load-rng or --finetune to prevent '

megatron/training/training.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -646,9 +646,6 @@ def pretrain(
646646
)
647647
set_ideal_affinity_for_current_gpu()
648648

649-
if args.batch_invariant_mode:
650-
print_rank_0("Enabling batch invariant mode globally",flush=True)
651-
enable_batch_invariant_mode()
652649

653650
if args.log_progress:
654651
append_to_progress_log("Starting job")

0 commit comments

Comments
 (0)