-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcheckpoint_handler.py
More file actions
144 lines (124 loc) · 6 KB
/
checkpoint_handler.py
File metadata and controls
144 lines (124 loc) · 6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import torch
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
try:
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend as xb
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
import torch_xla.utils.serialization as xser
except ImportError:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import (
StateDictType,
ShardedStateDictConfig,
ShardedOptimStateDictConfig
)
xm = None
xr = None
from device_utils import get_current_device_type
from train_config import TrainConfig
from logging_handler import get_logger
logger = get_logger()
class CheckpointHandler:
def __init__(self, cfg: TrainConfig):
super().__init__()
self.cfg = cfg
self.__init_chkpt_path()
if get_current_device_type() == "cuda":
self.sharded_state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
self.sharded_optim_state_dict_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
def __init_chkpt_path(self):
self.__chkpt_path = os.path.join(self.cfg.checkpoint_dir,
get_current_device_type(),
f"rank_{self.cfg.rank}-world_{self.cfg.world_size}.pt" )
logger.info(f"Checkpoint path: {self.__chkpt_path}")
os.makedirs(os.path.dirname(self.__chkpt_path), exist_ok=True)
def __load_checkpoint(self, model: torch.nn.Module, optimizer: Optimizer, scheduler: LRScheduler) -> int:
epoch = 1
try:
state_dict = torch.load(self.__chkpt_path, weights_only=False)
model.load_state_dict(state_dict['model'])
optimizer.load_state_dict(state_dict['optimizer'])
scheduler.load_state_dict(state_dict['scheduler'])
epoch = state_dict['epoch'] + 1
except Exception as e:
logger.info(f"load chkpt error: {e}")
return epoch
def __load_fsdp_checkpoint(self, model: torch.nn.Module, optimizer: Optimizer, scheduler: LRScheduler) -> int:
epoch = 1
try:
state_dict = torch.load(self.__chkpt_path, weights_only=False)
set_state_dict(
model,
optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optimizer"]
)
scheduler.load_state_dict(state_dict['scheduler'])
epoch = state_dict['epoch'] + 1
except Exception as e:
logger.info(f"load chkpt error: {e}")
return epoch
def __save_checkpoint(self, model: torch.nn.Module, optimizer: Optimizer, scheduler: LRScheduler, epoch:int):
try:
msd = model.state_dict()
osd = optimizer.state_dict()
ssd = scheduler.state_dict()
torch.save({"model": msd, "optimizer": osd, "scheduler": ssd, "epoch": epoch}, self.__chkpt_path)
except Exception as e:
logger.warning(f"save chkpt error: {e}")
def __save_fsdp_checkpoint(self, model: torch.nn.Module, optimizer: Optimizer, scheduler: LRScheduler, epoch:int):
try:
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
scheduler_state_dict = scheduler.state_dict()
torch.save({"model": model_state_dict,
"optimizer": optimizer_state_dict,
"scheduler": scheduler_state_dict,
"epoch": epoch},
self.__chkpt_path)
except Exception as e:
logger.warning(f"save chkpt error: {e}")
def __save_fsdp_xla_checkpoint(self, model: torch.nn.Module, optimizer: Optimizer, scheduler: LRScheduler, epoch:int):
try:
state_dict = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'shard_metadata': model.get_shard_metadata(),
"epoch": epoch
}
xser.save(state_dict, self.__chkpt_path, master_only=False)
except Exception as e:
logger.warning(f"save chkpt error: {e}")
def __load_fsdp_xla_checkpoint(self, model: torch.nn.Module, optimizer: Optimizer, scheduler: LRScheduler) -> int:
try:
state_dict = xser.load(self.__chkpt_path)
model.load_state_dict(state_dict['model'])
optimizer.load_state_dict(state_dict['optimizer'])
scheduler.load_state_dict(state_dict['scheduler'])
epoch = state_dict['epoch'] + 1
except Exception as e:
logger.warning(f"load chkpt error: {e}")
return epoch
def save(self, model: torch.nn.Module, optimizer: Optimizer, scheduler: LRScheduler, epoch:int):
if self.cfg.fsdp:
if xm:
self.__save_fsdp_xla_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch)
else:
self.__save_fsdp_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch)
else:
self.__save_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch)
def load(self, model: torch.nn.Module, optimizer: Optimizer, scheduler: LRScheduler) -> int:
if not os.path.isfile(self.__chkpt_path):
logger.info("No checkpoint available")
return 1
if self.cfg.fsdp:
if xm:
return self.__load_fsdp_xla_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler)
else:
return self.__load_fsdp_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler)
else:
return self.__load_checkpoint(model=model, optimizer=optimizer, scheduler=scheduler)