Skip to content

Commit e99b606

Browse files
authored
add custom copier interface (#32)
Signed-off-by: yuanyuxing.yyx <[email protected]>
1 parent d04db8d commit e99b606

File tree

6 files changed

+221
-73
lines changed

6 files changed

+221
-73
lines changed

fastsafetensors/copier/base.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from abc import ABC, abstractmethod
4+
from typing import Dict
5+
6+
from .. import cpp as fstcpp
7+
from ..frameworks import TensorBase
8+
from ..st_types import DType
9+
10+
11+
class CopierInterface(ABC):
12+
@abstractmethod
13+
def submit_io(
14+
self, use_buf_register: bool, max_copy_block_size: int
15+
) -> fstcpp.gds_device_buffer:
16+
pass
17+
18+
@abstractmethod
19+
def wait_io(
20+
self,
21+
gbuf: fstcpp.gds_device_buffer,
22+
dtype: DType = DType.AUTO,
23+
noalign: bool = False,
24+
) -> Dict[str, TensorBase]:
25+
pass
26+
27+
28+
class DummyDeviceBuffer(fstcpp.gds_device_buffer):
29+
def __init__(self):
30+
super().__init__(0, 0, False)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from typing import Any, Dict
4+
5+
from .. import cpp as fstcpp
6+
from ..common import SafeTensorsMetadata
7+
from ..frameworks import FrameworkOpBase, TensorBase
8+
from ..st_types import Device, DeviceType, DType
9+
from .base import CopierInterface, DummyDeviceBuffer
10+
11+
12+
class ExampleCopier(CopierInterface):
13+
def __init__(
14+
self,
15+
metadata: SafeTensorsMetadata,
16+
device: Device,
17+
reader,
18+
framework: FrameworkOpBase,
19+
debug_log: bool = False,
20+
):
21+
pass
22+
23+
def submit_io(
24+
self, use_buf_register: bool, max_copy_block_size: int
25+
) -> fstcpp.gds_device_buffer:
26+
return DummyDeviceBuffer()
27+
28+
def wait_io(
29+
self,
30+
gbuf: fstcpp.gds_device_buffer,
31+
dtype: DType = DType.AUTO,
32+
noalign: bool = False,
33+
) -> Dict[str, TensorBase]:
34+
# get tensor
35+
res: Dict[str, TensorBase] = {}
36+
return res
37+
38+
39+
def new_gds_file_copier(
40+
device: Device,
41+
bbuf_size_kb: int = 16 * 1024,
42+
max_threads: int = 16,
43+
nogds: bool = False,
44+
):
45+
# reader = example_reader()
46+
reader: Any = {}
47+
48+
def construct_copier(
49+
metadata: SafeTensorsMetadata,
50+
device: Device,
51+
framework: FrameworkOpBase,
52+
debug_log: bool = False,
53+
) -> CopierInterface:
54+
return ExampleCopier(metadata, device, reader, framework, debug_log)
55+
56+
return construct_copier

fastsafetensors/copier/gds.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
# Copyright 2024 IBM Inc. All rights reserved
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import warnings
45
from typing import Dict, Optional
56

67
from .. import cpp as fstcpp
78
from ..common import SafeTensorsMetadata
89
from ..frameworks import FrameworkOpBase, TensorBase
910
from ..st_types import Device, DeviceType, DType
11+
from .base import CopierInterface
12+
from .nogds import NoGdsFileCopier
1013

1114

12-
class GdsFileCopier:
15+
class GdsFileCopier(CopierInterface):
1316
def __init__(
1417
self,
1518
metadata: SafeTensorsMetadata,
@@ -139,3 +142,47 @@ def wait_io(
139142
return self.metadata.get_tensors(
140143
gbuf, self.device, self.aligned_offset, dtype=dtype
141144
)
145+
146+
147+
def new_gds_file_copier(
148+
device: Device,
149+
bbuf_size_kb: int = 16 * 1024,
150+
max_threads: int = 16,
151+
nogds: bool = False,
152+
):
153+
device_is_not_cpu = device.type != DeviceType.CPU
154+
if device_is_not_cpu and not fstcpp.is_cuda_found():
155+
raise Exception("[FAIL] libcudart.so does not exist")
156+
if not fstcpp.is_cufile_found() and not nogds:
157+
warnings.warn(
158+
"libcufile.so does not exist but nogds is False. use nogds=True",
159+
UserWarning,
160+
)
161+
nogds = True
162+
163+
if nogds:
164+
nogds_reader = fstcpp.nogds_file_reader(
165+
False, bbuf_size_kb, max_threads, device_is_not_cpu
166+
)
167+
168+
def construct_nogds_copier(
169+
metadata: SafeTensorsMetadata,
170+
device: Device,
171+
framework: FrameworkOpBase,
172+
debug_log: bool = False,
173+
) -> CopierInterface:
174+
return NoGdsFileCopier(metadata, device, nogds_reader, framework, debug_log)
175+
176+
return construct_nogds_copier
177+
178+
reader = fstcpp.gds_file_reader(max_threads, device_is_not_cpu)
179+
180+
def construct_copier(
181+
metadata: SafeTensorsMetadata,
182+
device: Device,
183+
framework: FrameworkOpBase,
184+
debug_log: bool = False,
185+
) -> CopierInterface:
186+
return GdsFileCopier(metadata, device, reader, framework, debug_log)
187+
188+
return construct_copier

fastsafetensors/copier/nogds.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from .. import cpp as fstcpp
88
from ..common import SafeTensorsMetadata
99
from ..frameworks import FrameworkOpBase, TensorBase
10-
from ..st_types import Device, DType
10+
from ..st_types import Device, DeviceType, DType
11+
from .base import CopierInterface
1112

1213

13-
class NoGdsFileCopier:
14+
class NoGdsFileCopier(CopierInterface):
1415
def __init__(
1516
self,
1617
metadata: SafeTensorsMetadata,

fastsafetensors/loader.py

Lines changed: 79 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,100 +2,70 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import math
5-
import warnings
65
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union
76

87
from . import cpp as fstcpp
98
from .common import SafeTensorsMetadata, TensorFrame, get_device_numa_node
9+
from .copier.gds import new_gds_file_copier
1010
from .file_buffer import FilesBufferOnDevice
11-
from .frameworks import TensorBase, get_framework_op
12-
from .st_types import DeviceType, DType
11+
from .frameworks import FrameworkOpBase, TensorBase, get_framework_op
12+
from .st_types import Device, DeviceType, DType
1313
from .tensor_factory import LazyTensorFactory
1414

1515
gl_set_numa = False
1616

1717
loaded_nvidia = False
1818

1919

20-
class SafeTensorsFileLoader:
21-
r"""Load .safetensors files lazily.
20+
class BaseSafeTensorsFileLoader:
21+
r"""Base class for loading .safetensors files lazily.
2222
2323
Args:
24-
devcie (str): target device.
25-
pg (Optional[Any]): process group-like objects for distributed. None for single GPU use-cases.
26-
bbuf_size_kb (int): bounce buffer size for file copies.
27-
max_threads (int): maximum number of threads for memory copies.
28-
nogds (bool): if True, trun off GDS and fallback to pread with bounce buffer.
29-
debug_log (bool): enable debug logs.
30-
31-
Examples:
32-
>> from fastsafetensors import SafeTensorsFileLoader
33-
>> src_files = download(target_dir, "gpt2")
34-
>> loader = SafeTensorsFileLoader(Device("cpu"), nogds=True, debug_log=True)
35-
>> loader.add_filenames({0: src_files})
36-
>> bufs = loader.copy_files_to_device()
37-
>> print(bufs.get_tensor(loader.get_keys()[0]))
38-
>> loader.close()
24+
pg (Optional[Any]): Process group-like objects for distributed loading.
25+
Use None for single device use-cases.
26+
device (Device): Target device where tensors will be loaded (CPU, CUDA, etc.).
27+
copier_constructor: Constructor function for creating file copier objects.
28+
set_numa (bool): Whether to set NUMA node affinity for optimized memory access.
29+
disable_cache (bool): Whether to disable caching of loaded tensors.
30+
debug_log (bool): Enable detailed debug logging.
31+
framework (str): Deep learning framework to use ("pytorch" or "paddle").
3932
"""
4033

4134
def __init__(
4235
self,
4336
pg: Optional[Any],
44-
device: str = "cpu",
45-
bbuf_size_kb: int = 16 * 1024,
46-
max_threads: int = 16,
47-
nogds: bool = False,
37+
device: Device,
38+
copier_constructor,
4839
set_numa: bool = True,
4940
disable_cache: bool = True,
5041
debug_log: bool = False,
5142
framework="pytorch",
5243
):
5344
self.framework = get_framework_op(framework)
5445
self.pg = self.framework.get_process_group(pg)
55-
self.device = self.framework.get_device(device, self.pg)
46+
self.device = device
5647
self.debug_log = debug_log
5748
self.meta: Dict[str, Tuple[SafeTensorsMetadata, int]] = {}
5849
self.frames = OrderedDict[str, TensorFrame]()
5950
self.disable_cache = disable_cache
60-
global loaded_nvidia
61-
if not loaded_nvidia:
62-
fstcpp.load_nvidia_functions()
63-
if not nogds:
64-
# no need to init gds and consume 10s+ in none-gds case
65-
if fstcpp.init_gds() != 0:
66-
raise Exception(f"[FAIL] init_gds()")
67-
loaded_nvidia = True
51+
self.init_numa(set_numa)
52+
self.copier_constructor = copier_constructor
53+
54+
def init_numa(self, set_numa: bool = True):
6855
global gl_set_numa
6956
if not gl_set_numa and set_numa:
7057
node = get_device_numa_node(self.device.index)
7158
if node is not None:
7259
fstcpp.set_numa_node(node)
7360
gl_set_numa = True
74-
fstcpp.set_debug_log(debug_log)
75-
device_is_not_cpu = self.device.type != DeviceType.CPU
76-
if device_is_not_cpu and not fstcpp.is_cuda_found():
77-
raise Exception("[FAIL] libcudart.so does not exist")
78-
if not fstcpp.is_cufile_found() and not nogds:
79-
warnings.warn(
80-
"libcufile.so does not exist but nogds is False. use nogds=True",
81-
UserWarning,
82-
)
83-
nogds = True
84-
self.reader: Union[fstcpp.nogds_file_reader, fstcpp.gds_file_reader]
85-
if nogds:
86-
self.reader = fstcpp.nogds_file_reader(
87-
False, bbuf_size_kb, max_threads, device_is_not_cpu
88-
)
89-
else:
90-
self.reader = fstcpp.gds_file_reader(max_threads, device_is_not_cpu)
9161

9262
def reset(self):
9363
self.frames = {}
9464
self.meta = {}
9565

9666
def close(self):
9767
self.reset()
98-
del self.reader
68+
del self.copier_constructor
9969

10070
def get_keys(self) -> List[str]:
10171
return list(self.frames.keys())
@@ -145,8 +115,10 @@ def copy_files_to_device(
145115

146116
factory_idx_bits = math.ceil(math.log2(len(self.meta) + 1))
147117
lidx = 1
148-
149118
for _, (meta, rank) in sorted(self.meta.items(), key=lambda x: x[0]):
119+
copier = self.copier_constructor(
120+
meta, self.device, self.framework, self.debug_log
121+
)
150122
self_rank = self.pg.rank() == rank
151123
factory = LazyTensorFactory(
152124
meta,
@@ -155,7 +127,7 @@ def copy_files_to_device(
155127
self_rank,
156128
factory_idx_bits,
157129
lidx,
158-
self.reader,
130+
copier,
159131
self.framework,
160132
self.debug_log,
161133
disable_cache=self.disable_cache,
@@ -166,12 +138,63 @@ def copy_files_to_device(
166138
need_wait.append(factory)
167139
lidx += 1
168140
for factory in need_wait:
169-
factory.wait_io(
170-
dtype=dtype, noalign=isinstance(self.reader, fstcpp.nogds_file_reader)
171-
)
141+
factory.wait_io(dtype=dtype, noalign=False)
172142
return FilesBufferOnDevice(factories, pg=self.pg, framework=self.framework)
173143

174144

145+
class SafeTensorsFileLoader(BaseSafeTensorsFileLoader):
146+
r"""Load .safetensors files lazily.
147+
148+
Args:
149+
devcie (str): target device.
150+
pg (Optional[Any]): process group-like objects for distributed. None for single GPU use-cases.
151+
bbuf_size_kb (int): bounce buffer size for file copies.
152+
max_threads (int): maximum number of threads for memory copies.
153+
nogds (bool): if True, trun off GDS and fallback to pread with bounce buffer.
154+
debug_log (bool): enable debug logs.
155+
156+
Examples:
157+
>> from fastsafetensors import SafeTensorsFileLoader
158+
>> src_files = download(target_dir, "gpt2")
159+
>> loader = SafeTensorsFileLoader(Device("cpu"), nogds=True, debug_log=True)
160+
>> loader.add_filenames({0: src_files})
161+
>> bufs = loader.copy_files_to_device()
162+
>> print(bufs.get_tensor(loader.get_keys()[0]))
163+
>> loader.close()
164+
"""
165+
166+
def __init__(
167+
self,
168+
pg: Optional[Any],
169+
device: str = "cpu",
170+
bbuf_size_kb: int = 16 * 1024,
171+
max_threads: int = 16,
172+
nogds: bool = False,
173+
set_numa: bool = True,
174+
disable_cache: bool = True,
175+
debug_log: bool = False,
176+
framework="pytorch",
177+
):
178+
self.framework = get_framework_op(framework)
179+
self.pg = self.framework.get_process_group(pg)
180+
self.device = self.framework.get_device(device, self.pg)
181+
182+
fstcpp.set_debug_log(debug_log)
183+
global loaded_nvidia
184+
if not loaded_nvidia:
185+
fstcpp.load_nvidia_functions()
186+
if not nogds:
187+
# no need to init gds and consume 10s+ in none-gds case
188+
if fstcpp.init_gds() != 0:
189+
raise Exception(f"[FAIL] init_gds()")
190+
loaded_nvidia = True
191+
192+
copier = new_gds_file_copier(self.device, bbuf_size_kb, max_threads, nogds)
193+
super().__init__(
194+
pg, self.device, copier, set_numa, disable_cache, debug_log, framework
195+
)
196+
197+
175198
class fastsafe_open:
176199
"""
177200
Opens a safetensors lazily and returns tensors as asked

0 commit comments

Comments
 (0)