Skip to content

Commit 40a5d58

Browse files
committed
add custom copier interface
Signed-off-by: yuanyuxing.yyx <[email protected]>
1 parent d04db8d commit 40a5d58

File tree

5 files changed

+171
-73
lines changed

5 files changed

+171
-73
lines changed

fastsafetensors/copier/base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from typing import Dict, Protocol
2+
3+
from .. import cpp as fstcpp
4+
from ..frameworks import TensorBase
5+
from ..st_types import DType
6+
7+
8+
class CopierInterface(Protocol):
9+
def submit_io(
10+
self, use_buf_register: bool, max_copy_block_size: int
11+
) -> fstcpp.gds_device_buffer:
12+
pass
13+
14+
def wait_io(
15+
self,
16+
gbuf: fstcpp.gds_device_buffer,
17+
dtype: DType = DType.AUTO,
18+
noalign: bool = False,
19+
) -> Dict[str, TensorBase]:
20+
pass
21+
22+
23+
class DummyDeviceBuffer(fstcpp.gds_device_buffer):
24+
def __init__(self):
25+
super().__init__(0, 0, False)

fastsafetensors/copier/gds.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
# Copyright 2024 IBM Inc. All rights reserved
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import warnings
45
from typing import Dict, Optional
56

67
from .. import cpp as fstcpp
78
from ..common import SafeTensorsMetadata
89
from ..frameworks import FrameworkOpBase, TensorBase
910
from ..st_types import Device, DeviceType, DType
11+
from .base import CopierInterface
12+
from .nogds import new_nogds_file_copier
1013

1114

12-
class GdsFileCopier:
15+
class GdsFileCopier(CopierInterface):
1316
def __init__(
1417
self,
1518
metadata: SafeTensorsMetadata,
@@ -139,3 +142,35 @@ def wait_io(
139142
return self.metadata.get_tensors(
140143
gbuf, self.device, self.aligned_offset, dtype=dtype
141144
)
145+
146+
147+
def new_gds_file_copier(
148+
device: Device,
149+
bbuf_size_kb: int = 16 * 1024,
150+
max_threads: int = 16,
151+
nogds: bool = False,
152+
):
153+
device_is_not_cpu = device.type != DeviceType.CPU
154+
if device_is_not_cpu and not fstcpp.is_cuda_found():
155+
raise Exception("[FAIL] libcudart.so does not exist")
156+
if not fstcpp.is_cufile_found() and not nogds:
157+
warnings.warn(
158+
"libcufile.so does not exist but nogds is False. use nogds=True",
159+
UserWarning,
160+
)
161+
nogds = True
162+
163+
if nogds:
164+
return new_nogds_file_copier(device, bbuf_size_kb, max_threads)
165+
166+
reader = fstcpp.gds_file_reader(max_threads, device_is_not_cpu)
167+
168+
def construct_copier(
169+
metadata: SafeTensorsMetadata,
170+
device: Device,
171+
framework: FrameworkOpBase,
172+
debug_log: bool = False,
173+
) -> CopierInterface:
174+
return GdsFileCopier(metadata, device, reader, framework, debug_log)
175+
176+
return construct_copier

fastsafetensors/copier/nogds.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from .. import cpp as fstcpp
88
from ..common import SafeTensorsMetadata
99
from ..frameworks import FrameworkOpBase, TensorBase
10-
from ..st_types import Device, DType
10+
from ..st_types import Device, DeviceType, DType
11+
from .base import CopierInterface
1112

1213

13-
class NoGdsFileCopier:
14+
class NoGdsFileCopier(CopierInterface):
1415
def __init__(
1516
self,
1617
metadata: SafeTensorsMetadata,
@@ -66,3 +67,26 @@ def wait_io(
6667
return self.metadata.get_tensors(
6768
gbuf, self.device, self.metadata.header_length, dtype=dtype
6869
)
70+
71+
72+
def new_nogds_file_copier(
73+
device: Device,
74+
bbuf_size_kb: int = 16 * 1024,
75+
max_threads: int = 16,
76+
):
77+
device_is_not_cpu = device.type != DeviceType.CPU
78+
if device_is_not_cpu and not fstcpp.is_cuda_found():
79+
raise Exception("[FAIL] libcudart.so does not exist")
80+
reader = fstcpp.nogds_file_reader(
81+
False, bbuf_size_kb, max_threads, device_is_not_cpu
82+
)
83+
84+
def construct_copier(
85+
metadata: SafeTensorsMetadata,
86+
device: Device,
87+
framework: FrameworkOpBase,
88+
debug_log: bool = False,
89+
) -> CopierInterface:
90+
return NoGdsFileCopier(metadata, device, reader, framework, debug_log)
91+
92+
return construct_copier

fastsafetensors/loader.py

Lines changed: 79 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,100 +2,70 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import math
5-
import warnings
65
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union
76

87
from . import cpp as fstcpp
98
from .common import SafeTensorsMetadata, TensorFrame, get_device_numa_node
9+
from .copier.gds import new_gds_file_copier
1010
from .file_buffer import FilesBufferOnDevice
11-
from .frameworks import TensorBase, get_framework_op
12-
from .st_types import DeviceType, DType
11+
from .frameworks import FrameworkOpBase, TensorBase, get_framework_op
12+
from .st_types import Device, DeviceType, DType
1313
from .tensor_factory import LazyTensorFactory
1414

1515
gl_set_numa = False
1616

1717
loaded_nvidia = False
1818

1919

20-
class SafeTensorsFileLoader:
21-
r"""Load .safetensors files lazily.
20+
class BaseSafeTensorsFileLoader:
21+
r"""Base class for loading .safetensors files lazily.
2222
2323
Args:
24-
devcie (str): target device.
25-
pg (Optional[Any]): process group-like objects for distributed. None for single GPU use-cases.
26-
bbuf_size_kb (int): bounce buffer size for file copies.
27-
max_threads (int): maximum number of threads for memory copies.
28-
nogds (bool): if True, trun off GDS and fallback to pread with bounce buffer.
29-
debug_log (bool): enable debug logs.
30-
31-
Examples:
32-
>> from fastsafetensors import SafeTensorsFileLoader
33-
>> src_files = download(target_dir, "gpt2")
34-
>> loader = SafeTensorsFileLoader(Device("cpu"), nogds=True, debug_log=True)
35-
>> loader.add_filenames({0: src_files})
36-
>> bufs = loader.copy_files_to_device()
37-
>> print(bufs.get_tensor(loader.get_keys()[0]))
38-
>> loader.close()
24+
pg (Optional[Any]): Process group-like objects for distributed loading.
25+
Use None for single device use-cases.
26+
device (Device): Target device where tensors will be loaded (CPU, CUDA, etc.).
27+
copier_constructor: Constructor function for creating file copier objects.
28+
set_numa (bool): Whether to set NUMA node affinity for optimized memory access.
29+
disable_cache (bool): Whether to disable caching of loaded tensors.
30+
debug_log (bool): Enable detailed debug logging.
31+
framework (str): Deep learning framework to use ("pytorch" or "paddle").
3932
"""
4033

4134
def __init__(
4235
self,
4336
pg: Optional[Any],
44-
device: str = "cpu",
45-
bbuf_size_kb: int = 16 * 1024,
46-
max_threads: int = 16,
47-
nogds: bool = False,
37+
device: Device,
38+
copier_constructor,
4839
set_numa: bool = True,
4940
disable_cache: bool = True,
5041
debug_log: bool = False,
5142
framework="pytorch",
5243
):
5344
self.framework = get_framework_op(framework)
5445
self.pg = self.framework.get_process_group(pg)
55-
self.device = self.framework.get_device(device, self.pg)
46+
self.device = device
5647
self.debug_log = debug_log
5748
self.meta: Dict[str, Tuple[SafeTensorsMetadata, int]] = {}
5849
self.frames = OrderedDict[str, TensorFrame]()
5950
self.disable_cache = disable_cache
60-
global loaded_nvidia
61-
if not loaded_nvidia:
62-
fstcpp.load_nvidia_functions()
63-
if not nogds:
64-
# no need to init gds and consume 10s+ in none-gds case
65-
if fstcpp.init_gds() != 0:
66-
raise Exception(f"[FAIL] init_gds()")
67-
loaded_nvidia = True
51+
self.init_numa(set_numa)
52+
self.copier_constructor = copier_constructor
53+
54+
def init_numa(self, set_numa: bool = True):
6855
global gl_set_numa
6956
if not gl_set_numa and set_numa:
7057
node = get_device_numa_node(self.device.index)
7158
if node is not None:
7259
fstcpp.set_numa_node(node)
7360
gl_set_numa = True
74-
fstcpp.set_debug_log(debug_log)
75-
device_is_not_cpu = self.device.type != DeviceType.CPU
76-
if device_is_not_cpu and not fstcpp.is_cuda_found():
77-
raise Exception("[FAIL] libcudart.so does not exist")
78-
if not fstcpp.is_cufile_found() and not nogds:
79-
warnings.warn(
80-
"libcufile.so does not exist but nogds is False. use nogds=True",
81-
UserWarning,
82-
)
83-
nogds = True
84-
self.reader: Union[fstcpp.nogds_file_reader, fstcpp.gds_file_reader]
85-
if nogds:
86-
self.reader = fstcpp.nogds_file_reader(
87-
False, bbuf_size_kb, max_threads, device_is_not_cpu
88-
)
89-
else:
90-
self.reader = fstcpp.gds_file_reader(max_threads, device_is_not_cpu)
9161

9262
def reset(self):
9363
self.frames = {}
9464
self.meta = {}
9565

9666
def close(self):
9767
self.reset()
98-
del self.reader
68+
del self.copier_constructor
9969

10070
def get_keys(self) -> List[str]:
10171
return list(self.frames.keys())
@@ -145,8 +115,10 @@ def copy_files_to_device(
145115

146116
factory_idx_bits = math.ceil(math.log2(len(self.meta) + 1))
147117
lidx = 1
148-
149118
for _, (meta, rank) in sorted(self.meta.items(), key=lambda x: x[0]):
119+
copier = self.copier_constructor(
120+
meta, self.device, self.framework, self.debug_log
121+
)
150122
self_rank = self.pg.rank() == rank
151123
factory = LazyTensorFactory(
152124
meta,
@@ -155,7 +127,7 @@ def copy_files_to_device(
155127
self_rank,
156128
factory_idx_bits,
157129
lidx,
158-
self.reader,
130+
copier,
159131
self.framework,
160132
self.debug_log,
161133
disable_cache=self.disable_cache,
@@ -166,12 +138,63 @@ def copy_files_to_device(
166138
need_wait.append(factory)
167139
lidx += 1
168140
for factory in need_wait:
169-
factory.wait_io(
170-
dtype=dtype, noalign=isinstance(self.reader, fstcpp.nogds_file_reader)
171-
)
141+
factory.wait_io(dtype=dtype, noalign=False)
172142
return FilesBufferOnDevice(factories, pg=self.pg, framework=self.framework)
173143

174144

145+
class SafeTensorsFileLoader(BaseSafeTensorsFileLoader):
146+
r"""Load .safetensors files lazily.
147+
148+
Args:
149+
devcie (str): target device.
150+
pg (Optional[Any]): process group-like objects for distributed. None for single GPU use-cases.
151+
bbuf_size_kb (int): bounce buffer size for file copies.
152+
max_threads (int): maximum number of threads for memory copies.
153+
nogds (bool): if True, trun off GDS and fallback to pread with bounce buffer.
154+
debug_log (bool): enable debug logs.
155+
156+
Examples:
157+
>> from fastsafetensors import SafeTensorsFileLoader
158+
>> src_files = download(target_dir, "gpt2")
159+
>> loader = SafeTensorsFileLoader(Device("cpu"), nogds=True, debug_log=True)
160+
>> loader.add_filenames({0: src_files})
161+
>> bufs = loader.copy_files_to_device()
162+
>> print(bufs.get_tensor(loader.get_keys()[0]))
163+
>> loader.close()
164+
"""
165+
166+
def __init__(
167+
self,
168+
pg: Optional[Any],
169+
device: str = "cpu",
170+
bbuf_size_kb: int = 16 * 1024,
171+
max_threads: int = 16,
172+
nogds: bool = False,
173+
set_numa: bool = True,
174+
disable_cache: bool = True,
175+
debug_log: bool = False,
176+
framework="pytorch",
177+
):
178+
self.framework = get_framework_op(framework)
179+
self.pg = self.framework.get_process_group(pg)
180+
self.device = self.framework.get_device(device, self.pg)
181+
182+
fstcpp.set_debug_log(debug_log)
183+
global loaded_nvidia
184+
if not loaded_nvidia:
185+
fstcpp.load_nvidia_functions()
186+
if not nogds:
187+
# no need to init gds and consume 10s+ in none-gds case
188+
if fstcpp.init_gds() != 0:
189+
raise Exception(f"[FAIL] init_gds()")
190+
loaded_nvidia = True
191+
192+
copier = new_gds_file_copier(self.device, bbuf_size_kb, max_threads, nogds)
193+
super().__init__(
194+
pg, self.device, copier, set_numa, disable_cache, debug_log, framework
195+
)
196+
197+
175198
class fastsafe_open:
176199
"""
177200
Opens a safetensors lazily and returns tensors as asked

fastsafetensors/tensor_factory.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
# Copyright 2024 IBM Inc. All rights reserved
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from collections import OrderedDict
54
from typing import Dict, List, Optional, Tuple, Union
65

76
from . import cpp as fstcpp
87
from .common import SafeTensorsMetadata
9-
from .copier.gds import GdsFileCopier
10-
from .copier.nogds import NoGdsFileCopier
8+
from .copier.base import CopierInterface, DummyDeviceBuffer
119
from .frameworks import FrameworkOpBase, ProcessGroupBase, TensorBase
1210
from .st_types import Device, DType
1311

@@ -21,24 +19,17 @@ def __init__(
2119
local_rank: bool,
2220
factory_idx_bits: int,
2321
lidx: int,
24-
reader: Union[fstcpp.gds_file_reader, fstcpp.nogds_file_reader],
22+
copier: CopierInterface,
2523
framework: FrameworkOpBase,
2624
debug_log: bool = False,
2725
disable_cache=True,
2826
):
2927
self.framework = framework
3028
self.metadata = metadata
3129
self.device = device
32-
self.copier: Optional[Union[NoGdsFileCopier, GdsFileCopier]] = None
30+
self.copier: Optional[CopierInterface] = None
3331
if local_rank:
34-
if isinstance(reader, fstcpp.nogds_file_reader):
35-
self.copier = NoGdsFileCopier(
36-
metadata, device, reader, framework, debug_log
37-
)
38-
else:
39-
self.copier = GdsFileCopier(
40-
metadata, device, reader, framework, debug_log
41-
)
32+
self.copier = copier
4233
self.tensors: Dict[str, TensorBase] = {}
4334
self.shuffled: Dict[str, TensorBase] = {}
4435
self.gbuf: Optional[fstcpp.gds_device_buffer] = None
@@ -224,7 +215,7 @@ def shuffle_multi_cols(
224215

225216
def free_dev_ptrs(self):
226217
self.tensors = {}
227-
if self.gbuf is not None:
218+
if self.gbuf is not None and not isinstance(self.gbuf, DummyDeviceBuffer):
228219
self.framework.free_tensor_memory(self.gbuf, self.device)
229220
if self.debug_log:
230221
print(

0 commit comments

Comments
 (0)