Skip to content

Move the module to the precision dtype #12591

@yuvalkirstain

Description

@yuvalkirstain

🐛 Bug

@SeanNaren
When I use bf-16 and check the dtype of the model, it seems like the model's precision is fp32 (and I do not see the memory gains I expect). On other frameworks that support bf-16 (like fairseq) the model's dtype is torch.bfloat16. Is there a simple example that "proves" that this feature reduces the memory consumption as it should? I suspect that there might be something wrong (but of course, I might be wrong).
Thank you!

To Reproduce

launch any job with precision=bf16 and compare with precision=32.

Expected behavior

This feature should save 30-50% memory but I do not see such gains in lightning.

Environment

  • CUDA:
    • GPU:
      • GeForce RTX 3090
    • available: True
    • version: 11.3
  • Packages:
    • numpy: 1.21.2
    • pyTorch_debug: False
    • pyTorch_version: 1.11.0
    • pytorch-lightning: 1.6.0dev
    • tqdm: 4.62.3
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.8.12
    • version: Load fix #74-Ubuntu SMP Tue Sep 17 17:06:04 UTC 2019

Additional context

BF-16 is a very important feature. It is usually more stable than fp16 and lightning should support it effectively (models that are pretrained with bf-16 should not be used with fp-16) :)

cc @Borda @tchaton @rohitgr7 @carmocca @justusschock @awaelchli @akihironitta

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions