Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/axolotl/utils/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,12 @@ def resolve_dtype(cfg):
if cfg.bf16:
cfg.fp16 = False

if cfg.bf16 or cfg.bfloat16:
# For mixed precision, we want the base model loaded in fp32
if cfg.fp16 or cfg.bf16:
cfg.torch_dtype = torch.float32
elif cfg.bfloat16:
cfg.torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
elif cfg.load_in_8bit or cfg.float16:
cfg.torch_dtype = torch.float16
else:
cfg.torch_dtype = torch.float32
Expand Down
Loading