@@ -411,6 +411,7 @@ def _load_checkpoint(
411
411
state : dict [str , Union [Module , Optimizer , Any ]],
412
412
strict : bool = True ,
413
413
optimizer_states_from_list : bool = False ,
414
+ weights_only : bool = False ,
414
415
) -> dict [str , Any ]:
415
416
from torch .distributed .checkpoint .state_dict import (
416
417
StateDictOptions ,
@@ -449,7 +450,7 @@ def _load_checkpoint(
449
450
set_optimizer_state_dict (module , optim , optim_state_dict = optim_state [optim_key ], options = state_dict_options )
450
451
451
452
# Load metadata (anything not a module or optimizer)
452
- metadata = torch .load (path / _METADATA_FILENAME )
453
+ metadata = torch .load (path / _METADATA_FILENAME , weights_only = weights_only )
453
454
requested_metadata_keys = state .keys () - modules .keys () - optimizers .keys ()
454
455
_validate_keys_for_strict_loading (requested_metadata_keys , metadata .keys (), strict = strict )
455
456
for key in requested_metadata_keys :
@@ -461,7 +462,7 @@ def _load_checkpoint(
461
462
return metadata
462
463
463
464
if _is_full_checkpoint (path ):
464
- checkpoint = torch .load (path , mmap = True , map_location = "cpu" , weights_only = False )
465
+ checkpoint = torch .load (path , mmap = True , map_location = "cpu" , weights_only = weights_only )
465
466
_load_raw_module_state (checkpoint .pop (module_key ), module , strict = strict )
466
467
467
468
state_dict_options = StateDictOptions (
0 commit comments