Skip to content

Add a dtype option for load_from_checkpointΒ #20833

@arijit-hub

Description

@arijit-hub

Description & Motivation

Hi,
It would be nice to have a dtype argument for load_from_checkpoint along with the very cool map_location. It would allow the user to set the dtype automatically via load_from_checkpoint without the manual need of doing .to(dtype).

Pitch

The fix is should be pretty straightforward. Currently we have:

def load_from_checkpoint(
        cls,
        checkpoint_path: Union[_PATH, IO],
        map_location: _MAP_LOCATION_TYPE = None,
        hparams_file: Optional[_PATH] = None,
        strict: Optional[bool] = None,
        **kwargs: Any,
    ) -> Self:
...

It would just get one more argument:

def load_from_checkpoint(
        cls,
        checkpoint_path: Union[_PATH, IO],
        map_location: _MAP_LOCATION_TYPE = None,
        type:torch.dtype=None,
        hparams_file: Optional[_PATH] = None,
        strict: Optional[bool] = None,
        **kwargs: Any,
    ) -> Self:
...

Given this actually points to _load_from_checkpoint

def _load_from_checkpoint(

We can add our extra dtype argument here too and can easily change this line

to

return model.to(dtype).to(device)

Hope this is taken into consideration :)

@Borda @awaelchli @lantiga

Alternatives

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancement

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions