Skip to content

Commit c4994b3

Browse files
authored
ZLUDA v3.8.7 (#66)
* Add dummy cuFFTW library. * Bump version. * Implement fft functions required to run torch fftn, ifftn, and rfftn.
1 parent d60bddb commit c4994b3

File tree

14 files changed

+1049
-193
lines changed

14 files changed

+1049
-193
lines changed

Cargo.lock

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 66 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,66 @@
1-
[workspace]
2-
3-
resolver = "2"
4-
5-
# Remember to also update the project's Cargo.toml
6-
# if it's a top-level project
7-
members = [
8-
"atiadlxx-sys",
9-
"comgr",
10-
"cuda_base",
11-
"cuda_types",
12-
"detours-sys",
13-
"ext/llvm-sys.rs",
14-
"hip_common",
15-
"hip_runtime-sys",
16-
"hipblaslt-sys",
17-
"hipfft-sys",
18-
"hiprt-sys",
19-
"miopen-sys",
20-
"offline_compiler",
21-
"optix_base",
22-
"optix_dump",
23-
"process_address_table",
24-
"ptx",
25-
"rocblas-sys",
26-
"rocm_smi-sys",
27-
"rocsparse-sys",
28-
"xtask",
29-
"zluda",
30-
"zluda_api",
31-
"zluda_blas",
32-
"zluda_blaslt",
33-
"zluda_ccl",
34-
"zluda_dark_api",
35-
"zluda_dnn",
36-
"zluda_dump",
37-
"zluda_fft",
38-
"zluda_inject",
39-
"zluda_lib",
40-
"zluda_llvm",
41-
"zluda_ml",
42-
"zluda_redirect",
43-
"zluda_rt",
44-
"zluda_rtc",
45-
"zluda_runtime",
46-
"zluda_sparse",
47-
]
48-
49-
# Cargo does not support OS-specific or profile-specific
50-
# targets. We keep list here to bare minimum and rely on xtask
51-
default-members = [
52-
"zluda_lib",
53-
"zluda_ml",
54-
"zluda_inject",
55-
"zluda_redirect"
56-
]
57-
58-
[profile.dev.package.blake3]
59-
opt-level = 3
60-
61-
[profile.dev.package.lz4-sys]
62-
opt-level = 3
63-
64-
[profile.dev.package.xtask]
65-
opt-level = 2
1+
[workspace]
2+
3+
resolver = "2"
4+
5+
# Remember to also update the project's Cargo.toml
6+
# if it's a top-level project
7+
members = [
8+
"atiadlxx-sys",
9+
"comgr",
10+
"cuda_base",
11+
"cuda_types",
12+
"detours-sys",
13+
"ext/llvm-sys.rs",
14+
"hip_common",
15+
"hip_runtime-sys",
16+
"hipblaslt-sys",
17+
"hipfft-sys",
18+
"hiprt-sys",
19+
"miopen-sys",
20+
"offline_compiler",
21+
"optix_base",
22+
"optix_dump",
23+
"process_address_table",
24+
"ptx",
25+
"rocblas-sys",
26+
"rocm_smi-sys",
27+
"rocsparse-sys",
28+
"xtask",
29+
"zluda",
30+
"zluda_api",
31+
"zluda_blas",
32+
"zluda_blaslt",
33+
"zluda_ccl",
34+
"zluda_dark_api",
35+
"zluda_dnn",
36+
"zluda_dump",
37+
"zluda_fft",
38+
"zluda_fftw",
39+
"zluda_inject",
40+
"zluda_lib",
41+
"zluda_llvm",
42+
"zluda_ml",
43+
"zluda_redirect",
44+
"zluda_rt",
45+
"zluda_rtc",
46+
"zluda_runtime",
47+
"zluda_sparse",
48+
]
49+
50+
# Cargo does not support OS-specific or profile-specific
51+
# targets. We keep list here to bare minimum and rely on xtask
52+
default-members = [
53+
"zluda_lib",
54+
"zluda_ml",
55+
"zluda_inject",
56+
"zluda_redirect"
57+
]
58+
59+
[profile.dev.package.blake3]
60+
opt-level = 3
61+
62+
[profile.dev.package.lz4-sys]
63+
opt-level = 3
64+
65+
[profile.dev.package.xtask]
66+
opt-level = 2

hipblaslt-sys/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ impl hipblasOperation_t {
1111
impl hipblasOperation_t {
1212
pub const HIPBLAS_OP_C: hipblasOperation_t = hipblasOperation_t(113);
1313
}
14+
#[allow(non_camel_case_types)]
1415
#[repr(transparent)]
1516
#[derive(Copy, Clone, Hash, PartialEq, Eq)]
1617
pub struct hipblasOperation_t(pub ::std::os::raw::c_int);

zluda_blas/src/lib.rs

Lines changed: 12 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
#![allow(warnings)]
1+
#[allow(warnings)]
22
mod common;
3+
#[allow(warnings)]
34
mod cublas;
5+
#[allow(warnings)]
46
mod cublasxt;
57

68
pub use common::*;
@@ -13,7 +15,7 @@ use rocsolver_sys::{
1315
rocsolver_cgetrf_batched, rocsolver_cgetri_outofplace_batched, rocsolver_dgetrs_batched,
1416
rocsolver_sgetrs_batched, rocsolver_zgetrf_batched, rocsolver_zgetri_outofplace_batched,
1517
};
16-
use std::{mem, ptr};
18+
use std::ptr;
1719

1820
#[cfg(debug_assertions)]
1921
pub(crate) fn unsupported() -> cublasStatus_t {
@@ -223,61 +225,20 @@ unsafe fn set_stream(handle: cublasHandle_t, stream_id: cudaStream_t) -> cublasS
223225
) -> CUresult>(b"cuGetExportTable\0")
224226
.unwrap();
225227
let mut export_table = ptr::null();
226-
(cu_get_export_table)(&mut export_table, &zluda_dark_api::ZludaExt::GUID);
228+
assert_eq!(
229+
(cu_get_export_table)(&mut export_table, &zluda_dark_api::ZludaExt::GUID),
230+
CUresult::CUDA_SUCCESS
231+
);
227232
let zluda_ext = zluda_dark_api::ZludaExt::new(export_table);
228233
let stream: Result<_, _> = zluda_ext.get_hip_stream(stream_id as _).into();
229234
to_cuda(rocblas_set_stream(handle as _, stream.unwrap() as _))
230235
}
231236

232-
fn set_math_mode(handle: cublasHandle_t, mode: cublasMath_t) -> cublasStatus_t {
237+
fn set_math_mode(_handle: cublasHandle_t, _mode: cublasMath_t) -> cublasStatus_t {
233238
// llama.cpp uses CUBLAS_TF32_TENSOR_OP_MATH
234239
cublasStatus_t::CUBLAS_STATUS_SUCCESS
235240
}
236241

237-
unsafe fn sgemm(
238-
transa: std::ffi::c_char,
239-
transb: std::ffi::c_char,
240-
m: i32,
241-
n: i32,
242-
k: i32,
243-
alpha: f32,
244-
a: *const f32,
245-
lda: i32,
246-
b: *const f32,
247-
ldb: i32,
248-
beta: f32,
249-
c: *mut f32,
250-
ldc: i32,
251-
) -> cublasStatus_t {
252-
let mut handle = mem::zeroed();
253-
let mut status = to_cuda(rocblas_create_handle(handle));
254-
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
255-
return status;
256-
}
257-
let transa = op_from_cuda(cublasOperation_t(transa as _));
258-
let transb = op_from_cuda(cublasOperation_t(transb as _));
259-
status = to_cuda(rocblas_sgemm(
260-
handle.cast(),
261-
transa,
262-
transb,
263-
m,
264-
n,
265-
k,
266-
&alpha,
267-
a,
268-
lda,
269-
b,
270-
ldb,
271-
&beta,
272-
c,
273-
ldc,
274-
));
275-
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
276-
return status;
277-
}
278-
to_cuda(rocblas_destroy_handle(*handle))
279-
}
280-
281242
unsafe fn sgemm_v2(
282243
handle: cublasHandle_t,
283244
transa: cublasOperation_t,
@@ -495,7 +456,7 @@ unsafe fn gemm_ex(
495456
))
496457
}
497458

498-
fn to_algo(algo: cublasGemmAlgo_t) -> rocblas_gemm_algo_ {
459+
fn to_algo(_algo: cublasGemmAlgo_t) -> rocblas_gemm_algo_ {
499460
// only option
500461
rocblas_gemm_algo::rocblas_gemm_algo_standard
501462
}
@@ -807,7 +768,7 @@ unsafe fn sgetrs_batched(
807768
dev_ipiv: *const i32,
808769
b: *const *mut f32,
809770
ldb: i32,
810-
info: *mut i32,
771+
_info: *mut i32,
811772
batch_size: i32,
812773
) -> cublasStatus_t {
813774
let trans = op_from_cuda_for_solver(trans);
@@ -837,7 +798,7 @@ unsafe fn dgetrs_batched(
837798
dev_ipiv: *const i32,
838799
b: *const *mut f64,
839800
ldb: i32,
840-
info: *mut i32,
801+
_info: *mut i32,
841802
batch_size: i32,
842803
) -> cublasStatus_t {
843804
let trans = op_from_cuda_for_solver(trans);
@@ -1048,50 +1009,6 @@ unsafe fn dger(
10481009
))
10491010
}
10501011

1051-
unsafe fn dgemm(
1052-
transa: std::ffi::c_char,
1053-
transb: std::ffi::c_char,
1054-
m: i32,
1055-
n: i32,
1056-
k: i32,
1057-
alpha: f64,
1058-
a: *const f64,
1059-
lda: i32,
1060-
b: *const f64,
1061-
ldb: i32,
1062-
beta: f64,
1063-
c: *mut f64,
1064-
ldc: i32,
1065-
) -> cublasStatus_t {
1066-
let mut handle = mem::zeroed();
1067-
let mut status = to_cuda(rocblas_create_handle(handle));
1068-
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1069-
return status;
1070-
}
1071-
let transa = op_from_cuda(cublasOperation_t(transa as _));
1072-
let transb = op_from_cuda(cublasOperation_t(transb as _));
1073-
status = to_cuda(rocblas_dgemm(
1074-
handle.cast(),
1075-
transa,
1076-
transb,
1077-
m,
1078-
n,
1079-
k,
1080-
&alpha,
1081-
a,
1082-
lda,
1083-
b,
1084-
ldb,
1085-
&beta,
1086-
c,
1087-
ldc,
1088-
));
1089-
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
1090-
return status;
1091-
}
1092-
to_cuda(rocblas_destroy_handle(*handle))
1093-
}
1094-
10951012
unsafe fn dgemm_v2(
10961013
handle: *mut cublasContext,
10971014
transa: cublasOperation_t,

zluda_fft/src/cufft.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,15 +380,15 @@ pub unsafe extern "system" fn cufftSetWorkArea(
380380
plan: cufftHandle,
381381
workArea: *mut ::std::os::raw::c_void,
382382
) -> cufftResult {
383-
crate::unsupported()
383+
crate::set_work_area(plan, workArea)
384384
}
385385

386386
#[no_mangle]
387387
pub unsafe extern "system" fn cufftSetAutoAllocation(
388388
plan: cufftHandle,
389389
autoAllocate: ::std::os::raw::c_int,
390390
) -> cufftResult {
391-
crate::unsupported()
391+
crate::set_auto_allocation(plan, autoAllocate)
392392
}
393393

394394
#[no_mangle]

zluda_fft/src/cufftxt.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,22 @@ pub unsafe extern "system" fn cufftXtMakePlanMany(
376376
workSize: *mut usize,
377377
executiontype: cudaDataType,
378378
) -> cufftResult {
379-
crate::unsupported()
379+
crate::xt_make_plan_many(
380+
plan,
381+
rank,
382+
n,
383+
inembed,
384+
istride,
385+
idist,
386+
inputtype,
387+
onembed,
388+
ostride,
389+
odist,
390+
outputtype,
391+
batch,
392+
workSize,
393+
executiontype,
394+
)
380395
}
381396

382397
#[no_mangle]
@@ -406,7 +421,7 @@ pub unsafe extern "system" fn cufftXtExec(
406421
output: *mut ::std::os::raw::c_void,
407422
direction: ::std::os::raw::c_int,
408423
) -> cufftResult {
409-
crate::unsupported()
424+
crate::xt_exec(plan, input, output, direction)
410425
}
411426

412427
#[no_mangle]

0 commit comments

Comments
 (0)