diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 7a2bbd6f9a..bc529b58e4 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -86,9 +86,12 @@ def resolve_dtype(cfg): if cfg.bf16: cfg.fp16 = False - if cfg.bf16 or cfg.bfloat16: + # For mixed precision, we want the base model loaded in fp32 + if cfg.fp16 or cfg.bf16: + cfg.torch_dtype = torch.float32 + elif cfg.bfloat16: cfg.torch_dtype = torch.bfloat16 - elif cfg.load_in_8bit or cfg.fp16 or cfg.float16: + elif cfg.load_in_8bit or cfg.float16: cfg.torch_dtype = torch.float16 else: cfg.torch_dtype = torch.float32