@@ -83,7 +83,10 @@ def onload_(self):
8383
8484 with context :
8585 for group_module in self .modules :
86- group_module .to (self .onload_device , non_blocking = self .non_blocking )
86+ for param in group_module .parameters ():
87+ param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
88+ for buffer in group_module .buffers ():
89+ buffer .data = buffer .data .to (self .onload_device , non_blocking = self .non_blocking )
8790 if self .parameters is not None :
8891 for param in self .parameters :
8992 param .data = param .data .to (self .onload_device , non_blocking = self .non_blocking )
@@ -98,6 +101,12 @@ def offload_(self):
98101 for group_module in self .modules :
99102 for param in group_module .parameters ():
100103 param .data = self .cpu_param_dict [param ]
104+ if self .parameters is not None :
105+ for param in self .parameters :
106+ param .data = self .cpu_param_dict [param ]
107+ if self .buffers is not None :
108+ for buffer in self .buffers :
109+ buffer .data = self .cpu_param_dict [buffer ]
101110 else :
102111 for group_module in self .modules :
103112 group_module .to (self .offload_device , non_blocking = self .non_blocking )
@@ -387,9 +396,7 @@ def _apply_group_offloading_block_level(
387396 # Create a pinned CPU parameter dict for async data transfer if streams are to be used
388397 cpu_param_dict = None
389398 if stream is not None :
390- for param in module .parameters ():
391- param .data = param .data .cpu ().pin_memory ()
392- cpu_param_dict = {param : param .data for param in module .parameters ()}
399+ cpu_param_dict = _get_pinned_cpu_param_dict (module )
393400
394401 # Create module groups for ModuleList and Sequential blocks
395402 modules_with_group_offloading = set ()
@@ -486,9 +493,7 @@ def _apply_group_offloading_leaf_level(
486493 # Create a pinned CPU parameter dict for async data transfer if streams are to be used
487494 cpu_param_dict = None
488495 if stream is not None :
489- for param in module .parameters ():
490- param .data = param .data .cpu ().pin_memory ()
491- cpu_param_dict = {param : param .data for param in module .parameters ()}
496+ cpu_param_dict = _get_pinned_cpu_param_dict (module )
492497
493498 # Create module groups for leaf modules and apply group offloading hooks
494499 modules_with_group_offloading = set ()
@@ -604,6 +609,17 @@ def _apply_lazy_group_offloading_hook(
604609 registry .register_hook (lazy_prefetch_hook , _LAZY_PREFETCH_GROUP_OFFLOADING )
605610
606611
612+ def _get_pinned_cpu_param_dict (module : torch .nn .Module ) -> Dict [torch .nn .Parameter , torch .Tensor ]:
613+ cpu_param_dict = {}
614+ for param in module .parameters ():
615+ param .data = param .data .cpu ().pin_memory ()
616+ cpu_param_dict [param ] = param .data
617+ for buffer in module .buffers ():
618+ buffer .data = buffer .data .cpu ().pin_memory ()
619+ cpu_param_dict [buffer ] = buffer .data
620+ return cpu_param_dict
621+
622+
607623def _gather_parameters_with_no_group_offloading_parent (
608624 module : torch .nn .Module , modules_with_group_offloading : Set [str ]
609625) -> List [torch .nn .Parameter ]:
0 commit comments