Skip to content

Commit 3dd9c3b

Browse files
authored
setup hf transfer too and fix auto bf16 when fp16 enabled (axolotl-ai-cloud#2620) [skip ci]
1 parent 0ba7d36 commit 3dd9c3b

File tree

5 files changed

+18
-7
lines changed

5 files changed

+18
-7
lines changed

src/axolotl/cli/evaluate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from axolotl.cli.config import load_cfg
1616
from axolotl.common.datasets import load_datasets, load_preference_datasets
1717
from axolotl.evaluate import evaluate
18-
from axolotl.utils import set_pytorch_cuda_alloc_conf
18+
from axolotl.utils import patch_optimized_env
1919
from axolotl.utils.dict import DictDefault
2020

2121
LOG = logging.getLogger(__name__)
@@ -32,7 +32,7 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
3232
cli_args: CLI arguments.
3333
"""
3434
# Enable expandable segments for cuda allocation to improve VRAM usage
35-
set_pytorch_cuda_alloc_conf()
35+
patch_optimized_env()
3636

3737
# pylint: disable=duplicate-code
3838
print_axolotl_text_art()

src/axolotl/cli/main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
filter_none_kwargs,
3030
)
3131
from axolotl.integrations.lm_eval.cli import lm_eval
32-
from axolotl.utils import set_pytorch_cuda_alloc_conf
32+
from axolotl.utils import patch_optimized_env
3333
from axolotl.utils.schemas.config import AxolotlInputConfig
3434

3535

@@ -55,6 +55,8 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
5555
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
5656
config options.
5757
"""
58+
patch_optimized_env()
59+
5860
if cloud:
5961
from axolotl.cli.cloud import do_cli_preprocess
6062

@@ -100,7 +102,7 @@ def train(
100102
config options.
101103
"""
102104
# Enable expandable segments for cuda allocation to improve VRAM usage
103-
set_pytorch_cuda_alloc_conf()
105+
patch_optimized_env()
104106

105107
if "use_ray" in kwargs and kwargs["use_ray"]:
106108
accelerate = False

src/axolotl/cli/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from axolotl.common.datasets import load_datasets, load_preference_datasets
1919
from axolotl.integrations.base import PluginManager
2020
from axolotl.train import train
21-
from axolotl.utils import set_pytorch_cuda_alloc_conf
21+
from axolotl.utils import patch_optimized_env
2222
from axolotl.utils.config import normalize_config, resolve_dtype
2323
from axolotl.utils.dict import DictDefault
2424

@@ -36,7 +36,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
3636
cli_args: Training-specific CLI arguments.
3737
"""
3838
# Enable expandable segments for cuda allocation to improve VRAM usage
39-
set_pytorch_cuda_alloc_conf()
39+
patch_optimized_env()
4040

4141
print_axolotl_text_art()
4242
check_accelerate_default_config()

src/axolotl/utils/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,12 @@ def set_pytorch_cuda_alloc_conf():
4343
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
4444
"expandable_segments:True,roundup_power2_divisions:16"
4545
)
46+
47+
48+
def patch_optimized_env():
49+
"""
50+
Patch environment variables to improve VRAM usage and increase download speed
51+
"""
52+
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
53+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
54+
set_pytorch_cuda_alloc_conf()

src/axolotl/utils/config/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def get_device():
5959

6060
def resolve_dtype(cfg):
6161
if (
62-
cfg.bf16 == "auto" and not cfg.use_ray
62+
not cfg.fp16 and cfg.bf16 == "auto" and not cfg.use_ray
6363
): # if we use ray we want to defer this check to the worker node
6464
if is_torch_bf16_gpu_available():
6565
LOG.debug("bf16 support detected, enabling for this configuration.")

0 commit comments

Comments
 (0)