11import torch
22import torch .nn as nn
33from typing import Dict
4-
4+ import platform
55
66def enable_sequential_cpu_offload (module : nn .Module , device : str = "cuda" ):
77 module = module .to ("cpu" )
@@ -26,13 +26,14 @@ def _forward_pre_hook(module: nn.Module, input_):
2626 for name , buffer in module .named_buffers (recurse = recurse ):
2727 buffer .data = buffer .data .to (device = device )
2828 return tuple (x .to (device = device ) if isinstance (x , torch .Tensor ) else x for x in input_ )
29-
30- for name , param in module . named_parameters ( recurse = recurse ) :
31- param .data = param .data .pin_memory ()
29+ for name , param in module . named_parameters ( recurse = recurse ):
30+ if platform . system () == 'Linux' :
31+ param .data = param .data .pin_memory ()
3232 offload_param_dict [name ] = param .data
3333 param .data = param .data .to (device = device )
3434 for name , buffer in module .named_buffers (recurse = recurse ):
35- buffer .data = buffer .data .pin_memory ()
35+ if platform .system () == 'Linux' :
36+ buffer .data = buffer .data .pin_memory ()
3637 offload_param_dict [name ] = buffer .data
3738 buffer .data = buffer .data .to (device = device )
3839 setattr (module , "_offload_param_dict" , offload_param_dict )
@@ -58,10 +59,12 @@ def offload_model_to_dict(module: nn.Module) -> Dict[str, torch.Tensor]:
5859 module = module .to ("cpu" )
5960 offload_param_dict = {}
6061 for name , param in module .named_parameters (recurse = True ):
61- param .data = param .data .pin_memory ()
62+ if platform .system () == 'Linux' :
63+ param .data = param .data .pin_memory ()
6264 offload_param_dict [name ] = param .data
6365 for name , buffer in module .named_buffers (recurse = True ):
64- buffer .data = buffer .data .pin_memory ()
66+ if platform .system () == 'Linux' :
67+ buffer .data = buffer .data .pin_memory ()
6568 offload_param_dict [name ] = buffer .data
6669 return offload_param_dict
6770
0 commit comments