Skip to content

Commit 0227db1

Browse files
committed
Extract bracket calculation (correct)
1 parent 6c77a1d commit 0227db1

File tree

5 files changed

+145
-96
lines changed

5 files changed

+145
-96
lines changed

cpp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ add_library(src-lib OBJECT lib/Naive.cpp lib/Naive.hpp lib/HermiteRunner.cpp lib
6262
lib/Transformer.hpp lib/Transformer.cpp
6363
lib/Exporter.hpp lib/Exporter.cpp
6464
lib/Filter.hpp lib/Filter.cpp
65+
lib/Brackets.hpp lib/Brackets.cpp
6566
)
6667
target_include_directories(src-lib PUBLIC lib/)
6768
target_link_libraries(src-lib mdspan fftw-cpp cnpy spdlog::spdlog)

cpp/lib/Brackets.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include "Brackets.hpp"
2+
#include "Transformer.hpp"
3+
#include "Filter.hpp"
4+
5+
namespace ahr {
6+
7+
void Brackets::prepareDXY_PH(View::C_XY const &view_K, View::C_XY const &viewDX_K,
8+
View::C_XY const &viewDY_K) const {
9+
grid.for_each_kxky([&](Dim kx, Dim ky) {
10+
viewDX_K(kx, ky) = kx_(kx) * 1i * view_K(kx, ky) * XYNorm;
11+
viewDY_K(kx, ky) = ky_(ky) * 1i * view_K(kx, ky) * XYNorm;
12+
});
13+
}
14+
15+
void Brackets::bracket(DxDy<View::R_XY> const &op1, DxDy<View::R_XY> const &op2,
16+
View::R_XY const &output) const {
17+
grid.for_each_xy([&](Dim x, Dim y) {
18+
output(x, y) = op1.DX(x, y) * op2.DY(x, y) - op1.DY(x, y) * op2.DX(x, y);
19+
});
20+
}
21+
22+
void Brackets::derivatives(View::C_XY const &op, DxDy<View::R_XY> output) const {
23+
DxDy<Buf::C_XY> Der_K{grid.KX, grid.KY};
24+
prepareDXY_PH(op, Der_K.DX, Der_K.DY);
25+
tf.bfft(Der_K.DX, output.DX);
26+
tf.bfft(Der_K.DY, output.DY);
27+
}
28+
29+
Brackets::Buf::C_XY Brackets::halfBracket(DxDy<View::R_XY> derOp1, DxDy<View::R_XY> derOp2) const {
30+
Buf::R_XY br = grid.rBufXY();
31+
Buf::C_XY br_K = grid.cBufXY();
32+
bracket(derOp1, derOp2, br);
33+
tf.fft(br, br_K);
34+
hlFilter(br_K);
35+
br_K(0, 0) = 0;
36+
return br_K;
37+
}
38+
39+
[[nodiscard]] Brackets::Buf::C_XY Brackets::fullBracket(View::C_XY op1, View::C_XY op2) const {
40+
auto derOp1 = grid.dBufXY(), derOp2 = grid.dBufXY();
41+
derivatives(op1, derOp1);
42+
derivatives(op2, derOp2);
43+
44+
return halfBracket(derOp1, derOp2);
45+
}
46+
47+
} // namespace ahr

cpp/lib/Brackets.hpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#pragma once
2+
#include "constants.hpp"
3+
#include "grid.hpp"
4+
5+
namespace ahr {
6+
7+
class Transformer;
8+
class HouLiFilter;
9+
10+
class Brackets {
11+
public:
12+
Brackets(Grid const &grid, Transformer const &tf, HouLiFilter const &hlFilter)
13+
: grid(grid), tf(tf), hlFilter(hlFilter) {}
14+
15+
private:
16+
Grid const &grid;
17+
Transformer const &tf;
18+
HouLiFilter const &hlFilter; // TODO vectorized
19+
20+
using View = Grid::View;
21+
using Buf = Grid::Buf;
22+
template <class T> using DxDy = Grid::DxDy<T>;
23+
24+
public:
25+
/// Compute real δx and δy derivatives of complex op, store in output
26+
void derivatives(View::C_XY const &op, DxDy<View::R_XY> output) const;
27+
28+
/// Compute the bracket of two complex fields using their derivatives
29+
[[nodiscard]] Buf::C_XY halfBracket(DxDy<View::R_XY> op1, DxDy<View::R_XY> op2) const;
30+
31+
/// Compute the bracket of two complex fields using their values
32+
[[nodiscard]] Buf::C_XY fullBracket(View::C_XY op1, View::C_XY op2) const;
33+
34+
private:
35+
/// Prepares the δx and δy of viewPH in phase space, as well as over-normalizes
36+
/// (after inverse FFT, values will be properly normalized)
37+
void prepareDXY_PH(View::C_XY const &view_K, View::C_XY const &viewDX_K,
38+
View::C_XY const &viewDY_K) const;
39+
40+
/// Computes bracket [op1, op2], expects normalized values
41+
void bracket(DxDy<View::R_XY> const &op1, DxDy<View::R_XY> const &op2,
42+
View::R_XY const &output) const;
43+
44+
/// Normalization factor for FFT
45+
const Real XYNorm{1.0 / Real(grid.X) / Real(grid.Y)};
46+
47+
// TODO extract to common utility
48+
[[nodiscard]] Real ky_(Dim ky) const {
49+
return (ky <= (grid.KY / 2) ? Real(ky) : Real(ky) - Real(grid.KY)) * Real(lx) / Real(ly);
50+
}
51+
[[nodiscard]] Real kx_(Dim kx) const { return Real(kx); }
52+
};
53+
} // namespace ahr

cpp/lib/Naive.cpp

Lines changed: 41 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,11 @@ void Naive::init(std::string_view equilibriumName) {
4949
moments_K(kx, ky, A_PAR) = aParEq_K(kx, ky);
5050
ueKPar_K(kx, ky) = -kPerp2(kx, ky) * moments_K(kx, ky, A_PAR);
5151
});
52-
derivatives(phi_K, dPhi);
53-
derivatives(ueKPar_K, dUEKPar);
52+
br.derivatives(phi_K, dPhi);
53+
br.derivatives(ueKPar_K, dUEKPar);
5454
for (int m = 0; m < g.M; ++m) {
55-
derivatives(momentK(m), Grid::sliceXY(dGM, m));
55+
br.derivatives(momentK(m), Grid::sliceXY(dGM, m));
5656
}
57-
5857
}
5958

6059
void Naive::run(Dim N, Dim saveInterval) {
@@ -92,8 +91,8 @@ void Naive::run(Dim N, Dim saveInterval) {
9291
auto GM_K_Star = g.cBufMXY(), GM_Nonlinear_K = g.cBufMXY();
9392

9493
// Compute N
95-
auto bracketPhiNE_K = halfBracket(dPhi, Grid::sliceXY(dGM, N_E));
96-
auto bracketAParUEKPar_K = halfBracket(Grid::sliceXY(dGM, A_PAR), dUEKPar);
94+
auto bracketPhiNE_K = br.halfBracket(dPhi, Grid::sliceXY(dGM, N_E));
95+
auto bracketAParUEKPar_K = br.halfBracket(Grid::sliceXY(dGM, A_PAR), dUEKPar);
9796

9897
// Compute A
9998
auto dPhiNeG2 = g.dBufXY();
@@ -111,8 +110,8 @@ void Naive::run(Dim N, Dim saveInterval) {
111110
});
112111
}
113112

114-
auto bracketAParPhiG2Ne_K = halfBracket(Grid::sliceXY(dGM, A_PAR), dPhiNeG2);
115-
auto bracketUEParPhi_K = halfBracket(dUEKPar, dPhi);
113+
auto bracketAParPhiG2Ne_K = br.halfBracket(Grid::sliceXY(dGM, A_PAR), dPhiNeG2);
114+
auto bracketUEParPhi_K = br.halfBracket(dUEKPar, dPhi);
116115

117116
g.for_each_kxky([&](Dim kx, Dim ky) {
118117
GM_Nonlinear_K(kx, ky, N_E) =
@@ -131,21 +130,22 @@ void Naive::run(Dim N, Dim saveInterval) {
131130

132131
if (g.M > 2) {
133132
// Compute G2
134-
auto bracketPhiG2_K = halfBracket(dPhi, Grid::sliceXY(dGM, G_MIN));
135-
auto bracketAParG3_K = halfBracket(Grid::sliceXY(dGM, A_PAR), Grid::sliceXY(dGM, G_MIN + 1));
133+
auto bracketPhiG2_K = br.halfBracket(dPhi, Grid::sliceXY(dGM, G_MIN));
134+
auto bracketAParG3_K =
135+
br.halfBracket(Grid::sliceXY(dGM, A_PAR), Grid::sliceXY(dGM, G_MIN + 1));
136136

137137
// Compute G_{M-1}
138-
auto bracketPhiGLast_K = halfBracket(dPhi, Grid::sliceXY(dGM, LAST));
139-
auto bracketAParGLast_K = halfBracket(Grid::sliceXY(dGM, A_PAR), Grid::sliceXY(dGM, LAST));
138+
auto bracketPhiGLast_K = br.halfBracket(dPhi, Grid::sliceXY(dGM, LAST));
139+
auto bracketAParGLast_K = br.halfBracket(Grid::sliceXY(dGM, A_PAR), Grid::sliceXY(dGM, LAST));
140140
g.for_each_kxky([&](Dim kx, Dim ky) {
141141
bracketAParGLast_K(kx, ky) *= nonlinear::GLastBracketFactor(g.M, kPerp2(kx, ky), hyper);
142142
bracketAParGLast_K(kx, ky) += rhoS / de * std::sqrt(LAST) * moments_K(kx, ky, LAST - 1);
143143
// TODO Viriato adds this after the derivative
144144
});
145145

146146
auto dBrLast = g.dBufXY();
147-
derivatives(bracketAParGLast_K, dBrLast);
148-
auto bracketTotalGLast_K = halfBracket(Grid::sliceXY(dGM, A_PAR), dBrLast);
147+
br.derivatives(bracketAParGLast_K, dBrLast);
148+
auto bracketTotalGLast_K = br.halfBracket(Grid::sliceXY(dGM, A_PAR), dBrLast);
149149

150150
g.for_each_kxky([&](Dim kx, Dim ky) {
151151
GM_Nonlinear_K(kx, ky, G_MIN) = nonlinear::G2(
@@ -172,8 +172,8 @@ void Naive::run(Dim N, Dim saveInterval) {
172172
std::sqrt(m) * dGM.DY(x, y, m - 1) + std::sqrt(m + 1) * dGM.DY(x, y, m + 1);
173173
});
174174

175-
auto bracketAParGMMinusPlus_K = halfBracket(Grid::sliceXY(dGM, A_PAR), dGMinusPlus);
176-
auto bracketPhiGM_K = halfBracket(dPhi, Grid::sliceXY(dGM, m));
175+
auto bracketAParGMMinusPlus_K = br.halfBracket(Grid::sliceXY(dGM, A_PAR), dGMinusPlus);
176+
auto bracketPhiGM_K = br.halfBracket(dPhi, Grid::sliceXY(dGM, m));
177177

178178
g.for_each_kxky([&](Dim kx, Dim ky) {
179179
GM_Nonlinear_K(kx, ky, m) =
@@ -198,12 +198,12 @@ void Naive::run(Dim N, Dim saveInterval) {
198198

199199
auto dPhi_Loop = g.dBufXY(), dUEKPar_Loop = g.dBufXY();
200200
auto dGM_Loop = g.dBufMXY();
201-
derivatives(phi_K_New, dPhi_Loop);
202-
derivatives(ueKPar_K_New, dUEKPar_Loop);
201+
br.derivatives(phi_K_New, dPhi_Loop);
202+
br.derivatives(ueKPar_K_New, dUEKPar_Loop);
203203

204204
for (int m = 0; m < g.M; ++m) {
205205
// TODO(OPT) not necessary if we bail (only up to G_MIN)
206-
derivatives(Grid::sliceXY(GM_K_Star, m), Grid::sliceXY(dGM_Loop, m));
206+
br.derivatives(Grid::sliceXY(GM_K_Star, m), Grid::sliceXY(dGM_Loop, m));
207207
}
208208

209209
// Corrector loop
@@ -221,7 +221,7 @@ void Naive::run(Dim N, Dim saveInterval) {
221221

222222
for (int p = 0; p <= MaxP; ++p) {
223223
auto DerivateNewMoment = [&](Dim m) {
224-
derivatives(Grid::sliceXY(momentsNew_K, m), Grid::sliceXY(dGM_Loop, m));
224+
br.derivatives(Grid::sliceXY(momentsNew_K, m), Grid::sliceXY(dGM_Loop, m));
225225
};
226226

227227
// First, compute A_par
@@ -240,8 +240,9 @@ void Naive::run(Dim N, Dim saveInterval) {
240240
}
241241
});
242242

243-
auto bracketAParPhiG2Ne_K_Loop = halfBracket(Grid::sliceXY(dGM_Loop, A_PAR), dPhiNeG2_Loop);
244-
auto bracketUEParPhi_K_Loop = halfBracket(dUEKPar_Loop, dPhi_Loop);
243+
auto bracketAParPhiG2Ne_K_Loop =
244+
br.halfBracket(Grid::sliceXY(dGM_Loop, A_PAR), dPhiNeG2_Loop);
245+
auto bracketUEParPhi_K_Loop = br.halfBracket(dUEKPar_Loop, dPhi_Loop);
245246

246247
/// f_pred from Viriato
247248
auto GM_Nonlinear_K_Loop = g.cBufMXY();
@@ -275,10 +276,10 @@ void Naive::run(Dim N, Dim saveInterval) {
275276
// TODO(OPT) bail if relative error is large
276277

277278
DerivateNewMoment(A_PAR);
278-
derivatives(ueKPar_K_New, dUEKPar_Loop);
279+
br.derivatives(ueKPar_K_New, dUEKPar_Loop);
279280

280-
auto bracketPhiNE_K_Loop = halfBracket(dPhi_Loop, Grid::sliceXY(dGM_Loop, N_E));
281-
auto bracketAParUEKPar_K_Loop = halfBracket(Grid::sliceXY(dGM_Loop, A_PAR), dUEKPar_Loop);
281+
auto bracketPhiNE_K_Loop = br.halfBracket(dPhi_Loop, Grid::sliceXY(dGM_Loop, N_E));
282+
auto bracketAParUEKPar_K_Loop = br.halfBracket(Grid::sliceXY(dGM_Loop, A_PAR), dUEKPar_Loop);
282283

283284
g.for_each_kxky([&](Dim kx, Dim ky) {
284285
GM_Nonlinear_K_Loop(kx, ky, N_E) =
@@ -293,13 +294,13 @@ void Naive::run(Dim N, Dim saveInterval) {
293294
(kx | ky) == 0 ? 0 : nonlinear::phi(momentsNew_K(kx, ky, N_E), kPerp2(kx, ky));
294295
});
295296

296-
derivatives(phi_K_New, dPhi_Loop);
297+
br.derivatives(phi_K_New, dPhi_Loop);
297298
DerivateNewMoment(N_E);
298299
if (g.M > 2) {
299300
// Compute G2
300-
auto bracketPhiG2_K_Loop = halfBracket(dPhi_Loop, Grid::sliceXY(dGM_Loop, G_MIN));
301+
auto bracketPhiG2_K_Loop = br.halfBracket(dPhi_Loop, Grid::sliceXY(dGM_Loop, G_MIN));
301302
auto bracketAParG3_K_Loop =
302-
halfBracket(Grid::sliceXY(dGM_Loop, A_PAR), Grid::sliceXY(dGM_Loop, G_MIN + 1));
303+
br.halfBracket(Grid::sliceXY(dGM_Loop, A_PAR), Grid::sliceXY(dGM_Loop, G_MIN + 1));
303304

304305
g.for_each_kxky([&](Dim kx, Dim ky) {
305306
GM_Nonlinear_K_Loop(kx, ky, G_MIN) =
@@ -323,8 +324,8 @@ void Naive::run(Dim N, Dim saveInterval) {
323324
});
324325

325326
auto bracketAParGMMinusPlus_K_Loop =
326-
halfBracket(Grid::sliceXY(dGM_Loop, A_PAR), dGMinusPlus_Loop);
327-
auto bracketPhiGM_K_Loop = halfBracket(dPhi_Loop, Grid::sliceXY(dGM_Loop, m));
327+
br.halfBracket(Grid::sliceXY(dGM_Loop, A_PAR), dGMinusPlus_Loop);
328+
auto bracketPhiGM_K_Loop = br.halfBracket(dPhi_Loop, Grid::sliceXY(dGM_Loop, m));
328329

329330
g.for_each_kxky([&](Dim kx, Dim ky) {
330331
GM_Nonlinear_K_Loop(kx, ky, m) = nonlinear::GM(m, bracketPhiGM_K_Loop(kx, ky),
@@ -341,9 +342,9 @@ void Naive::run(Dim N, Dim saveInterval) {
341342
}
342343

343344
// Compute G_{M-1}
344-
auto bracketPhiGLast_K_Loop = halfBracket(dPhi_Loop, Grid::sliceXY(dGM_Loop, LAST));
345+
auto bracketPhiGLast_K_Loop = br.halfBracket(dPhi_Loop, Grid::sliceXY(dGM_Loop, LAST));
345346
auto bracketAParGLast_K_Loop =
346-
halfBracket(Grid::sliceXY(dGM_Loop, A_PAR), Grid::sliceXY(dGM_Loop, LAST));
347+
br.halfBracket(Grid::sliceXY(dGM_Loop, A_PAR), Grid::sliceXY(dGM_Loop, LAST));
347348
g.for_each_kxky([&](Dim kx, Dim ky) {
348349
bracketAParGLast_K_Loop(kx, ky) *=
349350
nonlinear::GLastBracketFactor(g.M, kPerp2(kx, ky), hyper);
@@ -353,8 +354,9 @@ void Naive::run(Dim N, Dim saveInterval) {
353354
});
354355

355356
DxDy<Buf::R_XY> dBrLast_Loop = g.dBufXY();
356-
derivatives(bracketAParGLast_K_Loop, dBrLast_Loop);
357-
auto bracketTotalGLast_K_Loop = halfBracket(Grid::sliceXY(dGM_Loop, A_PAR), dBrLast_Loop);
357+
br.derivatives(bracketAParGLast_K_Loop, dBrLast_Loop);
358+
auto bracketTotalGLast_K_Loop =
359+
br.halfBracket(Grid::sliceXY(dGM_Loop, A_PAR), dBrLast_Loop);
358360

359361
g.for_each_kxky([&](Dim kx, Dim ky) {
360362
GM_Nonlinear_K_Loop(kx, ky, LAST) =
@@ -404,7 +406,8 @@ void Naive::run(Dim N, Dim saveInterval) {
404406
this->elapsedT += dt;
405407

406408
// Update dt
407-
Real tempDt = getTimestep(dPhi_Loop, Grid::sliceXY(dGM_Loop, N_E), Grid::sliceXY(dGM_Loop, A_PAR));
409+
Real tempDt =
410+
getTimestep(dPhi_Loop, Grid::sliceXY(dGM_Loop, N_E), Grid::sliceXY(dGM_Loop, A_PAR));
408411
dt = updateTimestep(dt, tempDt, noInc, relative_error);
409412
hyper = HyperCoefficients::calculate(dt, g);
410413

@@ -431,8 +434,9 @@ void Naive::run(Dim N, Dim saveInterval) {
431434
// Log moment values when level is trace (most verbose)
432435

433436
for (Dim m = 0; m < g.M; ++m) {
434-
spdlog::trace("t={} m={}:\n{}", t, m,
435-
fmt::streamed(ostream_tuple(std::setprecision(16), Grid::sliceXY(moments_K, m))));
437+
spdlog::trace(
438+
"t={} m={}:\n{}", t, m,
439+
fmt::streamed(ostream_tuple(std::setprecision(16), Grid::sliceXY(moments_K, m))));
436440
}
437441
}
438442

@@ -475,31 +479,6 @@ Naive::Buf::R_XY Naive::getMoment(Dim m) const {
475479
return out;
476480
}
477481

478-
[[nodiscard]] Naive::Buf::C_XY Naive::fullBracket(View::C_XY op1, View::C_XY op2) {
479-
auto derOp1 = g.dBufXY(), derOp2 = g.dBufXY();
480-
derivatives(op1, derOp1);
481-
derivatives(op2, derOp2);
482-
483-
return halfBracket(derOp1, derOp2);
484-
}
485-
486-
void Naive::derivatives(const View::C_XY &op, Naive::DxDy<View::R_XY> output) {
487-
DxDy<Buf::C_XY> Der_K{g.KX, g.KY};
488-
prepareDXY_PH(op, Der_K.DX, Der_K.DY);
489-
tf.bfft(Der_K.DX.to_mdspan(), output.DX);
490-
tf.bfft(Der_K.DY.to_mdspan(), output.DY);
491-
}
492-
493-
Naive::Buf::C_XY Naive::halfBracket(Naive::DxDy<View::R_XY> derOp1,
494-
Naive::DxDy<View::R_XY> derOp2) {
495-
Buf::R_XY br = g.rBufXY();
496-
Buf::C_XY br_K = g.cBufXY();
497-
bracket(derOp1, derOp2, br);
498-
fftHL(br.to_mdspan(), br_K.to_mdspan());
499-
br_K(0, 0) = 0;
500-
return br_K;
501-
}
502-
503482
Naive::Energies Naive::calculateEnergies() const {
504483
Energies e{};
505484
g.for_each_kxky([&](Dim kx, Dim ky) {

0 commit comments

Comments
 (0)