Skip to content

Commit 09c69de

Browse files
Check CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED (#39)
Signed-off-by: Takeshi Yoshimura <[email protected]>
1 parent d9d23b6 commit 09c69de

File tree

5 files changed

+59
-7
lines changed

5 files changed

+59
-7
lines changed

.github/workflows/test-paddle.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ jobs:
4848
fi
4949
pip install torch==${torch_ver} --index-url https://download.pytorch.org/whl/cpu # transformers requires torch
5050
pip install paddlepaddle==3.0.0
51-
pip install pytest pytest-cov setuptools_scm safetensors transformers==${tf_ver} numpy==${npy_ver}
51+
# TOFIX: safetensors version (0.7.0 had a bug around fp8 in Dec 5 2025)
52+
pip install pytest pytest-cov setuptools_scm safetensors==0.6.2 transformers==${tf_ver} numpy==${npy_ver}
5253
- name: Build Package
5354
run: |
5455
pip install .

fastsafetensors/copier/gds.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,24 @@ def new_gds_file_copier(
172172
raise Exception(
173173
"[FAIL] GPU runtime library (libcudart.so or libamdhip64.so) does not exist"
174174
)
175-
if not fstcpp.is_cufile_found() and not nogds:
176-
warnings.warn(
177-
"libcufile.so does not exist but nogds is False. use nogds=True",
178-
UserWarning,
175+
if device_is_not_cpu and not nogds:
176+
gds_supported = fstcpp.is_gds_supported(
177+
device.index if device.index is not None else 0
179178
)
180-
nogds = True
179+
if gds_supported < 0:
180+
raise Exception(f"is_gds_supported({device.index}) failed")
181+
if not fstcpp.is_cufile_found():
182+
warnings.warn(
183+
"libcufile.so does not exist but nogds is False. use nogds=True",
184+
UserWarning,
185+
)
186+
nogds = True
187+
elif gds_supported == 0:
188+
warnings.warn(
189+
"GDS is not supported in this platform but nogds is False. use nogds=True",
190+
UserWarning,
191+
)
192+
nogds = True
181193

182194
if nogds:
183195
nogds_reader = fstcpp.nogds_file_reader(

fastsafetensors/cpp.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def is_cufile_found() -> bool: ...
4646
def cufile_version() -> int: ...
4747
def get_alignment_size() -> int: ...
4848
def set_debug_log(debug_log: bool) -> None: ...
49+
def is_gds_supported(deviceId: int) -> int: ...
4950
def init_gds() -> int: ...
5051
def close_gds() -> int: ...
5152
def get_device_pci_bus(deviceId: int) -> str: ...

fastsafetensors/cpp/ext.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,13 @@ static void load_nvidia_functions() {
148148
mydlsym(&cuda_fns.cudaDeviceGetPCIBusId, handle_cudart, "cudaDeviceGetPCIBusId");
149149
mydlsym(&cuda_fns.cudaDeviceMalloc, handle_cudart, "cudaMalloc");
150150
mydlsym(&cuda_fns.cudaDeviceFree, handle_cudart, "cudaFree");
151-
bool success = cuda_fns.cudaMemcpy && cuda_fns.cudaDeviceSynchronize && cuda_fns.cudaHostAlloc && cuda_fns.cudaFreeHost && cuda_fns.cudaDeviceGetPCIBusId && cuda_fns.cudaDeviceMalloc && cuda_fns.cudaDeviceFree;
151+
mydlsym(&cuda_fns.cudaDriverGetVersion, handle_cudart, "cudaDriverGetVersion");
152+
mydlsym(&cuda_fns.cudaDeviceGetAttribute, handle_cudart, "cudaDeviceGetAttribute");
153+
bool success = cuda_fns.cudaMemcpy && cuda_fns.cudaDeviceSynchronize;
154+
success = success && cuda_fns.cudaHostAlloc && cuda_fns.cudaFreeHost;
155+
success = success && cuda_fns.cudaDeviceGetPCIBusId && cuda_fns.cudaDeviceMalloc;
156+
success = success && cuda_fns.cudaDeviceFree && cuda_fns.cudaDriverGetVersion;
157+
success = success && cuda_fns.cudaDeviceGetAttribute;
152158
if (!success) {
153159
cuda_found = false;
154160
if (init_log) {
@@ -159,6 +165,8 @@ static void load_nvidia_functions() {
159165
}
160166
}
161167
dlclose(handle_cudart);
168+
} else if (init_log) {
169+
fprintf(stderr, "[DEBUG] %s is not installed. fallback\n", cudartLib);
162170
}
163171
if (!cuda_found) {
164172
cuda_fns.cudaMemcpy = cpu_cudaMemcpy;
@@ -291,6 +299,32 @@ void init_gil_release_from_env() {
291299
}
292300
}
293301

302+
int is_gds_supported(int deviceId)
303+
{
304+
#ifndef USE_ROCM
305+
int gdr_support = 1;
306+
int driverVersion = 0;
307+
cudaError_t err;
308+
309+
err = cuda_fns.cudaDriverGetVersion(&driverVersion);
310+
if (err != cudaSuccess) {
311+
std::fprintf(stderr, "is_gds_supported: cudaDriverGetVersion failed, deviceId=%d, err=%d\n", deviceId, err);
312+
return -1;
313+
}
314+
315+
if (driverVersion > 11030) {
316+
err = cuda_fns.cudaDeviceGetAttribute(&gdr_support, cudaDevAttrGPUDirectRDMASupported, deviceId);
317+
if (err != cudaSuccess) {
318+
std::fprintf(stderr, "is_gds_supported: cudaDeviceGetAttribute failed, deviceId=%d, err=%d\n", deviceId, err);
319+
return -1;
320+
}
321+
}
322+
return gdr_support;
323+
#endif
324+
// ROCm does not have GDS
325+
return 0;
326+
}
327+
294328
int init_gds()
295329
{
296330
CUfileError_t err;
@@ -787,6 +821,7 @@ PYBIND11_MODULE(__MOD_NAME__, m)
787821
m.def("cufile_version", &cufile_version);
788822
m.def("set_debug_log", &set_debug_log);
789823
m.def("get_alignment_size", &get_alignment_size);
824+
m.def("is_gds_supported", &is_gds_supported);
790825
m.def("init_gds", &init_gds);
791826
m.def("close_gds", &close_gds);
792827
m.def("get_device_pci_bus", &get_device_pci_bus);

fastsafetensors/cpp/ext.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ typedef struct CUfileError { CUfileOpError err; } CUfileError_t;
3939
// Define minimal CUDA/HIP types for both platforms to avoid compile-time dependencies
4040
// We load all GPU functions dynamically at runtime via dlopen()
4141
typedef enum cudaError { cudaSuccess = 0, cudaErrorMemoryAllocation = 2 } cudaError_t;
42+
enum cudaDeviceAttr {cudaDevAttrGPUDirectRDMASupported = 116};
4243
// Platform-specific enum values - CUDA and HIP have different values for HostToDevice
4344
#ifdef USE_ROCM
4445
enum cudaMemcpyKind { cudaMemcpyHostToDevice=1, cudaMemcpyDefault = 4 };
@@ -212,6 +213,8 @@ typedef struct ext_funcs {
212213
cudaError_t (*cudaDeviceMalloc)(void **, size_t);
213214
cudaError_t (*cudaDeviceFree)(void *);
214215
int (*numa_run_on_node)(int);
216+
cudaError_t (*cudaDriverGetVersion)(int *);
217+
cudaError_t (*cudaDeviceGetAttribute)(int *, enum cudaDeviceAttr, int);
215218
} ext_funcs_t;
216219

217220
typedef struct cpp_metrics {

0 commit comments

Comments
 (0)