Skip to content

Commit 9253132

Browse files
committed
Add oneMKL DFT
1 parent a0ee705 commit 9253132

File tree

9 files changed

+1327
-5
lines changed

9 files changed

+1327
-5
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ authors = ["Tim Besard <tim.besard@gmail.com>"]
44
version = "2.0.3"
55

66
[deps]
7+
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
78
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
89
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
910
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
11+
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1012
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
1113
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
1214
GPUToolbox = "096a3bc2-3ced-46d0-87f4-dd12716f4bfc"
@@ -29,6 +31,7 @@ oneAPI_Level_Zero_Loader_jll = "13eca655-d68d-5b81-8367-6d99d727ab01"
2931
oneAPI_Support_jll = "b049733a-a71d-5ed3-8eba-7d323ac00b36"
3032

3133
[compat]
34+
AbstractFFTs = "1.5.0"
3235
Adapt = "4"
3336
CEnum = "0.4, 0.5"
3437
ExprTools = "0.1"

deps/CMakeLists.txt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,21 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
66

77
project(oneAPISupport)
88

9-
add_library(oneapi_support SHARED src/sycl.h src/sycl.hpp src/sycl.cpp src/onemkl.h src/onemkl.cpp)
9+
add_library(oneapi_support SHARED
10+
src/sycl.h
11+
src/sycl.hpp
12+
src/sycl.cpp
13+
src/onemkl.h
14+
src/onemkl.cpp
15+
src/onemkl_dft.h
16+
src/onemkl_dft.cpp
17+
)
1018

1119
target_link_libraries(oneapi_support
1220
mkl_sycl
21+
# DFT component libraries needed for oneMKL DFT template instantiations
22+
mkl_sycl_dft
23+
mkl_cdft_core
1324
mkl_intel_ilp64
1425
mkl_sequential
1526
mkl_core

deps/src/onemkl_dft.cpp

Lines changed: 466 additions & 0 deletions
Large diffs are not rendered by default.

deps/src/onemkl_dft.h

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#pragma once
2+
3+
#include "sycl.h"
4+
5+
#include <stdint.h>
6+
7+
#ifdef __cplusplus
8+
extern "C" {
9+
#endif
10+
11+
// Return codes (negative values indicate errors):
12+
// 0 : success
13+
// -1 : internal error / exception caught
14+
// -2 : invalid argument (null pointer, bad length, etc.)
15+
// -3 : invalid descriptor state (e.g. uninitialized desc->ptr) or size query failure
16+
#define ONEMKL_DFT_STATUS_SUCCESS 0
17+
#define ONEMKL_DFT_STATUS_ERROR -1
18+
#define ONEMKL_DFT_STATUS_INVALID_ARGUMENT -2
19+
#define ONEMKL_DFT_STATUS_BAD_STATE -3
20+
21+
// DFT precision
22+
typedef enum {
23+
ONEMKL_DFT_PRECISION_SINGLE = 0,
24+
ONEMKL_DFT_PRECISION_DOUBLE = 1
25+
} onemklDftPrecision;
26+
27+
// DFT domain
28+
typedef enum {
29+
ONEMKL_DFT_DOMAIN_REAL = 0,
30+
ONEMKL_DFT_DOMAIN_COMPLEX = 1
31+
} onemklDftDomain;
32+
33+
// Configuration parameters (subset mirrors oneapi::mkl::dft::config_param)
34+
typedef enum {
35+
ONEMKL_DFT_PARAM_FORWARD_DOMAIN = 0,
36+
ONEMKL_DFT_PARAM_DIMENSION,
37+
ONEMKL_DFT_PARAM_LENGTHS,
38+
ONEMKL_DFT_PARAM_PRECISION,
39+
ONEMKL_DFT_PARAM_FORWARD_SCALE,
40+
ONEMKL_DFT_PARAM_BACKWARD_SCALE,
41+
ONEMKL_DFT_PARAM_NUMBER_OF_TRANSFORMS,
42+
ONEMKL_DFT_PARAM_COMPLEX_STORAGE,
43+
ONEMKL_DFT_PARAM_PLACEMENT,
44+
ONEMKL_DFT_PARAM_INPUT_STRIDES,
45+
ONEMKL_DFT_PARAM_OUTPUT_STRIDES,
46+
ONEMKL_DFT_PARAM_FWD_DISTANCE,
47+
ONEMKL_DFT_PARAM_BWD_DISTANCE,
48+
ONEMKL_DFT_PARAM_WORKSPACE, // size query / placement
49+
ONEMKL_DFT_PARAM_WORKSPACE_ESTIMATE_BYTES,
50+
ONEMKL_DFT_PARAM_WORKSPACE_BYTES,
51+
ONEMKL_DFT_PARAM_FWD_STRIDES,
52+
ONEMKL_DFT_PARAM_BWD_STRIDES,
53+
ONEMKL_DFT_PARAM_WORKSPACE_PLACEMENT,
54+
ONEMKL_DFT_PARAM_WORKSPACE_EXTERNAL_BYTES
55+
} onemklDftConfigParam;
56+
57+
// Configuration values (mirrors oneapi::mkl::dft::config_value)
58+
typedef enum {
59+
ONEMKL_DFT_VALUE_COMMITTED = 0,
60+
ONEMKL_DFT_VALUE_UNCOMMITTED,
61+
ONEMKL_DFT_VALUE_COMPLEX_COMPLEX,
62+
ONEMKL_DFT_VALUE_REAL_REAL,
63+
ONEMKL_DFT_VALUE_INPLACE,
64+
ONEMKL_DFT_VALUE_NOT_INPLACE,
65+
ONEMKL_DFT_VALUE_WORKSPACE_AUTOMATIC, // internal
66+
ONEMKL_DFT_VALUE_ALLOW,
67+
ONEMKL_DFT_VALUE_AVOID,
68+
ONEMKL_DFT_VALUE_WORKSPACE_INTERNAL,
69+
ONEMKL_DFT_VALUE_WORKSPACE_EXTERNAL
70+
} onemklDftConfigValue;
71+
72+
// Opaque descriptor handle
73+
struct onemklDftDescriptor_st;
74+
typedef struct onemklDftDescriptor_st *onemklDftDescriptor_t;
75+
76+
// Creation / destruction
77+
int onemklDftCreate1D(onemklDftDescriptor_t *desc,
78+
onemklDftPrecision precision,
79+
onemklDftDomain domain,
80+
int64_t length);
81+
82+
int onemklDftCreateND(onemklDftDescriptor_t *desc,
83+
onemklDftPrecision precision,
84+
onemklDftDomain domain,
85+
int64_t dim,
86+
const int64_t *lengths);
87+
88+
int onemklDftDestroy(onemklDftDescriptor_t desc);
89+
90+
// Commit descriptor to a queue
91+
int onemklDftCommit(onemklDftDescriptor_t desc, syclQueue_t queue);
92+
93+
// Configuration set
94+
int onemklDftSetValueInt64(onemklDftDescriptor_t desc, onemklDftConfigParam param, int64_t value);
95+
int onemklDftSetValueDouble(onemklDftDescriptor_t desc, onemklDftConfigParam param, double value);
96+
int onemklDftSetValueInt64Array(onemklDftDescriptor_t desc, onemklDftConfigParam param, const int64_t *values, int64_t n);
97+
int onemklDftSetValueConfigValue(onemklDftDescriptor_t desc, onemklDftConfigParam param, onemklDftConfigValue value);
98+
99+
// Configuration get
100+
int onemklDftGetValueInt64(onemklDftDescriptor_t desc, onemklDftConfigParam param, int64_t *value);
101+
int onemklDftGetValueDouble(onemklDftDescriptor_t desc, onemklDftConfigParam param, double *value);
102+
// For array queries pass *n as available length; on return *n has elements written.
103+
int onemklDftGetValueInt64Array(onemklDftDescriptor_t desc, onemklDftConfigParam param, int64_t *values, int64_t *n);
104+
int onemklDftGetValueConfigValue(onemklDftDescriptor_t desc, onemklDftConfigParam param, onemklDftConfigValue *value);
105+
106+
// Compute (USM) in-place/out-of-place. Pointers must reference memory
107+
// appropriate for precision/domain. No size checking is performed.
108+
int onemklDftComputeForward(onemklDftDescriptor_t desc, void *inout);
109+
int onemklDftComputeForwardOutOfPlace(onemklDftDescriptor_t desc, void *in, void *out);
110+
int onemklDftComputeBackward(onemklDftDescriptor_t desc, void *inout);
111+
int onemklDftComputeBackwardOutOfPlace(onemklDftDescriptor_t desc, void *in, void *out);
112+
113+
// Compute (buffer API) variants. Host pointers are wrapped in temporary 1D buffers.
114+
int onemklDftComputeForwardBuffer(onemklDftDescriptor_t desc, void *inout);
115+
int onemklDftComputeForwardOutOfPlaceBuffer(onemklDftDescriptor_t desc, void *in, void *out);
116+
int onemklDftComputeBackwardBuffer(onemklDftDescriptor_t desc, void *inout);
117+
int onemklDftComputeBackwardOutOfPlaceBuffer(onemklDftDescriptor_t desc, void *in, void *out);
118+
119+
// Introspection: write out the integral values of selected config_param enums in
120+
// the same order as our public enum declaration above. Returns number written or
121+
// a negative error code if n is insufficient or arguments invalid.
122+
int onemklDftQueryParamIndices(int64_t *out, int64_t n);
123+
124+
#ifdef __cplusplus
125+
}
126+
#endif

0 commit comments

Comments
 (0)