Skip to content

Commit 62efeec

Browse files
committed
PR changes
Signed-off-by: Abhishek <[email protected]>
1 parent ba0e543 commit 62efeec

File tree

3 files changed

+21
-17
lines changed

3 files changed

+21
-17
lines changed

tuning/data/data_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
# Local
2222
from tuning.utils.utils import load_yaml_or_json
2323

24+
logger = logging.getLogger(__name__)
25+
2426

2527
@dataclass
2628
class DataHandlerConfig:
@@ -82,9 +84,7 @@ def _validate_dataset_config(dataset_config) -> DataSetConfig:
8284
assert isinstance(p, str), f"path {p} should be of the type string"
8385
if not os.path.isabs(p):
8486
_p = os.path.abspath(p)
85-
logging.warning(
86-
" Provided path %s is not absolute changing it to %s", p, _p
87-
)
87+
logger.warning(" Provided path %s is not absolute changing it to %s", p, _p)
8888
p = _p
8989
c.data_paths.append(p)
9090
if "builder" in kwargs and kwargs["builder"] is not None:

tuning/data/data_processors.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from tuning.data.data_handlers import AVAILABLE_DATA_HANDLERS
3030
from tuning.utils.utils import get_loader_for_filepath, validate_mergeable_datasets
3131

32+
logger = logging.getLogger(__name__)
33+
3234

3335
class DataPreProcessor:
3436

@@ -54,11 +56,11 @@ def register_data_handler(self, name: str, func: Callable):
5456
if not isinstance(name, str) or not callable(func):
5557
raise ValueError("Handlers should be of type Dict, str to callable")
5658
if name in self.registered_handlers:
57-
logging.warning(
59+
logger.warning(
5860
"Handler name '%s' already exists and will be overwritten", name
5961
)
6062
self.registered_handlers[name] = func
61-
logging.info("Registered new handler %s", name)
63+
logger.info("Registered new handler %s", name)
6264

6365
def register_data_handlers(self, handlers: Dict[str, Callable]):
6466
if handlers is None:
@@ -175,7 +177,7 @@ def _load_dataset(data_path=None, builder=None, data_files=None, data_dir=None):
175177
return all_datasets[0]
176178

177179
raw_datasets = datasets.concatenate_datasets(all_datasets)
178-
logging.info(
180+
logger.info(
179181
"Datasets concatenated from %s .Concatenated dataset columns: %s",
180182
datasetconfig.name,
181183
list(raw_datasets.features.keys()),
@@ -207,25 +209,25 @@ def _process_dataset_configs(
207209
if sum(p for p in sampling_probabilities) != 1:
208210
raise ValueError("Sampling probabilities don't sum to 1")
209211
sample_datasets = True
210-
logging.info(
212+
logger.info(
211213
"Sampling ratios are specified; given datasets will be interleaved."
212214
)
213215
else:
214-
logging.info(
216+
logger.info(
215217
"Sampling is not specified; if multiple datasets are provided,"
216218
" the given datasets will be concatenated."
217219
)
218220
sample_datasets = False
219221

220-
logging.info("Starting DataPreProcessor...")
222+
logger.info("Starting DataPreProcessor...")
221223
# Now Iterate over the multiple datasets provided to us to process
222224
for d in dataset_configs:
223-
logging.info("Loading %s", d.name)
225+
logger.info("Loading %s", d.name)
224226

225227
# In future the streaming etc go as kwargs of this function
226228
raw_dataset = self.load_dataset(d, splitName)
227229

228-
logging.info("Loaded raw dataset : %s", str(raw_dataset))
230+
logger.info("Loaded raw dataset : %s", str(raw_dataset))
229231

230232
raw_datasets = DatasetDict()
231233

@@ -266,7 +268,7 @@ def _process_dataset_configs(
266268

267269
kwargs["fn_kwargs"] = dict(kwargs["fn_kwargs"], **extra_kwargs)
268270

269-
logging.info("Applying Handler: %s Args: %s", data_handler, kwargs)
271+
logger.info("Applying Handler: %s Args: %s", data_handler, kwargs)
270272

271273
raw_datasets = raw_datasets.map(handler, **kwargs)
272274

@@ -285,7 +287,7 @@ def _process_dataset_configs(
285287
if sample_datasets:
286288
strategy = self.processor_config.sampling_stopping_strategy
287289
seed = self.processor_config.sampling_seed
288-
logging.info(
290+
logger.info(
289291
"Interleaving datasets: strategy[%s] seed[%d] probabilities[%s]",
290292
strategy,
291293
seed,
@@ -316,7 +318,7 @@ def process_dataset_configs(
316318

317319
if torch.distributed.is_available() and torch.distributed.is_initialized():
318320
if torch.distributed.get_rank() == 0:
319-
logging.info("Processing data on rank 0...")
321+
logger.info("Processing data on rank 0...")
320322
train_dataset = self._process_dataset_configs(dataset_configs, **kwargs)
321323
else:
322324
train_dataset = None
@@ -329,7 +331,7 @@ def process_dataset_configs(
329331
torch.distributed.broadcast_object_list(to_share, src=0)
330332
train_dataset = to_share[0]
331333
else:
332-
logging.info("Processing data...")
334+
logger.info("Processing data...")
333335
train_dataset = self._process_dataset_configs(dataset_configs, **kwargs)
334336

335337
return train_dataset

tuning/data/setup_dataprocessor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
from tuning.data.data_preprocessing_utils import get_data_collator
3434
from tuning.data.data_processors import get_datapreprocessor
3535

36+
logger = logging.getLogger(__name__)
37+
3638
# In future we may make the fields configurable
3739
DEFAULT_INPUT_COLUMN = "input"
3840
DEFAULT_OUTPUT_COLUMN = "output"
@@ -320,9 +322,9 @@ def process_dataargs(
320322
"""
321323

322324
max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length)
323-
logging.info("Max sequence length is %s", max_seq_length)
325+
logger.info("Max sequence length is %s", max_seq_length)
324326
if train_args.max_seq_length > tokenizer.model_max_length:
325-
logging.warning(
327+
logger.warning(
326328
"max_seq_length %s exceeds tokenizer.model_max_length \
327329
%s, using tokenizer.model_max_length %s",
328330
train_args.max_seq_length,

0 commit comments

Comments
 (0)