You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This commit makes two changes during model creation:
1. Decouples promote_trainable_params_to_fp32 from model __init__. This
is to avoid casting to fp32 to save memory in inference-only mode
(#4).
2. Use a context manager to manage default tensor type change. In the
previous version, the default tensor type is reset to
torch.FloatTensor after creating the vision model, which is
technically incorrect and should be the previous default tensor type
instead. We implement our own context manager because the official
context managers seem to be incomplete at this time (PyTorch 2.0.1):
No dtype manager is provided and set_default_device is ineffective to
the torch.Tensor calls which are used in fairscale.
0 commit comments