@@ -256,7 +256,7 @@ def step(self, closure=None):
256256 self .to_gpu () # needed for fairseq pure fp16 training
257257 self .initialized = True
258258
259- if self .is_paged : self .page_mng .prefetch_all ()
259+ # if self.is_paged: self.page_mng.prefetch_all()
260260 for gindex , group in enumerate (self .param_groups ):
261261 for pindex , p in enumerate (group ["params" ]):
262262 if p .grad is None :
@@ -265,7 +265,9 @@ def step(self, closure=None):
265265 if len (state ) == 0 :
266266 self .init_state (group , p , gindex , pindex )
267267
268+ self .prefetch_state (p )
268269 self .update_step (group , p , gindex , pindex )
270+ torch .cuda .synchronize ()
269271 if self .is_paged :
270272 # all paged operation are asynchronous, we need
271273 # to sync to make sure all tensors are in the right state
@@ -309,6 +311,13 @@ def get_state_buffer(self, p, dtype=torch.float32):
309311 self .page_mng .paged_tensors .append (buff )
310312 return buff
311313
314+ def prefetch_state (self , p ):
315+ if self .is_paged :
316+ state = self .state [p ]
317+ F .prefetch_tensor (state ['state1' ])
318+ if 'state2' in state :
319+ F .prefetch_tensor (state ['state2' ])
320+
312321
313322class Optimizer2State (Optimizer8bit ):
314323 def __init__ (
0 commit comments