Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/test-paddle.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ jobs:
fi
pip install torch==${torch_ver} --index-url https://download.pytorch.org/whl/cpu # transformers requires torch
pip install paddlepaddle==3.0.0
pip install pytest pytest-cov setuptools_scm safetensors transformers==${tf_ver} numpy==${npy_ver}
# TOFIX: safetensors version (0.7.0 had a bug around fp8 in Dec 5 2025)
pip install pytest pytest-cov setuptools_scm safetensors==0.6.2 transformers==${tf_ver} numpy==${npy_ver}
- name: Build Package
run: |
pip install .
Expand Down
22 changes: 17 additions & 5 deletions fastsafetensors/copier/gds.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,24 @@ def new_gds_file_copier(
raise Exception(
"[FAIL] GPU runtime library (libcudart.so or libamdhip64.so) does not exist"
)
if not fstcpp.is_cufile_found() and not nogds:
warnings.warn(
"libcufile.so does not exist but nogds is False. use nogds=True",
UserWarning,
if device_is_not_cpu and not nogds:
gds_supported = fstcpp.is_gds_supported(
device.index if device.index is not None else 0
)
nogds = True
if gds_supported < 0:
raise Exception(f"is_gds_supported({device.index}) failed")
if not fstcpp.is_cufile_found():
warnings.warn(
"libcufile.so does not exist but nogds is False. use nogds=True",
UserWarning,
)
nogds = True
elif gds_supported == 0:
warnings.warn(
"GDS is not supported in this platform but nogds is False. use nogds=True",
UserWarning,
)
nogds = True

if nogds:
nogds_reader = fstcpp.nogds_file_reader(
Expand Down
1 change: 1 addition & 0 deletions fastsafetensors/cpp.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def is_cufile_found() -> bool: ...
def cufile_version() -> int: ...
def get_alignment_size() -> int: ...
def set_debug_log(debug_log: bool) -> None: ...
def is_gds_supported(deviceId: int) -> int: ...
def init_gds() -> int: ...
def close_gds() -> int: ...
def get_device_pci_bus(deviceId: int) -> str: ...
Expand Down
37 changes: 36 additions & 1 deletion fastsafetensors/cpp/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,13 @@ static void load_nvidia_functions() {
mydlsym(&cuda_fns.cudaDeviceGetPCIBusId, handle_cudart, "cudaDeviceGetPCIBusId");
mydlsym(&cuda_fns.cudaDeviceMalloc, handle_cudart, "cudaMalloc");
mydlsym(&cuda_fns.cudaDeviceFree, handle_cudart, "cudaFree");
bool success = cuda_fns.cudaMemcpy && cuda_fns.cudaDeviceSynchronize && cuda_fns.cudaHostAlloc && cuda_fns.cudaFreeHost && cuda_fns.cudaDeviceGetPCIBusId && cuda_fns.cudaDeviceMalloc && cuda_fns.cudaDeviceFree;
mydlsym(&cuda_fns.cudaDriverGetVersion, handle_cudart, "cudaDriverGetVersion");
mydlsym(&cuda_fns.cudaDeviceGetAttribute, handle_cudart, "cudaDeviceGetAttribute");
bool success = cuda_fns.cudaMemcpy && cuda_fns.cudaDeviceSynchronize;
success = success && cuda_fns.cudaHostAlloc && cuda_fns.cudaFreeHost;
success = success && cuda_fns.cudaDeviceGetPCIBusId && cuda_fns.cudaDeviceMalloc;
success = success && cuda_fns.cudaDeviceFree && cuda_fns.cudaDriverGetVersion;
success = success && cuda_fns.cudaDeviceGetAttribute;
if (!success) {
cuda_found = false;
if (init_log) {
Expand All @@ -159,6 +165,8 @@ static void load_nvidia_functions() {
}
}
dlclose(handle_cudart);
} else if (init_log) {
fprintf(stderr, "[DEBUG] %s is not installed. fallback\n", cudartLib);
}
if (!cuda_found) {
cuda_fns.cudaMemcpy = cpu_cudaMemcpy;
Expand Down Expand Up @@ -291,6 +299,32 @@ void init_gil_release_from_env() {
}
}

int is_gds_supported(int deviceId)
{
#ifndef USE_ROCM
int gdr_support = 1;
int driverVersion = 0;
cudaError_t err;

err = cuda_fns.cudaDriverGetVersion(&driverVersion);
if (err != cudaSuccess) {
std::fprintf(stderr, "is_gds_supported: cudaDriverGetVersion failed, deviceId=%d, err=%d\n", deviceId, err);
return -1;
}

if (driverVersion > 11030) {
err = cuda_fns.cudaDeviceGetAttribute(&gdr_support, cudaDevAttrGPUDirectRDMASupported, deviceId);
if (err != cudaSuccess) {
std::fprintf(stderr, "is_gds_supported: cudaDeviceGetAttribute failed, deviceId=%d, err=%d\n", deviceId, err);
return -1;
}
}
return gdr_support;
#endif
// ROCm does not have GDS
return 0;
}

int init_gds()
{
CUfileError_t err;
Expand Down Expand Up @@ -787,6 +821,7 @@ PYBIND11_MODULE(__MOD_NAME__, m)
m.def("cufile_version", &cufile_version);
m.def("set_debug_log", &set_debug_log);
m.def("get_alignment_size", &get_alignment_size);
m.def("is_gds_supported", &is_gds_supported);
m.def("init_gds", &init_gds);
m.def("close_gds", &close_gds);
m.def("get_device_pci_bus", &get_device_pci_bus);
Expand Down
3 changes: 3 additions & 0 deletions fastsafetensors/cpp/ext.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ typedef struct CUfileError { CUfileOpError err; } CUfileError_t;
// Define minimal CUDA/HIP types for both platforms to avoid compile-time dependencies
// We load all GPU functions dynamically at runtime via dlopen()
typedef enum cudaError { cudaSuccess = 0, cudaErrorMemoryAllocation = 2 } cudaError_t;
enum cudaDeviceAttr {cudaDevAttrGPUDirectRDMASupported = 116};
// Platform-specific enum values - CUDA and HIP have different values for HostToDevice
#ifdef USE_ROCM
enum cudaMemcpyKind { cudaMemcpyHostToDevice=1, cudaMemcpyDefault = 4 };
Expand Down Expand Up @@ -212,6 +213,8 @@ typedef struct ext_funcs {
cudaError_t (*cudaDeviceMalloc)(void **, size_t);
cudaError_t (*cudaDeviceFree)(void *);
int (*numa_run_on_node)(int);
cudaError_t (*cudaDriverGetVersion)(int *);
cudaError_t (*cudaDeviceGetAttribute)(int *, enum cudaDeviceAttr, int);
} ext_funcs_t;

typedef struct cpp_metrics {
Expand Down