Skip to content

Commit 0c8292c

Browse files
committed
Extract PrepareDerivatives
1 parent 0227db1 commit 0c8292c

File tree

5 files changed

+46
-25
lines changed

5 files changed

+46
-25
lines changed

cpp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ add_library(src-lib OBJECT lib/Naive.cpp lib/Naive.hpp lib/HermiteRunner.cpp lib
6262
lib/Transformer.hpp lib/Transformer.cpp
6363
lib/Exporter.hpp lib/Exporter.cpp
6464
lib/Filter.hpp lib/Filter.cpp
65+
lib/PrepareDerivatives.hpp lib/PrepareDerivatives.cpp
6566
lib/Brackets.hpp lib/Brackets.cpp
6667
)
6768
target_include_directories(src-lib PUBLIC lib/)

cpp/lib/Brackets.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,9 @@
11
#include "Brackets.hpp"
2-
#include "Transformer.hpp"
32
#include "Filter.hpp"
3+
#include "Transformer.hpp"
44

55
namespace ahr {
66

7-
void Brackets::prepareDXY_PH(View::C_XY const &view_K, View::C_XY const &viewDX_K,
8-
View::C_XY const &viewDY_K) const {
9-
grid.for_each_kxky([&](Dim kx, Dim ky) {
10-
viewDX_K(kx, ky) = kx_(kx) * 1i * view_K(kx, ky) * XYNorm;
11-
viewDY_K(kx, ky) = ky_(ky) * 1i * view_K(kx, ky) * XYNorm;
12-
});
13-
}
14-
157
void Brackets::bracket(DxDy<View::R_XY> const &op1, DxDy<View::R_XY> const &op2,
168
View::R_XY const &output) const {
179
grid.for_each_xy([&](Dim x, Dim y) {
@@ -21,7 +13,7 @@ void Brackets::bracket(DxDy<View::R_XY> const &op1, DxDy<View::R_XY> const &op2,
2113

2214
void Brackets::derivatives(View::C_XY const &op, DxDy<View::R_XY> output) const {
2315
DxDy<Buf::C_XY> Der_K{grid.KX, grid.KY};
24-
prepareDXY_PH(op, Der_K.DX, Der_K.DY);
16+
prepareDXY(op, Der_K);
2517
tf.bfft(Der_K.DX, output.DX);
2618
tf.bfft(Der_K.DY, output.DY);
2719
}

cpp/lib/Brackets.hpp

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#pragma once
2-
#include "constants.hpp"
2+
#include "PrepareDerivatives.hpp"
33
#include "grid.hpp"
44

55
namespace ahr {
@@ -16,6 +16,7 @@ class Brackets {
1616
Grid const &grid;
1717
Transformer const &tf;
1818
HouLiFilter const &hlFilter; // TODO vectorized
19+
PrepareDerivatives prepareDXY{grid};
1920

2021
using View = Grid::View;
2122
using Buf = Grid::Buf;
@@ -32,22 +33,8 @@ class Brackets {
3233
[[nodiscard]] Buf::C_XY fullBracket(View::C_XY op1, View::C_XY op2) const;
3334

3435
private:
35-
/// Prepares the δx and δy of viewPH in phase space, as well as over-normalizes
36-
/// (after inverse FFT, values will be properly normalized)
37-
void prepareDXY_PH(View::C_XY const &view_K, View::C_XY const &viewDX_K,
38-
View::C_XY const &viewDY_K) const;
39-
4036
/// Computes bracket [op1, op2], expects normalized values
4137
void bracket(DxDy<View::R_XY> const &op1, DxDy<View::R_XY> const &op2,
4238
View::R_XY const &output) const;
43-
44-
/// Normalization factor for FFT
45-
const Real XYNorm{1.0 / Real(grid.X) / Real(grid.Y)};
46-
47-
// TODO extract to common utility
48-
[[nodiscard]] Real ky_(Dim ky) const {
49-
return (ky <= (grid.KY / 2) ? Real(ky) : Real(ky) - Real(grid.KY)) * Real(lx) / Real(ly);
50-
}
51-
[[nodiscard]] Real kx_(Dim kx) const { return Real(kx); }
5239
};
5340
} // namespace ahr

cpp/lib/PrepareDerivatives.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#include "PrepareDerivatives.hpp"
2+
3+
namespace ahr {
4+
5+
void PrepareDerivatives::operator()(View::C_XY const &in, DxDy<View::C_XY> out) const {
6+
grid.for_each_kxky([&](Dim kx, Dim ky) {
7+
out.DX(kx, ky) = kx_(kx) * 1i * in(kx, ky) * XYNorm;
8+
out.DY(kx, ky) = ky_(ky) * 1i * in(kx, ky) * XYNorm;
9+
});
10+
}
11+
12+
} // namespace ahr

cpp/lib/PrepareDerivatives.hpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include "constants.hpp"
4+
#include "grid.hpp"
5+
6+
namespace ahr {
7+
8+
class PrepareDerivatives {
9+
using View = Grid::View;
10+
template <class T> using DxDy = Grid::DxDy<T>;
11+
12+
public:
13+
explicit PrepareDerivatives(Grid const &grid) : grid(grid) {}
14+
15+
void operator()(View::C_XY const &in, DxDy<View::C_XY> out) const;
16+
17+
private:
18+
Grid const &grid;
19+
20+
/// Normalization factor for FFT
21+
const Real XYNorm{1.0 / Real(grid.X) / Real(grid.Y)};
22+
23+
// TODO extract to common utility
24+
[[nodiscard]] Real ky_(Dim ky) const {
25+
return (ky <= (grid.KY / 2) ? Real(ky) : Real(ky) - Real(grid.KY)) * Real(lx) / Real(ly);
26+
}
27+
[[nodiscard]] Real kx_(Dim kx) const { return Real(kx); }
28+
};
29+
} // namespace ahr

0 commit comments

Comments
 (0)