diff --git a/.gitignore b/.gitignore index acc54b9..aac9309 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ htmlcov/ .coverage_* .pytest_cache/ .vscode +.idea *.log *.pyc examples/paddle_case/log \ No newline at end of file diff --git a/fastsafetensors/loader.py b/fastsafetensors/loader.py index 60959fa..19c6335 100644 --- a/fastsafetensors/loader.py +++ b/fastsafetensors/loader.py @@ -46,6 +46,7 @@ def __init__( max_threads: int = 16, nogds: bool = False, set_numa: bool = True, + disable_cache: bool = True, debug_log: bool = False, framework="pytorch", ): @@ -55,6 +56,7 @@ def __init__( self.debug_log = debug_log self.meta: Dict[str, Tuple[SafeTensorsMetadata, int]] = {} self.frames = OrderedDict[str, TensorFrame]() + self.disable_cache = disable_cache global loaded_nvidia if not loaded_nvidia: fstcpp.load_nvidia_functions() @@ -154,6 +156,7 @@ def copy_files_to_device( self.reader, self.framework, self.debug_log, + disable_cache=self.disable_cache, ) factory.submit_io(use_buf_register, max_copy_block_size) factories[rank].append(factory) diff --git a/fastsafetensors/tensor_factory.py b/fastsafetensors/tensor_factory.py index 839d188..cebdb3e 100644 --- a/fastsafetensors/tensor_factory.py +++ b/fastsafetensors/tensor_factory.py @@ -24,6 +24,7 @@ def __init__( reader: Union[fstcpp.gds_file_reader, fstcpp.nogds_file_reader], framework: FrameworkOpBase, debug_log: bool = False, + disable_cache=True, ): self.framework = framework self.metadata = metadata @@ -46,6 +47,7 @@ def __init__( self.factory_idx_bits = factory_idx_bits self.lidx = lidx self.next_tag = 1 + self.disable_cache = disable_cache def submit_io(self, use_buf_register: bool, max_copy_block_size: int): if self.copier is not None: @@ -160,7 +162,11 @@ def shuffle(self, pg: ProcessGroupBase, tensor_name: str, dim: int) -> TensorBas f"shuffle: scatter, tensor_name={tensor_name}, shape={frame.shape}->{new_frame.shape}, self.rank={self.rank}, pg.rank()={pg.rank()}, rank_slices={rank_slices}, len(scatter_list)={len(scatter_list)}" ) pg.scatter(dst, scatter_list=scatter_list, src=self.rank) - self.shuffled[tensor_name] = dst + if not self.disable_cache: + # Cache tensor for reuse within the same batch to improve performance. + # Note: This requires additional (GPU) memory to store the cached tensors. + # Enable this only if you have sufficient (GPU) memory and required. + self.shuffled[tensor_name] = dst return dst def shuffle_multi_cols(