44from datasets import Dataset as HFDataset
55from .dataset_preprocessor import DatasetPreprocessor
66
7- class DataLoader ( TorchDataLoader ) :
7+ class DataLoader :
88 """
9- Custom DataLoader class for QuantLLM that inherits from torch.utils.data.DataLoader.
10- Provides additional functionality and easier integration with the QuantLLM package.
9+ Custom DataLoader class for QuantLLM that wraps torch.utils.data.DataLoader.
1110 """
1211
13- def __init__ (
14- self ,
15- dataset : Dataset ,
16- batch_size : int = 4 ,
17- shuffle : bool = True ,
18- num_workers : int = 4 ,
19- pin_memory : bool = True ,
20- drop_last : bool = False ,
21- ** kwargs
22- ):
23- """
24- Initialize the QuantLLM DataLoader.
25-
26- Args:
27- dataset (Dataset): The dataset to load
28- batch_size (int): Number of samples per batch
29- shuffle (bool): Whether to shuffle the data
30- num_workers (int): Number of worker processes for data loading
31- pin_memory (bool): Whether to pin memory for faster data transfer to GPU
32- drop_last (bool): Whether to drop the last incomplete batch
33- **kwargs: Additional arguments to pass to the DataLoader
34- """
35- self .loader = TorchDataLoader (
36- dataset = dataset ,
37- batch_size = batch_size ,
38- shuffle = shuffle ,
39- num_workers = num_workers ,
40- pin_memory = pin_memory ,
41- drop_last = drop_last ,
42- ** kwargs
43- )
44- self .dataset = dataset
45- self .batch_size = batch_size
46-
4712 @staticmethod
4813 def validate_dataset (dataset , name : str ):
4914 """Validate dataset."""
@@ -74,24 +39,47 @@ def from_datasets(
7439 if batch_size <= 0 :
7540 raise ValueError (f"batch_size must be positive, got { batch_size } " )
7641
77- # Convert HuggingFace Dataset to PyTorch Dataset if needed
78- def convert_to_torch_dataset (hf_dataset ):
79- if hf_dataset is None :
42+ def prepare_dataset (dataset ):
43+ if dataset is None :
8044 return None
81- if isinstance (hf_dataset , HFDataset ):
82- return hf_dataset .with_format ("torch" )
83- return hf_dataset
45+
46+ if isinstance (dataset , HFDataset ):
47+ # Ensure all required features are present
48+ required_features = ['input_ids' , 'attention_mask' , 'labels' ]
49+ if not all (feature in dataset .features for feature in required_features ):
50+ raise ValueError (f"Dataset must contain all required features: { required_features } " )
51+
52+ # Get feature dimensions
53+ sample_len = len (dataset [0 ]['input_ids' ])
54+ total_samples = len (dataset )
55+
56+ # Pre-allocate tensors
57+ input_ids = torch .zeros ((total_samples , sample_len ), dtype = torch .long )
58+ attention_mask = torch .zeros ((total_samples , sample_len ), dtype = torch .long )
59+ labels = torch .zeros ((total_samples , sample_len ), dtype = torch .long )
60+
61+ # Fill tensors
62+ for i in range (total_samples ):
63+ input_ids [i ] = torch .tensor (dataset [i ]['input_ids' ])
64+ attention_mask [i ] = torch .tensor (dataset [i ]['attention_mask' ])
65+ labels [i ] = torch .tensor (dataset [i ]['labels' ])
66+
67+ return TensorDataset (input_ids , attention_mask , labels )
68+
69+ return dataset
8470
85- train_dataset = convert_to_torch_dataset (train_dataset )
86- val_dataset = convert_to_torch_dataset (val_dataset )
87- test_dataset = convert_to_torch_dataset (test_dataset )
71+ train_dataset = prepare_dataset (train_dataset )
72+ val_dataset = prepare_dataset (val_dataset )
73+ test_dataset = prepare_dataset (test_dataset )
8874
75+ # Create DataLoaders with consistent batch sizes
8976 train_loader = TorchDataLoader (
9077 train_dataset ,
9178 batch_size = batch_size ,
9279 shuffle = shuffle ,
9380 num_workers = num_workers ,
9481 pin_memory = pin_memory and torch .cuda .is_available (),
82+ drop_last = True , # Drop last incomplete batch
9583 ** kwargs
9684 ) if train_dataset is not None else None
9785
@@ -101,6 +89,7 @@ def convert_to_torch_dataset(hf_dataset):
10189 shuffle = False ,
10290 num_workers = num_workers ,
10391 pin_memory = pin_memory and torch .cuda .is_available (),
92+ drop_last = True , # Drop last incomplete batch
10493 ** kwargs
10594 ) if val_dataset is not None else None
10695
@@ -110,83 +99,12 @@ def convert_to_torch_dataset(hf_dataset):
11099 shuffle = False ,
111100 num_workers = num_workers ,
112101 pin_memory = pin_memory and torch .cuda .is_available (),
102+ drop_last = True , # Drop last incomplete batch
113103 ** kwargs
114104 ) if test_dataset is not None else None
115105
116106 return train_loader , val_loader , test_loader
117107
118108 except Exception as e :
119109 print (f"Error creating data loaders: { str (e )} " )
120- raise
121-
122- @classmethod
123- def from_tensors (
124- cls ,
125- input_ids ,
126- attention_mask ,
127- labels = None ,
128- batch_size : int = 8 ,
129- ** kwargs
130- ):
131- """Create DataLoader from tensor inputs."""
132- try :
133- if not isinstance (input_ids , torch .Tensor ):
134- input_ids = torch .tensor (input_ids )
135- if not isinstance (attention_mask , torch .Tensor ):
136- attention_mask = torch .tensor (attention_mask )
137-
138- if labels is not None :
139- if not isinstance (labels , torch .Tensor ):
140- labels = torch .tensor (labels )
141- dataset = TensorDataset (input_ids , attention_mask , labels )
142- else :
143- dataset = TensorDataset (input_ids , attention_mask )
144-
145- return TorchDataLoader (
146- dataset ,
147- batch_size = batch_size ,
148- ** kwargs
149- )
150-
151- except Exception as e :
152- raise RuntimeError (f"Error creating data loader from tensors: { str (e )} " )
153-
154- def get_batch (self ) -> Dict [str , torch .Tensor ]:
155- """
156- Get a single batch from the DataLoader.
157-
158- Returns:
159- Dict[str, torch.Tensor]: Dictionary containing the batch data
160- """
161- try :
162- batch = next (iter (self .loader ))
163- return batch
164- except StopIteration :
165- raise RuntimeError ("No more batches available in the DataLoader" )
166-
167- def get_batch_size (self ) -> int :
168- """
169- Get the current batch size of the DataLoader.
170-
171- Returns:
172- int: Current batch size
173- """
174- return self .batch_size
175-
176- def get_dataset_size (self ) -> int :
177- """
178- Get the size of the underlying dataset.
179-
180- Returns:
181- int: Size of the dataset
182- """
183- return len (self .dataset )
184-
185- def get_num_batches (self ) -> int :
186- """
187- Get the total number of batches in the DataLoader.
188-
189- Returns:
190- int: Total number of batches
191- """
192- return len (self .loader )
110+ raise
0 commit comments