22# SPDX-License-Identifier: Apache-2.0
33
44import math
5- import warnings
65from typing import Any , Dict , List , Optional , OrderedDict , Tuple , Union
76
87from . import cpp as fstcpp
98from .common import SafeTensorsMetadata , TensorFrame , get_device_numa_node
9+ from .copier .gds import new_gds_file_copier
1010from .file_buffer import FilesBufferOnDevice
11- from .frameworks import TensorBase , get_framework_op
12- from .st_types import DeviceType , DType
11+ from .frameworks import FrameworkOpBase , TensorBase , get_framework_op
12+ from .st_types import Device , DeviceType , DType
1313from .tensor_factory import LazyTensorFactory
1414
1515gl_set_numa = False
1616
1717loaded_nvidia = False
1818
1919
20- class SafeTensorsFileLoader :
21- r"""Load .safetensors files lazily.
20+ class BaseSafeTensorsFileLoader :
21+ r"""Base class for loading .safetensors files lazily.
2222
2323 Args:
24- devcie (str): target device.
25- pg (Optional[Any]): process group-like objects for distributed. None for single GPU use-cases.
26- bbuf_size_kb (int): bounce buffer size for file copies.
27- max_threads (int): maximum number of threads for memory copies.
28- nogds (bool): if True, trun off GDS and fallback to pread with bounce buffer.
29- debug_log (bool): enable debug logs.
30-
31- Examples:
32- >> from fastsafetensors import SafeTensorsFileLoader
33- >> src_files = download(target_dir, "gpt2")
34- >> loader = SafeTensorsFileLoader(Device("cpu"), nogds=True, debug_log=True)
35- >> loader.add_filenames({0: src_files})
36- >> bufs = loader.copy_files_to_device()
37- >> print(bufs.get_tensor(loader.get_keys()[0]))
38- >> loader.close()
24+ pg (Optional[Any]): Process group-like objects for distributed loading.
25+ Use None for single device use-cases.
26+ device (Device): Target device where tensors will be loaded (CPU, CUDA, etc.).
27+ copier_constructor: Constructor function for creating file copier objects.
28+ set_numa (bool): Whether to set NUMA node affinity for optimized memory access.
29+ disable_cache (bool): Whether to disable caching of loaded tensors.
30+ debug_log (bool): Enable detailed debug logging.
31+ framework (str): Deep learning framework to use ("pytorch" or "paddle").
3932 """
4033
4134 def __init__ (
4235 self ,
4336 pg : Optional [Any ],
44- device : str = "cpu" ,
45- bbuf_size_kb : int = 16 * 1024 ,
46- max_threads : int = 16 ,
47- nogds : bool = False ,
37+ device : Device ,
38+ copier_constructor ,
4839 set_numa : bool = True ,
4940 disable_cache : bool = True ,
5041 debug_log : bool = False ,
5142 framework = "pytorch" ,
5243 ):
5344 self .framework = get_framework_op (framework )
5445 self .pg = self .framework .get_process_group (pg )
55- self .device = self . framework . get_device ( device , self . pg )
46+ self .device = device
5647 self .debug_log = debug_log
5748 self .meta : Dict [str , Tuple [SafeTensorsMetadata , int ]] = {}
5849 self .frames = OrderedDict [str , TensorFrame ]()
5950 self .disable_cache = disable_cache
60- global loaded_nvidia
61- if not loaded_nvidia :
62- fstcpp .load_nvidia_functions ()
63- if not nogds :
64- # no need to init gds and consume 10s+ in none-gds case
65- if fstcpp .init_gds () != 0 :
66- raise Exception (f"[FAIL] init_gds()" )
67- loaded_nvidia = True
51+ self .init_numa (set_numa )
52+ self .copier_constructor = copier_constructor
53+
54+ def init_numa (self , set_numa : bool = True ):
6855 global gl_set_numa
6956 if not gl_set_numa and set_numa :
7057 node = get_device_numa_node (self .device .index )
7158 if node is not None :
7259 fstcpp .set_numa_node (node )
7360 gl_set_numa = True
74- fstcpp .set_debug_log (debug_log )
75- device_is_not_cpu = self .device .type != DeviceType .CPU
76- if device_is_not_cpu and not fstcpp .is_cuda_found ():
77- raise Exception ("[FAIL] libcudart.so does not exist" )
78- if not fstcpp .is_cufile_found () and not nogds :
79- warnings .warn (
80- "libcufile.so does not exist but nogds is False. use nogds=True" ,
81- UserWarning ,
82- )
83- nogds = True
84- self .reader : Union [fstcpp .nogds_file_reader , fstcpp .gds_file_reader ]
85- if nogds :
86- self .reader = fstcpp .nogds_file_reader (
87- False , bbuf_size_kb , max_threads , device_is_not_cpu
88- )
89- else :
90- self .reader = fstcpp .gds_file_reader (max_threads , device_is_not_cpu )
9161
9262 def reset (self ):
9363 self .frames = {}
9464 self .meta = {}
9565
9666 def close (self ):
9767 self .reset ()
98- del self .reader
68+ del self .copier_constructor
9969
10070 def get_keys (self ) -> List [str ]:
10171 return list (self .frames .keys ())
@@ -145,8 +115,10 @@ def copy_files_to_device(
145115
146116 factory_idx_bits = math .ceil (math .log2 (len (self .meta ) + 1 ))
147117 lidx = 1
148-
149118 for _ , (meta , rank ) in sorted (self .meta .items (), key = lambda x : x [0 ]):
119+ copier = self .copier_constructor (
120+ meta , self .device , self .framework , self .debug_log
121+ )
150122 self_rank = self .pg .rank () == rank
151123 factory = LazyTensorFactory (
152124 meta ,
@@ -155,7 +127,7 @@ def copy_files_to_device(
155127 self_rank ,
156128 factory_idx_bits ,
157129 lidx ,
158- self . reader ,
130+ copier ,
159131 self .framework ,
160132 self .debug_log ,
161133 disable_cache = self .disable_cache ,
@@ -166,12 +138,63 @@ def copy_files_to_device(
166138 need_wait .append (factory )
167139 lidx += 1
168140 for factory in need_wait :
169- factory .wait_io (
170- dtype = dtype , noalign = isinstance (self .reader , fstcpp .nogds_file_reader )
171- )
141+ factory .wait_io (dtype = dtype , noalign = False )
172142 return FilesBufferOnDevice (factories , pg = self .pg , framework = self .framework )
173143
174144
145+ class SafeTensorsFileLoader (BaseSafeTensorsFileLoader ):
146+ r"""Load .safetensors files lazily.
147+
148+ Args:
149+ devcie (str): target device.
150+ pg (Optional[Any]): process group-like objects for distributed. None for single GPU use-cases.
151+ bbuf_size_kb (int): bounce buffer size for file copies.
152+ max_threads (int): maximum number of threads for memory copies.
153+ nogds (bool): if True, trun off GDS and fallback to pread with bounce buffer.
154+ debug_log (bool): enable debug logs.
155+
156+ Examples:
157+ >> from fastsafetensors import SafeTensorsFileLoader
158+ >> src_files = download(target_dir, "gpt2")
159+ >> loader = SafeTensorsFileLoader(Device("cpu"), nogds=True, debug_log=True)
160+ >> loader.add_filenames({0: src_files})
161+ >> bufs = loader.copy_files_to_device()
162+ >> print(bufs.get_tensor(loader.get_keys()[0]))
163+ >> loader.close()
164+ """
165+
166+ def __init__ (
167+ self ,
168+ pg : Optional [Any ],
169+ device : str = "cpu" ,
170+ bbuf_size_kb : int = 16 * 1024 ,
171+ max_threads : int = 16 ,
172+ nogds : bool = False ,
173+ set_numa : bool = True ,
174+ disable_cache : bool = True ,
175+ debug_log : bool = False ,
176+ framework = "pytorch" ,
177+ ):
178+ self .framework = get_framework_op (framework )
179+ self .pg = self .framework .get_process_group (pg )
180+ self .device = self .framework .get_device (device , self .pg )
181+
182+ fstcpp .set_debug_log (debug_log )
183+ global loaded_nvidia
184+ if not loaded_nvidia :
185+ fstcpp .load_nvidia_functions ()
186+ if not nogds :
187+ # no need to init gds and consume 10s+ in none-gds case
188+ if fstcpp .init_gds () != 0 :
189+ raise Exception (f"[FAIL] init_gds()" )
190+ loaded_nvidia = True
191+
192+ copier = new_gds_file_copier (self .device , bbuf_size_kb , max_threads , nogds )
193+ super ().__init__ (
194+ pg , self .device , copier , set_numa , disable_cache , debug_log , framework
195+ )
196+
197+
175198class fastsafe_open :
176199 """
177200 Opens a safetensors lazily and returns tensors as asked
0 commit comments