-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
featureIs an improvement or enhancementIs an improvement or enhancement
Description
Description & Motivation
Hi,
It would be nice to have a dtype
argument for load_from_checkpoint
along with the very cool map_location
. It would allow the user to set the dtype
automatically via load_from_checkpoint
without the manual need of doing .to(dtype)
.
Pitch
The fix is should be pretty straightforward. Currently we have:
def load_from_checkpoint(
cls,
checkpoint_path: Union[_PATH, IO],
map_location: _MAP_LOCATION_TYPE = None,
hparams_file: Optional[_PATH] = None,
strict: Optional[bool] = None,
**kwargs: Any,
) -> Self:
...
It would just get one more argument:
def load_from_checkpoint(
cls,
checkpoint_path: Union[_PATH, IO],
map_location: _MAP_LOCATION_TYPE = None,
type:torch.dtype=None,
hparams_file: Optional[_PATH] = None,
strict: Optional[bool] = None,
**kwargs: Any,
) -> Self:
...
Given this actually points to _load_from_checkpoint
def _load_from_checkpoint( |
We can add our extra dtype
argument here too and can easily change this line
return model.to(device) |
to
return model.to(dtype).to(device)
Hope this is taken into consideration :)
Alternatives
No response
Additional context
No response
intexcor
Metadata
Metadata
Assignees
Labels
featureIs an improvement or enhancementIs an improvement or enhancement