Skip to content

Commit fee3c13

Browse files
authored
Logging config for colab (axolotl-ai-cloud#2611)
* only configure logging on cli to play nicely with colab * allow reloading the config on the fly from a dict * make sure to use dict for yaml * reuse existing function for load * make cli args optional * mps fix and respect max_steps
1 parent 996fc12 commit fee3c13

File tree

11 files changed

+33
-26
lines changed

11 files changed

+33
-26
lines changed

src/axolotl/cli/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,7 @@
22

33
import os
44

5+
from axolotl.logging_config import configure_logging
6+
57
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
8+
configure_logging()

src/axolotl/cli/checks.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
from huggingface_hub import HfApi
99
from huggingface_hub.utils import LocalTokenNotFoundError
1010

11-
from axolotl.logging_config import configure_logging
12-
13-
configure_logging()
1411
LOG = logging.getLogger(__name__)
1512

1613

src/axolotl/cli/config.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import tempfile
77
from pathlib import Path
8+
from tempfile import NamedTemporaryFile
89
from typing import Union
910
from urllib.parse import urlparse
1011

@@ -158,7 +159,9 @@ def plugin_set_cfg(cfg: DictDefault):
158159
plugin_manager.cfg = cfg
159160

160161

161-
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
162+
def load_cfg(
163+
config: str | Path | DictDefault = Path("examples/"), **kwargs
164+
) -> DictDefault:
162165
"""
163166
Loads the `axolotl` configuration stored at `config`, validates it, and performs
164167
various setup.
@@ -170,13 +173,24 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
170173
Returns:
171174
`DictDefault` mapping configuration keys to values.
172175
"""
173-
config = check_remote_config(config)
174-
if Path(config).is_dir():
175-
config = choose_config(Path(config))
176-
177-
# Load the config from the yaml file
178-
with open(config, encoding="utf-8") as file:
179-
cfg: DictDefault = DictDefault(yaml.safe_load(file))
176+
if isinstance(config, (str, Path)):
177+
config = check_remote_config(config)
178+
if Path(config).is_dir():
179+
config = choose_config(Path(config))
180+
181+
# Load the config from the yaml file
182+
with open(config, encoding="utf-8") as file:
183+
cfg: DictDefault = DictDefault(yaml.safe_load(file))
184+
185+
cfg.axolotl_config_path = config
186+
else:
187+
cfg = config
188+
with NamedTemporaryFile(
189+
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
190+
) as temp_file:
191+
temp_file.write(yaml.dump(config.to_dict()))
192+
temp_file.close()
193+
cfg.axolotl_config_path = temp_file.name
180194

181195
# If there are any options passed in the cli, if it is something that seems valid
182196
# from the yaml, then overwrite the value
@@ -190,8 +204,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefa
190204
else:
191205
cfg[k] = kwargs[k]
192206

193-
cfg.axolotl_config_path = config
194-
195207
try:
196208
device_props = torch.cuda.get_device_properties("cuda")
197209
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)

src/axolotl/cli/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,9 @@
2020
ProcessorMixin,
2121
)
2222

23-
from axolotl.logging_config import configure_logging
2423
from axolotl.utils.dict import DictDefault
2524
from axolotl.utils.models import load_model, load_processor, load_tokenizer
2625

27-
configure_logging()
2826
LOG = logging.getLogger(__name__)
2927

3028

src/axolotl/common/datasets.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
4747
def load_datasets(
4848
*,
4949
cfg: DictDefault,
50-
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
50+
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
5151
) -> TrainDatasetMeta:
5252
"""
5353
Loads one or more training or evaluation datasets, calling
@@ -64,7 +64,8 @@ def load_datasets(
6464
tokenizer = load_tokenizer(cfg)
6565
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
6666
preprocess_iterable = (
67-
hasattr(cli_args, "iterable")
67+
cli_args
68+
and hasattr(cli_args, "iterable")
6869
and cli_args.iterable is not None
6970
and cli_args.iterable
7071
)
@@ -76,7 +77,7 @@ def load_datasets(
7677
preprocess_iterable=preprocess_iterable,
7778
)
7879

79-
if (
80+
if cli_args and (
8081
cli_args.debug
8182
or cfg.debug
8283
or cli_args.debug_text_only

src/axolotl/core/trainer_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def build(self, total_num_steps):
488488

489489
# these are all the "standard" kwargs that are def used
490490
training_arguments_kwargs["max_steps"] = (
491-
total_num_steps if self.cfg.max_steps else -1
491+
self.cfg.max_steps if self.cfg.max_steps else -1
492492
)
493493
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
494494
training_arguments_kwargs["per_device_train_batch_size"] = (

src/axolotl/evaluate.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from datasets import Dataset
1212
from transformers.trainer import Trainer
1313

14-
from axolotl.logging_config import configure_logging
1514
from axolotl.train import (
1615
TrainDatasetMeta,
1716
setup_model_and_tokenizer,
@@ -24,7 +23,6 @@
2423
src_dir = os.path.join(project_root, "src")
2524
sys.path.insert(0, src_dir)
2625

27-
configure_logging()
2826
LOG = get_logger(__name__)
2927

3028

src/axolotl/monkeypatch/attention/ring_attn/patch.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
import torch.distributed as dist
1313
from accelerate.logging import get_logger
1414

15-
from axolotl.logging_config import configure_logging
1615
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
1716

18-
configure_logging()
1917
LOG = get_logger(__name__)
2018

2119

src/axolotl/train.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
SequenceParallelContextManager,
3131
)
3232
from axolotl.integrations.base import PluginManager
33-
from axolotl.logging_config import configure_logging
3433
from axolotl.utils.dict import DictDefault
3534
from axolotl.utils.distributed import cleanup_distributed
3635
from axolotl.utils.freeze import freeze_layers_except
@@ -42,7 +41,6 @@
4241
except ImportError:
4342
BetterTransformer = None
4443

45-
configure_logging()
4644
LOG = get_logger(__name__)
4745

4846

src/axolotl/utils/config/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def resolve_dtype(cfg):
6767
else:
6868
LOG.debug("bf16 support not detected, disabling for this configuration.")
6969
cfg.bf16 = False
70-
if cfg.fp16 is None:
70+
if cfg.fp16 is None and not cfg.float16:
7171
cfg.fp16 = True
7272

7373
if cfg.device == "mps":

0 commit comments

Comments
 (0)