Skip to content

Commit e4e085e

Browse files
refactor for lint
TODO: STProcessGroupWrapper Signed-off-by: Takeshi Yoshimura <[email protected]>
1 parent 041b899 commit e4e085e

File tree

12 files changed

+957
-438
lines changed

12 files changed

+957
-438
lines changed

fastsafetensors/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
# Copyright 2024 IBM Inc. All rights reserved
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from .common import (SafeTensorsMetadata, SingleGroup, TensorFrame,
5-
alloc_tensor_memory, free_tensor_memory,
6-
get_device_numa_node, str_to_dtype)
4+
from .common import (
5+
SafeTensorsMetadata,
6+
TensorFrame,
7+
alloc_tensor_memory,
8+
free_tensor_memory,
9+
get_device_numa_node,
10+
)
711
from .file_buffer import FilesBufferOnDevice
812
from .loader import SafeTensorsFileLoader, fastsafe_open
13+
from .st_types import SingleGroup, STDevice, STDeviceType, STEnv

fastsafetensors/common.py

Lines changed: 247 additions & 137 deletions
Large diffs are not rendered by default.

fastsafetensors/copier/gds.py

Lines changed: 71 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,55 @@
11
# Copyright 2024 IBM Inc. All rights reserved
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from typing import Dict, Optional
5+
46
import torch
7+
58
from .. import cpp as fstcpp
6-
from typing import Dict
7-
from ..common import alloc_tensor_memory, free_tensor_memory, SafeTensorsMetadata, ALIGN, CUDA_PTR_ALIGN, paddle_loaded
8-
if paddle_loaded:
9-
import paddle
9+
from ..common import (
10+
ALIGN,
11+
CUDA_PTR_ALIGN,
12+
CUDA_VER,
13+
SafeTensorsMetadata,
14+
alloc_tensor_memory,
15+
free_tensor_memory,
16+
)
17+
from ..st_types import STDevice, STDeviceType, STDType
18+
1019

1120
class GdsFileCopier:
12-
def __init__(self, metadata: SafeTensorsMetadata, device: torch.device, reader: fstcpp.gds_file_reader, debug_log: bool=False):
21+
def __init__(
22+
self,
23+
metadata: SafeTensorsMetadata,
24+
device: STDevice,
25+
reader: fstcpp.gds_file_reader,
26+
debug_log: bool = False,
27+
):
1328
self.metadata = metadata
1429
self.device = device
1530
self.reader = reader
1631
self.debug_log = debug_log
1732
self.gbuf = None
18-
self.fh = 0
33+
self.fh: Optional[fstcpp.gds_file_handle] = None
1934
self.copy_reqs: Dict[int, int] = {}
2035
self.aligned_length = 0
21-
try:
22-
if self.metadata.framework == "pytorch":
23-
cuda_vers_list = torch.version.cuda.split('.')
24-
elif paddle_loaded and self.metadata.framework == "paddle":
25-
cuda_vers_list = paddle.version.cuda().split('.')
26-
cudavers = list(map(int, cuda_vers_list))
27-
# CUDA 12.2 (GDS version 1.7) introduces support for non O_DIRECT file descriptors
28-
# Compatible with CUDA 11.x
29-
self.o_direct = not (cudavers[0] > 12 or (cudavers[0] == 12 and cudavers[1] >= 2))
30-
except:
31-
self.o_direct = True
36+
cudavers = list(map(int, CUDA_VER.split(".")))
37+
# CUDA 12.2 (GDS version 1.7) introduces support for non O_DIRECT file descriptors
38+
# Compatible with CUDA 11.x
39+
self.o_direct = not (
40+
cudavers[0] > 12 or (cudavers[0] == 12 and cudavers[1] >= 2)
41+
)
3242

3343
def set_o_direct(self, enable: bool):
3444
self.o_direct = enable
3545

36-
def submit_io(self, use_buf_register: bool, max_copy_block_size: int)->fstcpp.gds_device_buffer:
37-
dev_is_cuda = (self.metadata.framework == "pytorch" and self.device.type == 'cuda') or (paddle_loaded and self.metadata.framework == "paddle" and "gpu" in self.device)
46+
def submit_io(
47+
self, use_buf_register: bool, max_copy_block_size: int
48+
) -> fstcpp.gds_device_buffer:
49+
dev_is_cuda = (
50+
self.device.type == STDeviceType.CUDA
51+
or self.device.type == STDeviceType.GPU
52+
)
3853
self.fh = fstcpp.gds_file_handle(self.metadata.src, self.o_direct, dev_is_cuda)
3954
offset = self.metadata.header_length
4055
length = self.metadata.size_bytes - self.metadata.header_length
@@ -55,7 +70,11 @@ def submit_io(self, use_buf_register: bool, max_copy_block_size: int)->fstcpp.gd
5570
if req_len > max_copy_block_size:
5671
req_len = max_copy_block_size
5772
if gbuf.cufile_register(count, req_len) < 0:
58-
raise Exception("submit_io: register_buffer failed, ptr=0x{:x}, count={}, len={}".format(gbuf.get_base_address(), count, req_len))
73+
raise Exception(
74+
"submit_io: register_buffer failed, ptr=0x{:x}, count={}, len={}".format(
75+
gbuf.get_base_address(), count, req_len
76+
)
77+
)
5978
count += req_len
6079

6180
count = 0
@@ -64,40 +83,63 @@ def submit_io(self, use_buf_register: bool, max_copy_block_size: int)->fstcpp.gd
6483
if req_len > max_copy_block_size:
6584
req_len = max_copy_block_size
6685
# TODO: pass timeout so that wait_copy_tensors can recognize too slow pread()
67-
req = self.reader.submit_read(self.fh, gbuf, aligned_offset + count, req_len, count, self.metadata.size_bytes)
86+
req = self.reader.submit_read(
87+
self.fh,
88+
gbuf,
89+
aligned_offset + count,
90+
req_len,
91+
count,
92+
self.metadata.size_bytes,
93+
)
6894
self.copy_reqs[req] = -1 if not use_buf_register else count
6995
count += req_len
7096
self.aligned_offset = aligned_offset
7197
self.aligned_length = aligned_length
7298
return gbuf
7399

74-
def wait_io(self, gbuf: fstcpp.gds_device_buffer, dtype: torch.dtype=None, noalign: bool=False)->Dict[str, torch.Tensor]:
100+
def wait_io(
101+
self,
102+
gbuf: fstcpp.gds_device_buffer,
103+
dtype: STDType = STDType.AUTO,
104+
noalign: bool = False,
105+
) -> Dict[str, torch.Tensor]:
75106
failed = []
76-
for req, c in sorted(self.copy_reqs.items(), key=lambda x:x[0]):
107+
for req, c in sorted(self.copy_reqs.items(), key=lambda x: x[0]):
77108
count = self.reader.wait_read(req)
78109
if count < 0:
79110
failed.append(req)
80111
if c != -1:
81112
gbuf.cufile_deregister(c)
82-
if self.fh != 0:
113+
if self.fh is not None:
83114
del self.fh
84-
self.fh = 0
115+
self.fh = None
85116
if len(failed) > 0:
86-
raise Exception(f"wait_io: wait_gds_read failed, failed={failed}, reqs={self.copy_reqs}")
117+
raise Exception(
118+
f"wait_io: wait_gds_read failed, failed={failed}, reqs={self.copy_reqs}"
119+
)
87120
self.copy_reqs = {}
88121
if not noalign and not self.metadata.aligned and self.aligned_length > 0:
89122
misaligned_bytes = self.metadata.header_length % CUDA_PTR_ALIGN
90-
length = 1024*1024*1024
123+
length = 1024 * 1024 * 1024
91124
tmp_gbuf = alloc_tensor_memory(length, self.device, self.metadata.framework)
92125
count = 0
93126
while count + misaligned_bytes < self.aligned_length:
94127
l = self.aligned_length - misaligned_bytes - count
95128
if l > length:
96129
l = length
97130
if self.debug_log:
98-
print("wait_io: fix misalignment, src=0x{:x}, misaligned_bytes={}, count={}, tmp=0x{:x}".format(gbuf.get_base_address(), misaligned_bytes, count, tmp_gbuf.get_base_address()))
131+
print(
132+
"wait_io: fix misalignment, src=0x{:x}, misaligned_bytes={}, count={}, tmp=0x{:x}".format(
133+
gbuf.get_base_address(),
134+
misaligned_bytes,
135+
count,
136+
tmp_gbuf.get_base_address(),
137+
)
138+
)
99139
gbuf.memmove(count, misaligned_bytes + count, tmp_gbuf, l)
100140
count += l
101141
free_tensor_memory(tmp_gbuf, self.device, self.metadata.framework)
102142
self.aligned_offset += misaligned_bytes
103-
return self.metadata.get_tensors(gbuf, self.device, self.aligned_offset, dtype=dtype)
143+
return self.metadata.get_tensors(
144+
gbuf, self.device, self.aligned_offset, dtype=dtype
145+
)

fastsafetensors/copier/nogds.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,67 @@
11
# Copyright 2024 IBM Inc. All rights reserved
22
# SPDX-License-Identifier: Apache-2.0
33

4-
import torch
54
import os
5+
from typing import Dict, List
6+
7+
import torch
8+
69
from .. import cpp as fstcpp
7-
from typing import Dict
8-
from ..common import alloc_tensor_memory, SafeTensorsMetadata, ALIGN, CUDA_PTR_ALIGN
10+
from ..common import ALIGN, CUDA_PTR_ALIGN, SafeTensorsMetadata, alloc_tensor_memory
11+
from ..st_types import STDevice, STDType
12+
913

1014
class NoGdsFileCopier:
11-
def __init__(self, metadata: SafeTensorsMetadata, device: torch.device, reader: fstcpp.nogds_file_reader, debug_log: bool=False):
15+
def __init__(
16+
self,
17+
metadata: SafeTensorsMetadata,
18+
device: STDevice,
19+
reader: fstcpp.nogds_file_reader,
20+
debug_log: bool = False,
21+
):
1222
self.metadata = metadata
1323
self.reader = reader
1424
self.fd = os.open(metadata.src, os.O_RDONLY, 0o644)
1525
if self.fd < 0:
16-
raise Exception(f"NoGdsFileCopier.__init__: failed to open, file={metadata.src}")
26+
raise Exception(
27+
f"NoGdsFileCopier.__init__: failed to open, file={metadata.src}"
28+
)
1729
self.device = device
1830
self.debug_log = debug_log
19-
self.reqs = []
31+
self.reqs: List[int] = []
2032

21-
def submit_io(self, use_buf_register: bool, max_copy_block_size: int)->fstcpp.gds_device_buffer:
33+
def submit_io(
34+
self, use_buf_register: bool, max_copy_block_size: int
35+
) -> fstcpp.gds_device_buffer:
2236
total_length = self.metadata.size_bytes - self.metadata.header_length
2337
gbuf = alloc_tensor_memory(total_length, self.device, self.metadata.framework)
2438
count = 0
2539
while count < total_length:
2640
l = total_length - count
2741
if max_copy_block_size < l:
2842
l = max_copy_block_size
29-
req = self.reader.submit_read(self.fd, gbuf, self.metadata.header_length + count, l, count)
43+
req = self.reader.submit_read(
44+
self.fd, gbuf, self.metadata.header_length + count, l, count
45+
)
3046
if req < 0:
3147
raise Exception(f"submit_io: submit_nogds_read failed, err={req}")
3248
self.reqs.append(req)
3349
count += l
3450
return gbuf
3551

36-
def wait_io(self, gbuf: fstcpp.gds_device_buffer, dtype: torch.dtype=None, noalign: bool=False)->Dict[str, torch.Tensor]:
52+
def wait_io(
53+
self,
54+
gbuf: fstcpp.gds_device_buffer,
55+
dtype: STDType = STDType.AUTO,
56+
noalign: bool = False,
57+
) -> Dict[str, torch.Tensor]:
3758
for req in self.reqs:
3859
count = self.reader.wait_read(req)
3960
if count < 0:
4061
raise Exception(f"wait_io: wait_nogds_read failed, req={req}")
4162
if self.fd > 0:
4263
os.close(self.fd)
4364
self.fd = 0
44-
return self.metadata.get_tensors(gbuf, self.device, self.metadata.header_length, dtype=dtype)
65+
return self.metadata.get_tensors(
66+
gbuf, self.device, self.metadata.header_length, dtype=dtype
67+
)

fastsafetensors/cpp.pyi

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2025 IBM Inc. All rights reserved
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# fastsafetensors/cpp.pyi
5+
6+
class gds_device_buffer:
7+
def __init__(self, devPtr_base: int, length: int, use_cuda: bool) -> None: ...
8+
def cufile_register(self, offset: int, length: int) -> int: ...
9+
def cufile_deregister(self, offset: int) -> int: ...
10+
def memmove(
11+
self, dst_off: int, src_off: int, tmp: "gds_device_buffer", length: int
12+
) -> int: ...
13+
def get_base_address(self) -> int: ...
14+
15+
class nogds_file_reader:
16+
def __init__(
17+
self, use_mmap: bool, bbuf_size_kb: int, max_threads: int, use_cuda: bool
18+
) -> None: ...
19+
def submit_read(
20+
self, fd: int, dst: gds_device_buffer, offset: int, length: int, ptr_off: int
21+
) -> int: ...
22+
def wait_read(self, thread_id: int) -> int: ...
23+
24+
class gds_file_handle:
25+
def __init__(self, filename: str, o_direct: bool, use_cuda: bool) -> None: ...
26+
27+
class gds_file_reader:
28+
def __init__(self, max_threads: int, use_cuda: bool) -> None: ...
29+
def submit_read(
30+
self,
31+
fh: gds_file_handle,
32+
dst: gds_device_buffer,
33+
offset: int,
34+
length: int,
35+
ptr_off: int,
36+
file_length: int,
37+
) -> int: ...
38+
def wait_read(self, id: int) -> int: ...
39+
40+
def is_cuda_found() -> bool: ...
41+
def is_cufile_found() -> bool: ...
42+
def cufile_version() -> int: ...
43+
def get_alignment_size() -> int: ...
44+
def set_debug_log(debug_log: bool) -> None: ...
45+
def init_gds(
46+
max_direct_io_size_in_kb: int, max_pinned_memory_size_in_kb: int
47+
) -> int: ...
48+
def close_gds() -> int: ...
49+
def get_device_pci_bus(deviceId: int) -> str: ...
50+
def set_numa_node(numa_node: int) -> int: ...
51+
def read_buffer(dst: int, length: int) -> bytes: ...
52+
def cpu_malloc(length: int) -> int: ...
53+
def cpu_free(addr: int) -> None: ...
54+
def gpu_malloc(length: int) -> int: ...
55+
def gpu_free(addr: int) -> None: ...
56+
def load_nvidia_functions() -> None: ...

0 commit comments

Comments
 (0)