Skip to content

Commit 463fb38

Browse files
committed
fixed cuda reformatting, updated hooks
1 parent 9d816fe commit 463fb38

File tree

9 files changed

+52
-107
lines changed

9 files changed

+52
-107
lines changed

.clang-format

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ AllowShortLoopsOnASingleLine: true
1414
BreakBeforeBraces: Attach
1515
BreakBeforeBinaryOperators: None
1616
ColumnLimit: 90
17-
ExperimentalAutoDetectBinPacking: true
17+
ExperimentalAutoDetectBinPacking: false # <--- safer default for .cu/.cuh
1818
FixNamespaceComments: true
1919
IndentWidth: 2
2020
MaxEmptyLinesToKeep: 1
2121
NamespaceIndentation: None
22-
ReflowComments: true
22+
ReflowComments: IndentOnly
2323
PenaltyBreakComment: 1
24-
PenaltyBreakOpenParenthesis: 1 # modified; was 0
24+
PenaltyBreakOpenParenthesis: 1
2525
SortIncludes: CaseSensitive
2626
SortUsingDeclarations: true
2727
SpacesBeforeTrailingComments: 1

.pre-commit-config.yaml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
repos:
22
- repo: https://github.com/pre-commit/mirrors-clang-format
3-
rev: 'v19.1.7'
3+
rev: 'v20.1.8'
44
hooks:
55
- id: clang-format
66
types_or: [c++, c, cuda]
7+
files: \.(c|cc|h|hpp|cu|cuh)$
78
exclude: '(^|/)(matlab/.*)$'
89
- repo: https://github.com/pre-commit/pre-commit-hooks
910
rev: v5.0.0
@@ -14,13 +15,19 @@ repos:
1415
- id: check-illegal-windows-names
1516
- id: mixed-line-ending
1617
- repo: https://github.com/BlankSpruce/gersemi
17-
rev: 0.19.1
18+
rev: 0.21.0
1819
hooks:
1920
- id: gersemi
2021
- repo: https://github.com/abravalheri/validate-pyproject
21-
rev: v0.23 # Use the latest stable version
22+
rev: v0.24.1
2223
hooks:
2324
- id: validate-pyproject
2425
# Optional: Include additional validations from SchemaStore
2526
additional_dependencies: ["validate-pyproject-schema-store[all]"]
2627
files: ^python/(finufft|cufinufft)/pyproject\.toml$
28+
- repo: https://github.com/LilSpazJoekp/docstrfmt
29+
rev: v1.10.0
30+
hooks:
31+
- id: docstrfmt
32+
language_version: python3
33+
types_or: [rst] # only needed if you want to include txt files.

include/cufinufft/impl.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -707,8 +707,7 @@ int cufinufft_setpts_impl(int M, T *d_kx, T *d_ky, T *d_kz, int N, T *d_s, T *d_
707707
d_plan->spopts);
708708
if (d_plan->dim > 1) {
709709
onedim_nuft_kernel_precomp<T>(nuft_precomp_f.data() + MAX_NQUAD,
710-
nuft_precomp_z.data() + MAX_NQUAD,
711-
d_plan->spopts);
710+
nuft_precomp_z.data() + MAX_NQUAD, d_plan->spopts);
712711
}
713712
if (d_plan->dim > 2) {
714713
onedim_nuft_kernel_precomp<T>(nuft_precomp_f.data() + 2 * MAX_NQUAD,
@@ -765,8 +764,8 @@ int cufinufft_setpts_impl(int M, T *d_kx, T *d_ky, T *d_kz, int N, T *d_s, T *d_
765764
thrust::cuda::par.on(stream), phase_iterator, phase_iterator + N,
766765
d_plan->deconv, d_plan->deconv,
767766
[c1, c2, c3, d1, d2, d3, realsign] __host__
768-
__device__(const thrust::tuple<T, T, T> tuple, cuda_complex<T> deconv)
769-
-> cuda_complex<T> {
767+
__device__(const thrust::tuple<T, T, T> tuple,
768+
cuda_complex<T> deconv) -> cuda_complex<T> {
770769
// d2 and d3 are 0 if dim < 2 and dim < 3
771770
const auto phase = c1 * (thrust::get<0>(tuple) + d1) +
772771
c2 * (thrust::get<1>(tuple) + d2) +

src/cuda/memtransfer_wrapper.cu

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,9 @@ int allocgpumem1d_plan(cufinufft_plan_t<T> *d_plan)
5252
cudaMallocWrapper(&d_plan->binstartpts, numbins * sizeof(int), stream,
5353
d_plan->supports_pools))))
5454
goto finalize;
55-
if ((ier = checkCudaErrors(cudaMallocWrapper(&d_plan->subprobstartpts,
56-
(numbins + 1) * sizeof(int),
57-
stream,
58-
d_plan->supports_pools))))
55+
if ((ier = checkCudaErrors(
56+
cudaMallocWrapper(&d_plan->subprobstartpts, (numbins + 1) * sizeof(int),
57+
stream, d_plan->supports_pools))))
5958
goto finalize;
6059
} break;
6160
default:
@@ -65,10 +64,8 @@ int allocgpumem1d_plan(cufinufft_plan_t<T> *d_plan)
6564

6665
if (!d_plan->opts.gpu_spreadinterponly) {
6766
if ((ier = checkCudaErrors(
68-
cudaMallocWrapper(&d_plan->fw,
69-
maxbatchsize * nf1 * sizeof(cuda_complex<T>),
70-
stream,
71-
d_plan->supports_pools))))
67+
cudaMallocWrapper(&d_plan->fw, maxbatchsize * nf1 * sizeof(cuda_complex<T>),
68+
stream, d_plan->supports_pools))))
7269
goto finalize;
7370
if ((ier = checkCudaErrors(
7471
cudaMallocWrapper(&d_plan->fwkerhalf1, (nf1 / 2 + 1) * sizeof(T), stream,

test/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,8 @@ function(add_tests_with_prec PREC REQ_TOL CHECK_TOL SUFFIX)
9595
COMMAND spreadinterp1d_test${SUFFIX} 1e3 1e3 ${REQ_TOL} 0 2 2.0 ${CHECK_TOL}
9696
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}
9797
)
98-
99-
add_test(NAME run_adjointness_${PREC} COMMAND adjointness${SUFFIX} WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
10098

99+
add_test(NAME run_adjointness_${PREC} COMMAND adjointness${SUFFIX} WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
101100
endfunction()
102101

103102
# use above function to actually add the tests, with certain requested and check

test/cuda/cufinufft3d_test.cu

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -256,20 +256,19 @@ int run_test(int method, int type, int N1, int N2, int N3, int M, T tol, T check
256256

257257
int main(int argc, char *argv[]) {
258258
if (argc != 11) {
259-
fprintf(stderr,
260-
"Usage: cufinufft3d1_test method type N1 N2 N3 M tol checktol prec\n"
261-
"Arguments:\n"
262-
" method: One of\n"
263-
" 1: nupts driven,\n"
264-
" 2: sub-problem, or\n"
265-
" 4: block gather.\n"
266-
" type: Type of transform (1, 2, 3)"
267-
" N1, N2, N3: The size of the 3D array\n"
268-
" M: The number of non-uniform points\n"
269-
" tol: NUFFT tolerance\n"
270-
" checktol: relative error to pass test\n"
271-
" prec: 'f' or 'd' (float/double)\n"
272-
" upsamplefac: upsampling factor\n");
259+
fprintf(stderr, "Usage: cufinufft3d1_test method type N1 N2 N3 M tol checktol prec\n"
260+
"Arguments:\n"
261+
" method: One of\n"
262+
" 1: nupts driven,\n"
263+
" 2: sub-problem, or\n"
264+
" 4: block gather.\n"
265+
" type: Type of transform (1, 2, 3)"
266+
" N1, N2, N3: The size of the 3D array\n"
267+
" M: The number of non-uniform points\n"
268+
" tol: NUFFT tolerance\n"
269+
" checktol: relative error to pass test\n"
270+
" prec: 'f' or 'd' (float/double)\n"
271+
" upsamplefac: upsampling factor\n");
273272
return 1;
274273
}
275274
const int method = atoi(argv[1]);

test/utils/dirft1d.hpp

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,8 @@ template<typename BIGINT,
2020
typename CArr, // cArr[j] yields Complex<FLT>
2121
typename FArr // fArr[m] yields Complex<FLT>
2222
>
23-
void dirft1d1(const BIGINT nj,
24-
const XArr &x,
25-
const CArr &c,
26-
const int iflag,
27-
const BIGINT ms,
28-
FArr &fArr) {
23+
void dirft1d1(const BIGINT nj, const XArr &x, const CArr &c, const int iflag,
24+
const BIGINT ms, FArr &fArr) {
2925
using Complex = std::decay_t<decltype(c[0])>;
3026
using FLT = typename Complex::value_type;
3127

@@ -63,11 +59,7 @@ template<typename BIGINT,
6359
typename CArr, // cArr[j] yields Complex<FLT>
6460
typename FArr // fArr[m] yields Complex<FLT>
6561
>
66-
void dirft1d2(const BIGINT nj,
67-
const XArr &x,
68-
CArr &c,
69-
const int iflag,
70-
const BIGINT ms,
62+
void dirft1d2(const BIGINT nj, const XArr &x, CArr &c, const int iflag, const BIGINT ms,
7163
const FArr &f) {
7264
using Complex = std::decay_t<decltype(c[0])>;
7365
using FLT = typename Complex::value_type;
@@ -107,13 +99,8 @@ template<typename BIGINT,
10799
typename SArr, // sArr[k] yields real FLT
108100
typename FArr // fArr[k] yields Complex<FLT>
109101
>
110-
void dirft1d3(const BIGINT nj,
111-
const XArr &x,
112-
const CArr &c,
113-
const int iflag,
114-
const BIGINT nk,
115-
const SArr &sArr,
116-
FArr &f) {
102+
void dirft1d3(const BIGINT nj, const XArr &x, const CArr &c, const int iflag,
103+
const BIGINT nk, const SArr &sArr, FArr &f) {
117104
using Complex = std::decay_t<decltype(c[0])>;
118105
using FLT = typename Complex::value_type;
119106

test/utils/dirft2d.hpp

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,8 @@ template<typename BIGINT,
2323
typename CArr, // c[j] → Complex<FLT>
2424
typename FArr // f[m] → Complex<FLT>
2525
>
26-
void dirft2d1(BIGINT nj,
27-
const XYArr &x,
28-
const XYArr &y,
29-
const CArr &c,
30-
int iflag,
31-
BIGINT ms,
32-
BIGINT mt,
33-
FArr &f) {
26+
void dirft2d1(BIGINT nj, const XYArr &x, const XYArr &y, const CArr &c, int iflag,
27+
BIGINT ms, BIGINT mt, FArr &f) {
3428
using Complex = std::decay_t<decltype(c[0])>;
3529
using FLT = typename Complex::value_type;
3630

@@ -71,14 +65,8 @@ void dirft2d1(BIGINT nj,
7165
// exponential.
7266
// Uses winding trick. Barnett 1/26/17
7367
template<typename BIGINT, typename XYArr, typename CArr, typename FArr>
74-
void dirft2d2(BIGINT nj,
75-
const XYArr &x,
76-
const XYArr &y,
77-
CArr &c,
78-
int iflag,
79-
BIGINT ms,
80-
BIGINT mt,
81-
const FArr &f) {
68+
void dirft2d2(BIGINT nj, const XYArr &x, const XYArr &y, CArr &c, int iflag, BIGINT ms,
69+
BIGINT mt, const FArr &f) {
8270
using Complex = std::decay_t<decltype(c[0])>;
8371
using FLT = typename Complex::value_type;
8472

@@ -117,15 +105,8 @@ void dirft2d2(BIGINT nj,
117105
// If iflag>0 the + sign is used, otherwise the - sign is used, in the
118106
// exponential. Simple brute force. Barnett 1/26/17
119107
template<typename BIGINT, typename XYArr, typename CArr, typename STArr, typename FArr>
120-
void dirft2d3(BIGINT nj,
121-
const XYArr &x,
122-
const XYArr &y,
123-
const CArr &c,
124-
int iflag,
125-
BIGINT nk,
126-
const STArr &s,
127-
const STArr &t,
128-
FArr &f) {
108+
void dirft2d3(BIGINT nj, const XYArr &x, const XYArr &y, const CArr &c, int iflag,
109+
BIGINT nk, const STArr &s, const STArr &t, FArr &f) {
129110
using Complex = std::decay_t<decltype(c[0])>;
130111
using FLT = typename Complex::value_type;
131112

test/utils/dirft3d.hpp

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,8 @@ template<typename BIGINT,
2323
typename XYZArr, // x[j], y[j], z[j] → FLT
2424
typename CArr, // c[j] → Complex<FLT>
2525
typename FArr> // f[m] → Complex<FLT>
26-
void dirft3d1(BIGINT nj,
27-
const XYZArr &x,
28-
const XYZArr &y,
29-
const XYZArr &z,
30-
const CArr &c,
31-
int iflag,
32-
BIGINT ms,
33-
BIGINT mt,
34-
BIGINT mu,
35-
FArr &f) {
26+
void dirft3d1(BIGINT nj, const XYZArr &x, const XYZArr &y, const XYZArr &z, const CArr &c,
27+
int iflag, BIGINT ms, BIGINT mt, BIGINT mu, FArr &f) {
3628
using Complex = std::decay_t<decltype(c[0])>;
3729
using FLT = typename Complex::value_type;
3830

@@ -85,16 +77,8 @@ void dirft3d1(BIGINT nj,
8577
// Uses winding trick. Barnett 2/1/17
8678
// ------------------------------------------------------------
8779
template<typename BIGINT, typename XYZArr, typename CArr, typename FArr>
88-
void dirft3d2(BIGINT nj,
89-
const XYZArr &x,
90-
const XYZArr &y,
91-
const XYZArr &z,
92-
CArr &c,
93-
int iflag,
94-
BIGINT ms,
95-
BIGINT mt,
96-
BIGINT mu,
97-
const FArr &f) {
80+
void dirft3d2(BIGINT nj, const XYZArr &x, const XYZArr &y, const XYZArr &z, CArr &c,
81+
int iflag, BIGINT ms, BIGINT mt, BIGINT mu, const FArr &f) {
9882
using Complex = std::decay_t<decltype(c[0])>;
9983
using FLT = typename Complex::value_type;
10084

@@ -137,16 +121,8 @@ void dirft3d2(BIGINT nj,
137121
// f[k] = Σ_j c[j] exp(i * iflag * (s[k] x[j] + t[k] y[j] + u[k] z[j]))
138122
// ------------------------------------------------------------
139123
template<typename BIGINT, typename XYZArr, typename CArr, typename STUArr, typename FArr>
140-
void dirft3d3(BIGINT nj,
141-
const XYZArr &x,
142-
const XYZArr &y,
143-
const XYZArr &z,
144-
const CArr &c,
145-
int iflag,
146-
BIGINT nk,
147-
const STUArr &s,
148-
const STUArr &t,
149-
const STUArr &u,
124+
void dirft3d3(BIGINT nj, const XYZArr &x, const XYZArr &y, const XYZArr &z, const CArr &c,
125+
int iflag, BIGINT nk, const STUArr &s, const STUArr &t, const STUArr &u,
150126
FArr &f) {
151127
using Complex = std::decay_t<decltype(c[0])>;
152128
using FLT = typename Complex::value_type;

0 commit comments

Comments
 (0)