Skip to content

Commit d9d23b6

Browse files
authored
add parallel loader (#33)
Signed-off-by: yuanyuxing.yyx <[email protected]>
1 parent d6f998a commit d9d23b6

File tree

6 files changed

+567
-7
lines changed

6 files changed

+567
-7
lines changed

fastsafetensors/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55

66
__version__ = version(__name__)
77

8-
from .common import SafeTensorsMetadata, SingleGroup, TensorFrame, get_device_numa_node
8+
from .common import (
9+
SafeTensorsMetadata,
10+
SingleGroup,
11+
TensorFrame,
12+
get_device_numa_node,
13+
)
914
from .file_buffer import FilesBufferOnDevice
10-
from .loader import SafeTensorsFileLoader, fastsafe_open
15+
from .loader import BaseSafeTensorsFileLoader, SafeTensorsFileLoader, fastsafe_open
16+
from .parallel_loader import ParallelLoader

fastsafetensors/cpp.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,5 @@ def gpu_malloc(length: int) -> int: ...
5757
def gpu_free(addr: int) -> None: ...
5858
def load_nvidia_functions() -> None: ...
5959
def get_cpp_metrics() -> cpp_metrics: ...
60+
def set_gil_release(gil_release: bool) -> None: ...
61+
def get_gil_release() -> bool: ...

fastsafetensors/cpp/ext.cpp

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@
99
#include <sys/mman.h>
1010
#include <chrono>
1111
#include <dlfcn.h>
12+
#include <cstdlib>
13+
#include <algorithm>
1214

1315
#include "cuda_compat.h"
1416
#include "ext.hpp"
1517

1618
#define ALIGN 4096
1719

1820
static bool debug_log = false;
21+
static bool enable_gil_release = false;
1922

2023
static cpp_metrics_t mc = {.bounce_buffer_bytes = 0};
2124

@@ -266,6 +269,28 @@ void set_debug_log(bool _debug_log)
266269
debug_log = _debug_log;
267270
}
268271

272+
void set_gil_release(bool enable) {
273+
enable_gil_release = enable;
274+
}
275+
276+
bool get_gil_release() {
277+
return enable_gil_release;
278+
}
279+
280+
void init_gil_release_from_env() {
281+
const char* env_val = std::getenv("FASTSAFETENSORS_ENABLE_GIL_RELEASE");
282+
if (env_val != nullptr) {
283+
std::string env_str(env_val);
284+
// Convert to lowercase for case-insensitive comparison
285+
std::transform(env_str.begin(), env_str.end(), env_str.begin(), ::tolower);
286+
enable_gil_release = (env_str == "1" || env_str == "true" || env_str == "yes" || env_str == "on");
287+
if (debug_log) {
288+
std::printf("[DEBUG] GIL release %s via environment variable FASTSAFETENSORS_ENABLE_GIL_RELEASE=%s\n",
289+
enable_gil_release ? "enabled" : "disabled", env_val);
290+
}
291+
}
292+
}
293+
269294
int init_gds()
270295
{
271296
CUfileError_t err;
@@ -741,6 +766,8 @@ cpp_metrics_t get_cpp_metrics() {
741766

742767
PYBIND11_MODULE(__MOD_NAME__, m)
743768
{
769+
// Initialize GIL release setting from environment variable on module load
770+
init_gil_release_from_env();
744771
// Export both is_cuda_found and is_hip_found on all platforms
745772
// Use string concatenation to prevent hipify from converting the export names
746773
#ifdef USE_ROCM
@@ -771,6 +798,8 @@ PYBIND11_MODULE(__MOD_NAME__, m)
771798
m.def("gpu_free", &gpu_free);
772799
m.def("load_nvidia_functions", &load_nvidia_functions);
773800
m.def("get_cpp_metrics", &get_cpp_metrics);
801+
m.def("set_gil_release", &set_gil_release);
802+
m.def("get_gil_release", &get_gil_release);
774803

775804
pybind11::class_<gds_device_buffer>(m, "gds_device_buffer")
776805
.def(pybind11::init<const uintptr_t, const uint64_t, bool>())
@@ -780,18 +809,56 @@ PYBIND11_MODULE(__MOD_NAME__, m)
780809
.def("get_base_address", &gds_device_buffer::get_base_address)
781810
.def("get_length", &gds_device_buffer::get_length);
782811

812+
// Helper lambdas to conditionally apply GIL release
813+
auto nogds_submit_read = [](nogds_file_reader& self, const int fd, const gds_device_buffer& dst, const int64_t offset, const int64_t length, const uint64_t ptr_off) {
814+
if (enable_gil_release) {
815+
pybind11::gil_scoped_release release;
816+
return self.submit_read(fd, dst, offset, length, ptr_off);
817+
} else {
818+
return self.submit_read(fd, dst, offset, length, ptr_off);
819+
}
820+
};
821+
822+
auto nogds_wait_read = [](nogds_file_reader& self, const int thread_id) {
823+
if (enable_gil_release) {
824+
pybind11::gil_scoped_release release;
825+
return self.wait_read(thread_id);
826+
} else {
827+
return self.wait_read(thread_id);
828+
}
829+
};
830+
783831
pybind11::class_<nogds_file_reader>(m, "nogds_file_reader")
784832
.def(pybind11::init<const bool, const uint64_t, const int, bool>())
785-
.def("submit_read", &nogds_file_reader::submit_read)
786-
.def("wait_read", &nogds_file_reader::wait_read);
833+
.def("submit_read", nogds_submit_read)
834+
.def("wait_read", nogds_wait_read);
787835

788836
pybind11::class_<gds_file_handle>(m, "gds_file_handle")
789837
.def(pybind11::init<std::string, bool, bool>());
790838

839+
// Helper lambdas for gds_file_reader to conditionally apply GIL release
840+
auto gds_submit_read = [](gds_file_reader& self, const gds_file_handle &fh, const gds_device_buffer &dst, const uint64_t offset, const uint64_t length, const uint64_t ptr_off, const uint64_t file_length) {
841+
if (enable_gil_release) {
842+
pybind11::gil_scoped_release release;
843+
return self.submit_read(fh, dst, offset, length, ptr_off, file_length);
844+
} else {
845+
return self.submit_read(fh, dst, offset, length, ptr_off, file_length);
846+
}
847+
};
848+
849+
auto gds_wait_read = [](gds_file_reader& self, const int id) {
850+
if (enable_gil_release) {
851+
pybind11::gil_scoped_release release;
852+
return self.wait_read(id);
853+
} else {
854+
return self.wait_read(id);
855+
}
856+
};
857+
791858
pybind11::class_<gds_file_reader>(m, "gds_file_reader")
792859
.def(pybind11::init<const int, bool>())
793-
.def("submit_read", &gds_file_reader::submit_read)
794-
.def("wait_read", &gds_file_reader::wait_read);
860+
.def("submit_read", gds_submit_read)
861+
.def("wait_read", gds_wait_read);
795862

796863
pybind11::class_<cpp_metrics_t>(m, "cpp_metrics")
797864
.def(pybind11::init<>())

fastsafetensors/cpp/ext.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ typedef struct CUfileDrvProps {
7171

7272
int get_alignment_size();
7373
void set_debug_log(bool _debug_log);
74+
void set_gil_release(bool enable);
75+
bool get_gil_release();
76+
void init_gil_release_from_env();
7477
int init_gds();
7578
int close_gds();
7679
std::string get_device_pci_bus(int deviceId);

fastsafetensors/loader.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
disable_cache: bool = True,
4141
debug_log: bool = False,
4242
framework="pytorch",
43+
**kwargs,
4344
):
4445
self.framework = get_framework_op(framework)
4546
self.pg = self.framework.get_process_group(pg)
@@ -174,6 +175,7 @@ def __init__(
174175
disable_cache: bool = True,
175176
debug_log: bool = False,
176177
framework="pytorch",
178+
**kwargs,
177179
):
178180
self.framework = get_framework_op(framework)
179181
self.pg = self.framework.get_process_group(pg)
@@ -191,7 +193,14 @@ def __init__(
191193

192194
copier = new_gds_file_copier(self.device, bbuf_size_kb, max_threads, nogds)
193195
super().__init__(
194-
pg, self.device, copier, set_numa, disable_cache, debug_log, framework
196+
pg,
197+
self.device,
198+
copier,
199+
set_numa,
200+
disable_cache,
201+
debug_log,
202+
framework,
203+
**kwargs,
195204
)
196205

197206

0 commit comments

Comments
 (0)