1- # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES . All rights reserved.
1+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
22
33# Parts of the code here are adapted from PyTorch
44# repo: https://github.com/pytorch/pytorch
@@ -111,41 +111,6 @@ def cb():
111111 _lazy_call (cb )
112112
113113
114- def convert_cuda_rng_state (
115- state : Union [torch .Tensor , torch .Generator ], to_graphable : bool = False
116- ) -> Union [torch .Tensor , torch .Generator ]:
117- """
118- Convert the cuda rng state tensor to the graphable version,
119- or from the graphable version to the non-graphable tensor version.
120- """
121- if to_graphable :
122- if isinstance (state , torch .Tensor ):
123- # Convert to the graphable version.
124- # Store current rng state.
125- orig_cuda_rng_state = _get_cuda_rng_state (graph_safe = False )
126- # Set rng state to the desired one
127- _set_cuda_rng_state (state , graph_safe = False )
128- # Get the graphable state
129- graphable_state = _get_cuda_rng_state (clone = True , graph_safe = True )
130- # And set the state to the original state we started with.
131- _set_cuda_rng_state (orig_cuda_rng_state , graph_safe = False )
132- return graphable_state
133- elif isinstance (state , torch .Generator ):
134- # already graphable, just return it.
135- return state
136- else :
137- raise ValueError (f"Invalid state type: { type (state )} " )
138- else :
139- if isinstance (state , torch .Tensor ):
140- # already non-graphable, just return it.
141- return state
142- elif isinstance (state , torch .Generator ):
143- # Convert to the non-graphable tensor version.
144- return state .get_state ()
145- else :
146- raise ValueError (f"Invalid state type: { type (state )} " )
147-
148-
149114def get_expert_parallel_rng_tracker_name ():
150115 """Get the expert parallel rng tracker name"""
151116 global _EXPERT_PARALLEL_RNG_TRACKER_NAME
@@ -196,10 +161,6 @@ def reset(self):
196161 # Seeds are just for book keeping and ensure no seed is set twice.
197162 self .seeds_ = set ()
198163
199- # Name of the rng state currently being used in the generator.
200- # The default one is "default-rng" and won't be pushed to the self.states_ dictionary.
201- self ._current_state_name = "default-rng"
202-
203164 def get_states (self ):
204165 """Get rng states. Copy the dictionary so we have direct
205166 pointers to the states, not just a pointer to the dictionary."""
@@ -246,14 +207,10 @@ def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
246207 # Check if we have added the state
247208 if name not in self .states_ :
248209 raise Exception ('cuda rng state {} is not added' .format (name ))
249- # Store current rng state and name. Store in self.states_ if it's not the default state .
210+ # Store current rng state.
250211 orig_cuda_rng_state = _get_cuda_rng_state (graph_safe = self .use_cudagraphable_rng )
251- orig_state_name = self ._current_state_name
252- if orig_state_name != "default-rng" :
253- self .states_ [orig_state_name ] = orig_cuda_rng_state
254- # Set rng state and name to the desired one.
212+ # Set rng state to the desired one
255213 _set_cuda_rng_state (self .states_ [name ], graph_safe = self .use_cudagraphable_rng )
256- self ._current_state_name = name
257214 # Record cpu RNG state
258215 cpu_rng_state = torch .get_rng_state ()
259216 # Do the stuff we wanted to do.
@@ -263,19 +220,10 @@ def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
263220 # Throw a warning if cpu RNG state changed
264221 if not torch .all (cpu_rng_state == torch .get_rng_state ()).item ():
265222 logging .getLogger (__name__ ).warning ('CPU RNG state changed within GPU RNG context' )
266- # Check if the current state name is the same as the desired state name.
267- if self ._current_state_name != name :
268- raise Exception (
269- f'current state name { self ._current_state_name } is not the same as the desired '
270- f'state name { name } .'
271- )
272223 # Update the current rng state for later use.
273224 self .states_ [name ] = _get_cuda_rng_state (graph_safe = self .use_cudagraphable_rng )
274- # And set the state and name to the original state we started with.
275- if orig_state_name != "default-rng" :
276- orig_cuda_rng_state = self .states_ [orig_state_name ]
225+ # And set the state to the original state we started with.
277226 _set_cuda_rng_state (orig_cuda_rng_state , graph_safe = self .use_cudagraphable_rng )
278- self ._current_state_name = orig_state_name
279227
280228
281229# RNG tracker object.
@@ -429,34 +377,18 @@ def model_parallel_cuda_manual_seed(
429377 _CUDA_RNG_STATE_TRACKER .add (_EXPERT_PARALLEL_RNG_TRACKER_NAME , expert_parallel_seed )
430378
431379
432- def is_graph_safe_cuda_rng_tracker (cuda_rng_tracker ):
433- """Check if the cuda rng tracker is graph safe version."""
434- if HAVE_TE and is_te_min_version ("1.5.0" ):
435- from megatron .core .extensions .transformer_engine import TECudaRNGStatesTracker
436-
437- if isinstance (cuda_rng_tracker , TECudaRNGStatesTracker ):
438- return True
439- if getattr (cuda_rng_tracker , "use_cudagraphable_rng" , False ):
440- return True
441- return False
442-
443-
444380def _get_all_rng_states ():
445381 """Get all the rng states."""
446382 cpu_rng_state = torch .get_rng_state ()
447- cuda_rng_state = _get_cuda_rng_state (
448- graph_safe = is_graph_safe_cuda_rng_tracker (get_cuda_rng_tracker ())
449- )
383+ cuda_rng_state = _get_cuda_rng_state ()
450384 cuda_rng_state_tracker = get_cuda_rng_tracker ().get_states ()
451385 return cpu_rng_state , cuda_rng_state , cuda_rng_state_tracker
452386
453387
454388def _set_all_rng_states (cpu_rng_state , cuda_rng_state , cuda_rng_state_tracker ):
455389 """Set all the rng states."""
456390 torch .set_rng_state (cpu_rng_state )
457- _set_cuda_rng_state (
458- cuda_rng_state , graph_safe = is_graph_safe_cuda_rng_tracker (get_cuda_rng_tracker ())
459- )
391+ _set_cuda_rng_state (cuda_rng_state )
460392 get_cuda_rng_tracker ().set_states (cuda_rng_state_tracker )
461393
462394
0 commit comments