diff --git a/FlagEmbedding/abc/finetune/embedder/AbsDataset.py b/FlagEmbedding/abc/finetune/embedder/AbsDataset.py index 26430330..4da7dbee 100644 --- a/FlagEmbedding/abc/finetune/embedder/AbsDataset.py +++ b/FlagEmbedding/abc/finetune/embedder/AbsDataset.py @@ -63,7 +63,8 @@ def _load_dataset(self, file_path: str): Returns: datasets.Dataset: Loaded HF dataset. """ - if dist.get_rank() == 0: + safe_rank = dist.get_rank() if dist.is_initialized() else 0 + if safe_rank == 0: logger.info(f'loading data from {file_path} ...') temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=self.args.cache_path) @@ -342,7 +343,8 @@ def _load_dataset(self, file_path: str): Returns: datasets.Dataset: The loaded dataset. """ - if dist.get_rank() == 0: + safe_rank = dist.get_rank() if dist.is_initialized() else 0 + if safe_rank == 0: logger.info(f'loading data from {file_path} ...') temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=self.args.cache_path) diff --git a/FlagEmbedding/abc/finetune/embedder/AbsModeling.py b/FlagEmbedding/abc/finetune/embedder/AbsModeling.py index 9ba9e829..de8c80b0 100644 --- a/FlagEmbedding/abc/finetune/embedder/AbsModeling.py +++ b/FlagEmbedding/abc/finetune/embedder/AbsModeling.py @@ -54,8 +54,8 @@ def __init__( if self.negatives_cross_device: if not dist.is_initialized(): raise ValueError('Distributed training has not been initialized for representation all gather.') - self.process_rank = dist.get_rank() - self.world_size = dist.get_world_size() + self.process_rank = dist.get_rank() if dist.is_initialized() else 0 + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 self.sub_batch_size = sub_batch_size self.kd_loss_type = kd_loss_type diff --git a/FlagEmbedding/abc/finetune/reranker/AbsDataset.py b/FlagEmbedding/abc/finetune/reranker/AbsDataset.py index 22389770..73830bbb 100644 --- a/FlagEmbedding/abc/finetune/reranker/AbsDataset.py +++ b/FlagEmbedding/abc/finetune/reranker/AbsDataset.py @@ -64,7 +64,8 @@ def _load_dataset(self, file_path: str): Returns: datasets.Dataset: Loaded HF dataset. """ - if dist.get_rank() == 0: + safe_rank = dist.get_rank() if dist.is_initialized() else 0 + if safe_rank == 0: logger.info(f'loading data from {file_path} ...') temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=self.args.cache_path)