Skip to content

Commit ec67a21

Browse files
Fix the training issue and add the CPU, GPU and MPS for training.
1 parent 33b7c61 commit ec67a21

File tree

7 files changed

+393
-243
lines changed

7 files changed

+393
-243
lines changed

quantllm/config/model_config.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class ModelConfig:
1010
model_name: str
1111
model_type: str = "auto"
1212
revision: str = "main"
13-
trust_remote_code: bool = False
13+
trust_remote_code: bool = True
1414

1515
# Model architecture
1616
hidden_size: Optional[int] = None
@@ -42,6 +42,12 @@ class ModelConfig:
4242
lora_config: Optional[Dict[str, Any]] = None
4343
use_lora: bool = False
4444

45+
# CPU optimization
46+
cpu_offload: bool = False
47+
gradient_checkpointing: bool = False
48+
bf16: bool = False # bfloat16 support for more efficient training
49+
max_memory: Optional[dict] = None # For device specific memory limits
50+
4551
kwargs: Optional[Dict[str, Any]] = None
4652
device_map: Optional[Dict[str, str]] = 'auto' # 'auto' or specific device mapping
4753

@@ -58,6 +64,20 @@ def __post_init__(self):
5864

5965
if self.kwargs is None:
6066
self.kwargs = {}
67+
68+
if self.load_in_4bit and self.load_in_8bit:
69+
raise ValueError("Cannot use both 4-bit and 8-bit quantization simultaneously")
70+
71+
# Set reasonable defaults for memory management
72+
if self.max_memory is None:
73+
import torch
74+
if torch.cuda.is_available():
75+
# Leave some GPU memory free for system
76+
total_memory = torch.cuda.get_device_properties(0).total_memory
77+
self.max_memory = {0: f"{int(total_memory * 0.85 / 1024**3)}GiB"}
78+
else:
79+
# Default CPU memory limit
80+
self.max_memory = {"cpu": "16GiB"}
6181

6282
def to_dict(self) -> Dict[str, Any]:
6383
"""Convert configuration to dictionary."""
@@ -88,6 +108,10 @@ def to_dict(self) -> Dict[str, Any]:
88108
"bnb_4bit_use_double_quant": self.bnb_4bit_use_double_quant,
89109
"lora_config": self.lora_config,
90110
"use_lora": self.use_lora,
111+
"cpu_offload": self.cpu_offload,
112+
"gradient_checkpointing": self.gradient_checkpointing,
113+
"bf16": self.bf16,
114+
"max_memory": self.max_memory,
91115
"kwargs": self.kwargs,
92116
"device_map": self.device_map
93117
}

quantllm/data/dataloader.py

Lines changed: 37 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,11 @@
44
from datasets import Dataset as HFDataset
55
from .dataset_preprocessor import DatasetPreprocessor
66

7-
class DataLoader(TorchDataLoader):
7+
class DataLoader:
88
"""
9-
Custom DataLoader class for QuantLLM that inherits from torch.utils.data.DataLoader.
10-
Provides additional functionality and easier integration with the QuantLLM package.
9+
Custom DataLoader class for QuantLLM that wraps torch.utils.data.DataLoader.
1110
"""
1211

13-
def __init__(
14-
self,
15-
dataset: Dataset,
16-
batch_size: int = 4,
17-
shuffle: bool = True,
18-
num_workers: int = 4,
19-
pin_memory: bool = True,
20-
drop_last: bool = False,
21-
**kwargs
22-
):
23-
"""
24-
Initialize the QuantLLM DataLoader.
25-
26-
Args:
27-
dataset (Dataset): The dataset to load
28-
batch_size (int): Number of samples per batch
29-
shuffle (bool): Whether to shuffle the data
30-
num_workers (int): Number of worker processes for data loading
31-
pin_memory (bool): Whether to pin memory for faster data transfer to GPU
32-
drop_last (bool): Whether to drop the last incomplete batch
33-
**kwargs: Additional arguments to pass to the DataLoader
34-
"""
35-
self.loader = TorchDataLoader(
36-
dataset=dataset,
37-
batch_size=batch_size,
38-
shuffle=shuffle,
39-
num_workers=num_workers,
40-
pin_memory=pin_memory,
41-
drop_last=drop_last,
42-
**kwargs
43-
)
44-
self.dataset = dataset
45-
self.batch_size = batch_size
46-
4712
@staticmethod
4813
def validate_dataset(dataset, name: str):
4914
"""Validate dataset."""
@@ -74,24 +39,47 @@ def from_datasets(
7439
if batch_size <= 0:
7540
raise ValueError(f"batch_size must be positive, got {batch_size}")
7641

77-
# Convert HuggingFace Dataset to PyTorch Dataset if needed
78-
def convert_to_torch_dataset(hf_dataset):
79-
if hf_dataset is None:
42+
def prepare_dataset(dataset):
43+
if dataset is None:
8044
return None
81-
if isinstance(hf_dataset, HFDataset):
82-
return hf_dataset.with_format("torch")
83-
return hf_dataset
45+
46+
if isinstance(dataset, HFDataset):
47+
# Ensure all required features are present
48+
required_features = ['input_ids', 'attention_mask', 'labels']
49+
if not all(feature in dataset.features for feature in required_features):
50+
raise ValueError(f"Dataset must contain all required features: {required_features}")
51+
52+
# Get feature dimensions
53+
sample_len = len(dataset[0]['input_ids'])
54+
total_samples = len(dataset)
55+
56+
# Pre-allocate tensors
57+
input_ids = torch.zeros((total_samples, sample_len), dtype=torch.long)
58+
attention_mask = torch.zeros((total_samples, sample_len), dtype=torch.long)
59+
labels = torch.zeros((total_samples, sample_len), dtype=torch.long)
60+
61+
# Fill tensors
62+
for i in range(total_samples):
63+
input_ids[i] = torch.tensor(dataset[i]['input_ids'])
64+
attention_mask[i] = torch.tensor(dataset[i]['attention_mask'])
65+
labels[i] = torch.tensor(dataset[i]['labels'])
66+
67+
return TensorDataset(input_ids, attention_mask, labels)
68+
69+
return dataset
8470

85-
train_dataset = convert_to_torch_dataset(train_dataset)
86-
val_dataset = convert_to_torch_dataset(val_dataset)
87-
test_dataset = convert_to_torch_dataset(test_dataset)
71+
train_dataset = prepare_dataset(train_dataset)
72+
val_dataset = prepare_dataset(val_dataset)
73+
test_dataset = prepare_dataset(test_dataset)
8874

75+
# Create DataLoaders with consistent batch sizes
8976
train_loader = TorchDataLoader(
9077
train_dataset,
9178
batch_size=batch_size,
9279
shuffle=shuffle,
9380
num_workers=num_workers,
9481
pin_memory=pin_memory and torch.cuda.is_available(),
82+
drop_last=True, # Drop last incomplete batch
9583
**kwargs
9684
) if train_dataset is not None else None
9785

@@ -101,6 +89,7 @@ def convert_to_torch_dataset(hf_dataset):
10189
shuffle=False,
10290
num_workers=num_workers,
10391
pin_memory=pin_memory and torch.cuda.is_available(),
92+
drop_last=True, # Drop last incomplete batch
10493
**kwargs
10594
) if val_dataset is not None else None
10695

@@ -110,83 +99,12 @@ def convert_to_torch_dataset(hf_dataset):
11099
shuffle=False,
111100
num_workers=num_workers,
112101
pin_memory=pin_memory and torch.cuda.is_available(),
102+
drop_last=True, # Drop last incomplete batch
113103
**kwargs
114104
) if test_dataset is not None else None
115105

116106
return train_loader, val_loader, test_loader
117107

118108
except Exception as e:
119109
print(f"Error creating data loaders: {str(e)}")
120-
raise
121-
122-
@classmethod
123-
def from_tensors(
124-
cls,
125-
input_ids,
126-
attention_mask,
127-
labels=None,
128-
batch_size: int = 8,
129-
**kwargs
130-
):
131-
"""Create DataLoader from tensor inputs."""
132-
try:
133-
if not isinstance(input_ids, torch.Tensor):
134-
input_ids = torch.tensor(input_ids)
135-
if not isinstance(attention_mask, torch.Tensor):
136-
attention_mask = torch.tensor(attention_mask)
137-
138-
if labels is not None:
139-
if not isinstance(labels, torch.Tensor):
140-
labels = torch.tensor(labels)
141-
dataset = TensorDataset(input_ids, attention_mask, labels)
142-
else:
143-
dataset = TensorDataset(input_ids, attention_mask)
144-
145-
return TorchDataLoader(
146-
dataset,
147-
batch_size=batch_size,
148-
**kwargs
149-
)
150-
151-
except Exception as e:
152-
raise RuntimeError(f"Error creating data loader from tensors: {str(e)}")
153-
154-
def get_batch(self) -> Dict[str, torch.Tensor]:
155-
"""
156-
Get a single batch from the DataLoader.
157-
158-
Returns:
159-
Dict[str, torch.Tensor]: Dictionary containing the batch data
160-
"""
161-
try:
162-
batch = next(iter(self.loader))
163-
return batch
164-
except StopIteration:
165-
raise RuntimeError("No more batches available in the DataLoader")
166-
167-
def get_batch_size(self) -> int:
168-
"""
169-
Get the current batch size of the DataLoader.
170-
171-
Returns:
172-
int: Current batch size
173-
"""
174-
return self.batch_size
175-
176-
def get_dataset_size(self) -> int:
177-
"""
178-
Get the size of the underlying dataset.
179-
180-
Returns:
181-
int: Size of the dataset
182-
"""
183-
return len(self.dataset)
184-
185-
def get_num_batches(self) -> int:
186-
"""
187-
Get the total number of batches in the DataLoader.
188-
189-
Returns:
190-
int: Total number of batches
191-
"""
192-
return len(self.loader)
110+
raise

quantllm/data/dataset_preprocessor.py

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ def __init__(self, tokenizer, logger=None):
1111
# Set pad token if not set
1212
if self.tokenizer.pad_token is None:
1313
self.tokenizer.pad_token = self.tokenizer.eos_token
14-
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
15-
print("Added [PAD] token to tokenizer")
14+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
15+
print("Set pad token to eos token")
1616

1717
def validate_datasets(self, datasets):
1818
"""Validate input datasets."""
@@ -43,68 +43,95 @@ def tokenize_dataset(
4343
self.validate_datasets([train_dataset, val_dataset, test_dataset])
4444

4545
def process_and_tokenize_batch(examples):
46+
# Get texts and preprocess
4647
texts = examples[text_column]
4748
if not isinstance(texts, list):
4849
texts = [texts]
49-
50-
# Preprocess texts
5150
texts = [self.preprocess_text(text) for text in texts]
5251

5352
try:
53+
# Tokenize with padding and truncation
54+
# Use max_length + 1 to account for the shift we'll do later
5455
tokenized = self.tokenizer(
5556
texts,
56-
padding=True,
57+
padding="max_length",
5758
truncation=True,
58-
max_length=max_length,
59-
return_tensors="pt"
59+
max_length=max_length + 1, # Add 1 to account for shift
60+
return_tensors=None
6061
)
6162

63+
input_ids = tokenized["input_ids"]
64+
attention_mask = tokenized["attention_mask"]
65+
66+
# Now shift to create inputs and labels
67+
# inputs will be [:-1] and labels will be [1:]
68+
labels = [ids[1:] for ids in input_ids]
69+
input_ids = [ids[:-1] for ids in input_ids]
70+
attention_mask = [mask[:-1] for mask in attention_mask]
71+
72+
# Verify all sequences have the expected length
73+
expected_length = max_length
74+
if not all(len(seq) == expected_length for seq in input_ids):
75+
raise ValueError(f"Input sequence lengths don't match. Expected {expected_length}")
76+
if not all(len(seq) == expected_length for seq in attention_mask):
77+
raise ValueError(f"Attention mask lengths don't match. Expected {expected_length}")
78+
if not all(len(seq) == expected_length for seq in labels):
79+
raise ValueError(f"Label sequence lengths don't match. Expected {expected_length}")
80+
6281
result = {
63-
"input_ids": tokenized["input_ids"],
64-
"attention_mask": tokenized["attention_mask"]
82+
"input_ids": input_ids,
83+
"attention_mask": attention_mask,
84+
"labels": labels
6585
}
6686

67-
if label_column and label_column in examples:
68-
result["labels"] = examples[label_column]
69-
70-
print(f"Tokenized batch of {len(texts)} texts") # User feedback
87+
self.logger.log_info(f"Tokenized batch of {len(texts)} texts")
7188
return result
7289

7390
except Exception as e:
74-
print(f"Error tokenizing batch: {str(e)}") # User feedback
91+
self.logger.log_error(f"Error tokenizing batch: {str(e)}")
7592
raise
7693

7794
# Process datasets
7895
train_tokenized = train_dataset.map(
7996
process_and_tokenize_batch,
8097
batched=True,
8198
batch_size=batch_size,
82-
remove_columns=train_dataset.column_names
99+
remove_columns=train_dataset.column_names,
100+
desc="Tokenizing training set"
83101
)
84-
print(f"Tokenized training dataset: {len(train_tokenized)} examples") # User feedback
102+
self.logger.log_info(f"Tokenized training dataset: {len(train_tokenized)} examples")
85103

86104
val_tokenized = None
87105
if val_dataset is not None:
88106
val_tokenized = val_dataset.map(
89107
process_and_tokenize_batch,
90108
batched=True,
91109
batch_size=batch_size,
92-
remove_columns=val_dataset.column_names
110+
remove_columns=val_dataset.column_names,
111+
desc="Tokenizing validation set"
93112
)
94-
print(f"Tokenized validation dataset: {len(val_tokenized)} examples") # User feedback
113+
self.logger.log_info(f"Tokenized validation dataset: {len(val_tokenized)} examples")
95114

96115
test_tokenized = None
97116
if test_dataset is not None:
98117
test_tokenized = test_dataset.map(
99118
process_and_tokenize_batch,
100119
batched=True,
101120
batch_size=batch_size,
102-
remove_columns=test_dataset.column_names
121+
remove_columns=test_dataset.column_names,
122+
desc="Tokenizing test set"
103123
)
104-
print(f"Tokenized test dataset: {len(test_tokenized)} examples") # User feedback
124+
self.logger.log_info(f"Tokenized test dataset: {len(test_tokenized)} examples")
125+
126+
# Set format to PyTorch tensors
127+
train_tokenized.set_format("torch")
128+
if val_tokenized:
129+
val_tokenized.set_format("torch")
130+
if test_tokenized:
131+
test_tokenized.set_format("torch")
105132

106133
return train_tokenized, val_tokenized, test_tokenized
107134

108135
except Exception as e:
109-
print(f"Error in tokenization: {str(e)}") # User feedback
136+
self.logger.log_error(f"Error in tokenization: {str(e)}")
110137
raise

0 commit comments

Comments
 (0)