Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion src/llama_cookbook/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import inspect
from dataclasses import asdict
from dataclasses import asdict, fields
from enum import Enum
from typing import get_type_hints

import torch.distributed as dist
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
Expand All @@ -19,19 +21,43 @@
from llama_cookbook.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
from llama_cookbook.datasets import DATASET_PREPROC


def _convert_to_field_type(config, param_name, value):
"""
Convert a value to the expected field type for a dataclass field.
Handles enum conversion from string values (e.g., "FULL_SHARD" -> ShardingStrategy.FULL_SHARD).
"""
try:
type_hints = get_type_hints(type(config))
expected_type = type_hints.get(param_name)

if expected_type is not None and isinstance(value, str):
# Check if expected type is an Enum
if isinstance(expected_type, type) and issubclass(expected_type, Enum):
return expected_type[value]
except (KeyError, TypeError, ValueError):
# If conversion fails, return original value and let it fail later with a clear error
pass
return value


def update_config(config, **kwargs):
if isinstance(config, (tuple, list)):
for c in config:
update_config(c, **kwargs)
else:
for k, v in kwargs.items():
if hasattr(config, k):
# Convert string values to enum types if needed
v = _convert_to_field_type(config, k, v)
setattr(config, k, v)
elif "." in k:
# allow --some_config.some_param=True
config_name, param_name = k.split(".")
if type(config).__name__ == config_name:
if hasattr(config, param_name):
# Convert string values to enum types if needed
v = _convert_to_field_type(config, param_name, v)
setattr(config, param_name, v)
else:
# In case of specialized config we can warn user
Expand Down