|
1 | | -from typing import Dict, Iterable, Iterator, Tuple, Union |
| 1 | +from typing import Dict, Iterable, Iterator, Union |
2 | 2 |
|
3 | 3 |
|
4 | 4 | from .global_var import config |
|
8 | 8 | from .parameter import DistributedParameter, OpAllGather |
9 | 9 | from .checkpointing import ScopedTensorInspectorContext |
10 | 10 | from . import debug |
11 | | -from torch.nn.modules.module import _addindent |
12 | 11 | import copy |
13 | 12 |
|
14 | 13 | def round_up(x, d): |
@@ -331,7 +330,8 @@ def __init__(self, inner_module : torch.nn.Module): |
331 | 330 |
|
332 | 331 | # calc total number of parameters |
333 | 332 | for name, param in ordered_parameters: |
334 | | - assert isinstance(param, DistributedParameter), "All parameters in checkpoint block must be DistributedParameter." |
| 333 | + if not isinstance(param, DistributedParameter): |
| 334 | + raise ValueError("All parameters in checkpoint block must be DistributedParameter.") |
335 | 335 |
|
336 | 336 | storage_type = storage_type_cuda(param.storage_type()) |
337 | 337 | kw_name = _get_param_kw(param) |
@@ -464,7 +464,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): |
464 | 464 | # gather here |
465 | 465 | with torch.no_grad(): |
466 | 466 | with CheckpointBlockContext(self): |
467 | | - return self._module.state_dict(destination, prefix, keep_vars) |
| 467 | + return self._module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) |
468 | 468 |
|
469 | 469 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, |
470 | 470 | missing_keys, unexpected_keys, error_msgs): |
|
0 commit comments