Skip to content

Commit ea17a2b

Browse files
Fix the Data issue and Add the logging setups for better Experince.
1 parent ec67a21 commit ea17a2b

File tree

10 files changed

+394
-149
lines changed

10 files changed

+394
-149
lines changed

quantllm/__init__.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,25 @@
1111
TrainingLogger
1212
)
1313
from .hub import HubManager, CheckpointManager
14+
from .utils.optimizations import get_optimal_training_settings
15+
from .utils.log_config import configure_logging, enable_logging
1416

1517
from .config import (
1618
ModelConfig,
1719
DatasetConfig,
1820
TrainingConfig
1921
)
2022

23+
# Configure package-wide logging
24+
configure_logging()
25+
2126
__version__ = "0.1.0"
2227

28+
# Package metadata
29+
__title__ = "QuantLLM"
30+
__description__ = "Efficient Quantized LLM Fine-Tuning Library"
31+
__author__ = "QuantLLM Team"
32+
2333
__all__ = [
2434
# Model
2535
"Model",
@@ -42,5 +52,18 @@
4252
# Configuration
4353
"ModelConfig",
4454
"DatasetConfig",
45-
"TrainingConfig"
46-
]
55+
"TrainingConfig",
56+
57+
# Utilities
58+
"get_optimal_training_settings",
59+
"configure_logging",
60+
"enable_logging",
61+
]
62+
63+
# Initialize package-level logger with fancy welcome message
64+
logger = TrainingLogger()
65+
logger.log_success(f"""
66+
✨ QuantLLM v{__version__} initialized successfully ✨
67+
🚀 Efficient Quantized Language Model Fine-Tuning
68+
📚 Documentation: https://github.com/yourusername/QuantLLM
69+
""")

quantllm/data/dataset_preprocessor.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,33 @@
11
from datasets import Dataset
2-
from typing import Optional, Dict, Any, Callable
2+
from typing import Optional, Dict, Any, Callable, Tuple
33
from transformers import PreTrainedTokenizer
44
from ..trainer.logger import TrainingLogger
5+
from tqdm.auto import tqdm
6+
import logging
7+
import warnings
8+
9+
# Disable unnecessary logging
10+
logging.getLogger("tokenizers").setLevel(logging.ERROR)
11+
warnings.filterwarnings("ignore")
512

613
class DatasetPreprocessor:
7-
def __init__(self, tokenizer, logger=None):
14+
def __init__(self, tokenizer: PreTrainedTokenizer, logger: Optional[TrainingLogger] = None):
815
self.tokenizer = tokenizer
916
self.logger = logger or TrainingLogger()
1017

1118
# Set pad token if not set
1219
if self.tokenizer.pad_token is None:
1320
self.tokenizer.pad_token = self.tokenizer.eos_token
1421
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
15-
print("Set pad token to eos token")
22+
self.logger.log_info("Set pad token to eos token")
1623

1724
def validate_datasets(self, datasets):
1825
"""Validate input datasets."""
1926
for dataset in datasets:
2027
if dataset is not None and not isinstance(dataset, Dataset):
2128
raise ValueError(f"Expected Dataset object, got {type(dataset)}")
2229

23-
def preprocess_text(self, text):
30+
def preprocess_text(self, text: str) -> str:
2431
"""Basic text preprocessing"""
2532
if not text:
2633
return ""
@@ -30,98 +37,98 @@ def preprocess_text(self, text):
3037

3138
def tokenize_dataset(
3239
self,
33-
train_dataset,
34-
val_dataset=None,
35-
test_dataset=None,
40+
train_dataset: Dataset,
41+
val_dataset: Optional[Dataset] = None,
42+
test_dataset: Optional[Dataset] = None,
3643
max_length: int = 512,
3744
text_column: str = "text",
38-
label_column: str = None,
45+
label_column: Optional[str] = None,
3946
batch_size: int = 1000
40-
):
41-
"""Tokenize datasets with preprocessing."""
47+
) -> Tuple[Dataset, Optional[Dataset], Optional[Dataset]]:
48+
"""Tokenize datasets with preprocessing and progress bars."""
4249
try:
4350
self.validate_datasets([train_dataset, val_dataset, test_dataset])
4451

4552
def process_and_tokenize_batch(examples):
46-
# Get texts and preprocess
53+
# Get texts and preprocess with progress indication
4754
texts = examples[text_column]
4855
if not isinstance(texts, list):
4956
texts = [texts]
57+
58+
# Preprocess texts
5059
texts = [self.preprocess_text(text) for text in texts]
5160

5261
try:
5362
# Tokenize with padding and truncation
54-
# Use max_length + 1 to account for the shift we'll do later
5563
tokenized = self.tokenizer(
5664
texts,
5765
padding="max_length",
5866
truncation=True,
59-
max_length=max_length + 1, # Add 1 to account for shift
67+
max_length=max_length + 1, # Add 1 for shift
6068
return_tensors=None
6169
)
6270

71+
# For causal language modeling, prepare shifted sequences
6372
input_ids = tokenized["input_ids"]
6473
attention_mask = tokenized["attention_mask"]
6574

66-
# Now shift to create inputs and labels
67-
# inputs will be [:-1] and labels will be [1:]
75+
# Prepare shifted sequences for input and labels
6876
labels = [ids[1:] for ids in input_ids]
6977
input_ids = [ids[:-1] for ids in input_ids]
7078
attention_mask = [mask[:-1] for mask in attention_mask]
7179

72-
# Verify all sequences have the expected length
80+
# Verify sequence lengths
7381
expected_length = max_length
74-
if not all(len(seq) == expected_length for seq in input_ids):
75-
raise ValueError(f"Input sequence lengths don't match. Expected {expected_length}")
76-
if not all(len(seq) == expected_length for seq in attention_mask):
77-
raise ValueError(f"Attention mask lengths don't match. Expected {expected_length}")
78-
if not all(len(seq) == expected_length for seq in labels):
79-
raise ValueError(f"Label sequence lengths don't match. Expected {expected_length}")
82+
assert all(len(seq) == expected_length for seq in input_ids), "Input sequence lengths don't match"
83+
assert all(len(seq) == expected_length for seq in attention_mask), "Attention mask lengths don't match"
84+
assert all(len(seq) == expected_length for seq in labels), "Label sequence lengths don't match"
8085

8186
result = {
8287
"input_ids": input_ids,
8388
"attention_mask": attention_mask,
8489
"labels": labels
8590
}
8691

87-
self.logger.log_info(f"Tokenized batch of {len(texts)} texts")
8892
return result
8993

9094
except Exception as e:
9195
self.logger.log_error(f"Error tokenizing batch: {str(e)}")
9296
raise
9397

94-
# Process datasets
98+
# Process datasets with overall progress bars
99+
self.logger.log_info("Processing training dataset")
95100
train_tokenized = train_dataset.map(
96101
process_and_tokenize_batch,
97102
batched=True,
98103
batch_size=batch_size,
99104
remove_columns=train_dataset.column_names,
100105
desc="Tokenizing training set"
101106
)
102-
self.logger.log_info(f"Tokenized training dataset: {len(train_tokenized)} examples")
107+
self.logger.log_success(f"Tokenized training dataset: {len(train_tokenized)} examples")
103108

104109
val_tokenized = None
105110
if val_dataset is not None:
111+
self.logger.log_info("Processing validation dataset")
106112
val_tokenized = val_dataset.map(
107113
process_and_tokenize_batch,
108114
batched=True,
109115
batch_size=batch_size,
110116
remove_columns=val_dataset.column_names,
111117
desc="Tokenizing validation set"
112118
)
113-
self.logger.log_info(f"Tokenized validation dataset: {len(val_tokenized)} examples")
119+
self.logger.log_success(f"Tokenized validation dataset: {len(val_tokenized)} examples")
114120

115121
test_tokenized = None
116122
if test_dataset is not None:
123+
self.logger.log_info("Processing test dataset")
117124
test_tokenized = test_dataset.map(
118125
process_and_tokenize_batch,
119126
batched=True,
120127
batch_size=batch_size,
121128
remove_columns=test_dataset.column_names,
122129
desc="Tokenizing test set"
123130
)
124-
self.logger.log_info(f"Tokenized test dataset: {len(test_tokenized)} examples")
131+
self.logger.log_success(f"Tokenized test dataset: {len(test_tokenized)} examples")
125132

126133
# Set format to PyTorch tensors
127134
train_tokenized.set_format("torch")

quantllm/data/dataset_splitter.py

Lines changed: 103 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,26 @@
1-
from datasets import Dataset
2-
from typing import Optional, Dict, Any, Tuple
1+
from datasets import Dataset, DatasetDict
2+
from typing import Optional, Tuple, Union
3+
import numpy as np
34
from ..trainer.logger import TrainingLogger
5+
from tqdm.auto import tqdm
6+
import logging
7+
8+
# Configure logging
9+
logging.getLogger("datasets").setLevel(logging.WARNING)
410

511
class DatasetSplitter:
6-
def __init__(self, logger=None):
12+
def __init__(self, logger: Optional[TrainingLogger] = None):
13+
"""Initialize dataset splitter."""
714
self.logger = logger or TrainingLogger()
8-
15+
16+
def _get_dataset_from_dict(self, dataset: Union[Dataset, DatasetDict], split: str = "train") -> Dataset:
17+
"""Extract dataset from DatasetDict if needed."""
18+
if isinstance(dataset, DatasetDict):
19+
if split in dataset:
20+
return dataset[split]
21+
raise ValueError(f"DatasetDict does not contain split '{split}'")
22+
return dataset
23+
924
def validate_split_params(self, train_size: float, val_size: float, test_size: float = None):
1025
"""Validate split parameters."""
1126
if train_size <= 0 or train_size >= 1:
@@ -55,52 +70,104 @@ def train_test_split(
5570
self.logger.log_error(f"Error splitting dataset: {str(e)}")
5671
raise
5772

58-
def train_val_test_split(self, dataset, train_size: float, val_size: float, test_size: float = None):
59-
"""Split dataset into train, validation and test sets."""
73+
def train_val_test_split(
74+
self,
75+
dataset: Union[Dataset, DatasetDict],
76+
train_size: float = 0.8,
77+
val_size: float = 0.1,
78+
test_size: float = 0.1,
79+
shuffle: bool = True,
80+
seed: int = 42,
81+
split: str = "train"
82+
) -> Tuple[Dataset, Dataset, Dataset]:
83+
"""
84+
Split dataset into train, validation and test sets with progress indication.
85+
86+
Args:
87+
dataset (Dataset or DatasetDict): Dataset to split
88+
train_size (float): Proportion of training set
89+
val_size (float): Proportion of validation set
90+
test_size (float): Proportion of test set
91+
shuffle (bool): Whether to shuffle the dataset
92+
seed (int): Random seed
93+
split (str): Which split to use if dataset is a DatasetDict
94+
95+
Returns:
96+
Tuple[Dataset, Dataset, Dataset]: Train, validation and test datasets
97+
"""
6098
try:
61-
if not isinstance(dataset, Dataset):
62-
if isinstance(dataset, dict) and 'train' in dataset:
63-
dataset = dataset['train']
64-
else:
65-
raise ValueError(f"Expected Dataset object or dict with 'train' key, got {type(dataset)}")
66-
67-
if test_size is None:
68-
test_size = 1.0 - train_size - val_size
69-
70-
self.validate_split_params(train_size, val_size, test_size)
99+
# Get the actual dataset if we have a DatasetDict
100+
dataset = self._get_dataset_from_dict(dataset, split)
71101

72-
# If dataset is already split
73-
if isinstance(dataset, dict) and all(k in dataset for k in ['train', 'validation', 'test']):
74-
self.logger.log_info("Dataset already contains train/validation/test splits")
75-
return dataset['train'], dataset['validation'], dataset['test']
102+
# Validate split proportions
103+
total = train_size + val_size + test_size
104+
if not np.isclose(total, 1.0):
105+
raise ValueError(f"Split proportions must sum to 1, got {total}")
76106

77-
# Convert ratios to absolute sizes
107+
# Calculate split sizes
78108
total_size = len(dataset)
79-
if total_size == 0:
80-
raise ValueError("Dataset is empty")
81-
82-
train_end = int(total_size * train_size)
83-
val_end = train_end + int(total_size * val_size)
109+
train_samples = int(total_size * train_size)
110+
val_samples = int(total_size * val_size)
111+
test_samples = total_size - train_samples - val_samples
84112

85-
# Shuffle dataset with seed for reproducibility
86-
dataset = dataset.shuffle(seed=42)
113+
self.logger.log_info("Splitting dataset...")
87114

88-
# Split dataset
89-
train_dataset = dataset.select(range(train_end))
90-
val_dataset = dataset.select(range(train_end, val_end))
91-
test_dataset = dataset.select(range(val_end, total_size))
115+
# Create indices
116+
indices = np.arange(total_size)
117+
if shuffle:
118+
with tqdm(total=1, desc="Shuffling dataset", unit="operation") as pbar:
119+
rng = np.random.default_rng(seed)
120+
rng.shuffle(indices)
121+
pbar.update(1)
92122

93-
# Validate split sizes
94-
if len(train_dataset) == 0 or len(val_dataset) == 0 or len(test_dataset) == 0:
95-
raise ValueError("One or more splits are empty. Try adjusting split ratios.")
123+
# Split dataset using Hugging Face's built-in functionality
124+
with tqdm(total=2, desc="Creating splits", unit="split") as pbar:
125+
# First split: train vs rest
126+
train_val_split = dataset.train_test_split(
127+
train_size=train_size,
128+
seed=seed,
129+
shuffle=False # We already shuffled if needed
130+
)
131+
train_dataset = train_val_split["train"]
132+
rest_dataset = train_val_split["test"]
133+
pbar.update(1)
96134

135+
# Second split: val vs test from the rest
136+
val_ratio = val_size / (val_size + test_size)
137+
val_test_split = rest_dataset.train_test_split(
138+
train_size=val_ratio,
139+
seed=seed,
140+
shuffle=False
141+
)
142+
val_dataset = val_test_split["train"]
143+
test_dataset = val_test_split["test"]
144+
pbar.update(1)
145+
146+
# Log split sizes
97147
self.logger.log_info(f"Split sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
148+
98149
return train_dataset, val_dataset, test_dataset
99150

100151
except Exception as e:
101152
self.logger.log_error(f"Error splitting dataset: {str(e)}")
102153
raise
103-
154+
155+
def train_val_split(
156+
self,
157+
dataset: Union[Dataset, DatasetDict],
158+
train_size: float = 0.8,
159+
shuffle: bool = True,
160+
seed: int = 42,
161+
split: str = "train"
162+
) -> Tuple[Dataset, Dataset]:
163+
"""Split dataset into train and validation sets."""
164+
dataset = self._get_dataset_from_dict(dataset, split)
165+
return dataset.train_test_split(
166+
train_size=train_size,
167+
shuffle=shuffle,
168+
seed=seed
169+
).values()
170+
104171
def k_fold_split(self, dataset, n_splits: int = 5, shuffle: bool = True, seed: int = 42):
105172
"""Create k-fold cross validation splits."""
106173
try:

0 commit comments

Comments
 (0)