Skip to content

Commit 3834192

Browse files
committed
Add transformer class
1 parent beac1e5 commit 3834192

File tree

8 files changed

+156
-24
lines changed

8 files changed

+156
-24
lines changed

cpp/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ add_executable(adaptive_hermite_refinement lib/test-triangle.cpp)
5858
target_include_directories(adaptive_hermite_refinement PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
5959
target_link_libraries(adaptive_hermite_refinement argparse mdspan)
6060

61-
add_library(src-lib OBJECT lib/Naive.cpp lib/Naive.hpp lib/HermiteRunner.cpp lib/HermiteRunner.hpp)
61+
add_library(src-lib OBJECT lib/Naive.cpp lib/Naive.hpp lib/HermiteRunner.cpp lib/HermiteRunner.hpp
62+
lib/Transformer.cpp
63+
lib/Transformer.hpp)
6264
target_include_directories(src-lib PUBLIC lib/)
6365
target_link_libraries(src-lib mdspan fftw-cpp cnpy spdlog::spdlog)
6466

cpp/lib/Naive.cpp

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ void Naive::hlFilter(View::C_XY &complexArray) {
2626
});
2727
}
2828

29-
void Naive::fft(View::R_XY in, View::C_XY out) {
30-
fft_base(in, out);
29+
void Naive::fftHL(View::R_XY in, View::C_XY out) {
30+
tf.fft(in, out);
3131
hlFilter(out);
3232
}
3333

@@ -37,14 +37,13 @@ void Naive::init(std::string_view equilibriumName) {
3737
auto temp = g.rBufXY();
3838

3939
// Plan FFTs both ways
40-
fft_base = fftw::plan_r2c<2u>::dft(temp.to_mdspan(), phi_K.to_mdspan(), fftw::ESTIMATE);
41-
fftInv = fftw::plan_c2r<2u>::dft(phi_K.to_mdspan(), temp.to_mdspan(), fftw::ESTIMATE);
40+
tf.init();
4241

4342
// Initialize equilibrium values
4443
auto [aParEq, phi] = equilibrium(equilibriumName, g);
4544

46-
fft(phi.to_mdspan(), phi_K.to_mdspan());
47-
fft(aParEq.to_mdspan(), aParEq_K.to_mdspan());
45+
fftHL(phi.to_mdspan(), phi_K.to_mdspan());
46+
fftHL(aParEq.to_mdspan(), aParEq_K.to_mdspan());
4847

4948
// Transform moments into phase space
5049
for (int m = G_MIN; m < g.M; ++m) {
@@ -475,7 +474,7 @@ Real Naive::updateTimestep(Real dt, Real tempDt, bool noInc, Real relative_error
475474
mdarray<Real, dextents<Dim, 2u>> Naive::getFinalAPar() {
476475
Buf::R_XY buf = g.rBufXY();
477476
// This actually wrecks A_PAR, but we don't need it anymore
478-
fftInv(Grid::sliceXY(moments_K, A_PAR), buf.to_mdspan());
477+
tf.bfft(Grid::sliceXY(moments_K, A_PAR), buf.to_mdspan());
479478

480479
// Write to a layout_right array and normalize
481480
mdarray<Real, dextents<Dim, 2u>> result{g.X, g.Y};
@@ -490,7 +489,7 @@ Naive::Buf::R_XY Naive::getMoment(Dim m) const {
490489
g.for_each_kxky([&](Dim kx, Dim ky) { tmp(kx, ky) = moments_K(kx, ky, m); });
491490

492491
Buf::R_XY out = g.rBufXY();
493-
fftInv(tmp.to_mdspan(), out.to_mdspan());
492+
tf.bfft(tmp.to_mdspan(), out.to_mdspan());
494493

495494
return out;
496495
}
@@ -506,16 +505,16 @@ Naive::Buf::R_XY Naive::getMoment(Dim m) const {
506505
void Naive::derivatives(const View::C_XY &op, Naive::DxDy<View::R_XY> output) {
507506
DxDy<Buf::C_XY> Der_K{g.KX, g.KY};
508507
prepareDXY_PH(op, Der_K.DX, Der_K.DY);
509-
fftInv(Der_K.DX.to_mdspan(), output.DX);
510-
fftInv(Der_K.DY.to_mdspan(), output.DY);
508+
tf.bfft(Der_K.DX.to_mdspan(), output.DX);
509+
tf.bfft(Der_K.DY.to_mdspan(), output.DY);
511510
}
512511

513512
Naive::Buf::C_XY Naive::halfBracket(Naive::DxDy<View::R_XY> derOp1,
514513
Naive::DxDy<View::R_XY> derOp2) {
515514
Buf::R_XY br = g.rBufXY();
516515
Buf::C_XY br_K = g.cBufXY();
517516
bracket(derOp1, derOp2, br);
518-
fft(br.to_mdspan(), br_K.to_mdspan());
517+
fftHL(br.to_mdspan(), br_K.to_mdspan());
519518
br_K(0, 0) = 0;
520519
return br_K;
521520
}
@@ -532,16 +531,12 @@ void Naive::exportToNpy(std::string path, View::C_XY view) const {
532531

533532
g.for_each_kxky([&](Dim kx, Dim ky) { tempK(kx, ky) = view(kx, ky); });
534533

535-
fftInv(tempK.to_mdspan(), temp.to_mdspan());
536-
normalize(temp.to_mdspan(), temp.to_mdspan());
534+
tf.bfft(tempK.to_mdspan(), temp.to_mdspan());
535+
tf.normalize(temp.to_mdspan(), temp.to_mdspan());
537536

538537
exportToNpy(std::move(path), temp.to_mdspan());
539538
}
540539

541-
void Naive::normalize(View::R_XY view, View::R_XY viewOut) const {
542-
g.for_each_xy([&](Dim x, Dim y) { viewOut(x, y) = view(x, y) * XYNorm; });
543-
}
544-
545540
Naive::Energies Naive::calculateEnergies() const {
546541
Energies e{};
547542
g.for_each_kxky([&](Dim kx, Dim ky) {

cpp/lib/Naive.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "HermiteRunner.hpp"
4+
#include "Transformer.hpp"
45
#include "constants.hpp"
56
#include "debug.hpp"
67
#include "grid.hpp"
@@ -35,6 +36,7 @@ class Naive : public ahr::HermiteRunner {
3536

3637
private:
3738
Grid g;
39+
Transformer tf{g};
3840

3941
using View = Grid::View;
4042
using Buf = Grid::Buf;
@@ -45,10 +47,8 @@ class Naive : public ahr::HermiteRunner {
4547
Real elapsedT{0.0}; ///< total time elapsed
4648

4749
void hlFilter(View::C_XY &complexArray);
48-
void fft(View::R_XY in, View::C_XY out); ///< FFT with Hou-Li Filter
50+
void fftHL(View::R_XY in, View::C_XY out); ///< FFT with Hou-Li Filter
4951

50-
fftw::plan_r2c<2u> fft_base{};
51-
fftw::plan_c2r<2u> fftInv{};
5252
Real bPerpMax{0};
5353

5454
static constexpr Dim N_E = 0;
@@ -232,9 +232,6 @@ class Naive : public ahr::HermiteRunner {
232232
void exportToNpy(std::string path, View::C_XY view) const;
233233

234234
private:
235-
// If view = viewOut, then we're normalizing in place.
236-
void normalize(Naive::View::R_XY view, Naive::View::R_XY viewOut) const;
237-
238235
void exportTimestep(Dim t);
239236
};
240237
}; // namespace ahr

cpp/lib/Transformer.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//
2+
// Created by Luka on 10/23/2024.
3+
//
4+
5+
#include "Transformer.hpp"
6+
7+
namespace ahr {
8+
void ahr::Transformer::init() {
9+
auto r = grid.rBufXY();
10+
auto c = grid.cBufXY();
11+
fftFwd = fftw::plan_r2c<2u>::dft(r.to_mdspan(), c.to_mdspan(), fftw::ESTIMATE);
12+
fftBwd = fftw::plan_c2r<2u>::dft(c.to_mdspan(), r.to_mdspan(), fftw::ESTIMATE);
13+
}
14+
void Transformer::fft(Grid::View::R_XY in, Grid::View::C_XY out) const { fftFwd(in, out); }
15+
16+
void Transformer::bfft(Grid::View::C_XY in, Grid::View::R_XY out) const { fftBwd(in, out); }
17+
18+
void Transformer::normalize(Grid::View::C_XY view, Grid::View::C_XY out) const {
19+
grid.for_each_kxky([&](Dim kx, Dim ky) { out(kx, ky) = view(kx, ky) * XYNorm; });
20+
}
21+
22+
void Transformer::normalize(Grid::View::R_XY view, Grid::View::R_XY out) const {
23+
grid.for_each_xy([&](Dim x, Dim y) { out(x, y) = view(x, y) * XYNorm; });
24+
}
25+
26+
} // namespace ahr

cpp/lib/Transformer.hpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#pragma once
2+
3+
#include "grid.hpp"
4+
5+
namespace ahr {
6+
7+
class Transformer {
8+
9+
public:
10+
explicit Transformer(Grid const &grid) : grid(grid) {}
11+
12+
/// Plan FFTs, etc.
13+
void init();
14+
15+
/// Forward FFT
16+
void fft(Grid::View::R_XY in, Grid::View::C_XY out) const;
17+
18+
/// Backwards FFT (unnormalized)
19+
void bfft(Grid::View::C_XY in, Grid::View::R_XY out) const;
20+
21+
/// Normalize a complex buffer (can be in-place)
22+
void normalize(Grid::View::C_XY view, Grid::View::C_XY out) const;
23+
24+
/// Normalize a real buffer (can be in-place)
25+
void normalize(Grid::View::R_XY view, Grid::View::R_XY out) const;
26+
27+
private:
28+
fftw::plan_r2c<2u> fftFwd{};
29+
fftw::plan_c2r<2u> fftBwd{};
30+
31+
Grid const &grid;
32+
33+
/// Normalization factor for FFT
34+
Real XYNorm{1.0 / double(grid.X) / double(grid.Y)};
35+
};
36+
37+
} // namespace ahr

cpp/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ function(add_viriato_test FILE)
1919
endfunction()
2020

2121
add_viriato_test(grid.cpp)
22+
add_viriato_test(transformer.cpp)
2223

2324
add_viriato_test(naive-energies.cpp)
2425
add_viriato_test(naive-moments.cpp)

cpp/test/transformer.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#include "Transformer.hpp"
2+
3+
#include "util.hpp"
4+
#include <gtest/gtest.h>
5+
6+
namespace ahr {
7+
8+
TEST(Transformer, Forward) {
9+
Grid grid{5, 16, 24}; // Example grid dimensions
10+
Transformer tf{grid};
11+
tf.init();
12+
13+
auto r = grid.rBufXY();
14+
auto c = grid.cBufXY();
15+
16+
// Initialize the input grid with a delta function
17+
grid.for_each_xy([&](Dim x, Dim y) { r(x, y) = (x == 0 && y == 0) ? 1.0 : 0.0; });
18+
19+
// Perform FFT
20+
tf.fft(r, c);
21+
22+
// Check the output (constant)
23+
grid.for_each_kxky([&](Dim kx, Dim ky) {
24+
EXPECT_NEAR(c(kx, ky).real(), 1.0, 1e-6);
25+
EXPECT_NEAR(c(kx, ky).imag(), 0.0, 1e-6);
26+
});
27+
}
28+
29+
TEST(Transformer, Backward) {
30+
Grid grid{5, 16, 24};
31+
Transformer tf{grid};
32+
tf.init();
33+
34+
auto r = grid.rBufXY();
35+
auto c = grid.cBufXY();
36+
37+
// Initialize the input with a constant
38+
grid.for_each_kxky([&](Dim kx, Dim ky) { c(kx, ky) = 1.0; });
39+
40+
tf.bfft(c, r);
41+
42+
// Check the output (delta function)
43+
grid.for_each_xy([&](Dim x, Dim y) {
44+
if (x == 0 && y == 0) {
45+
EXPECT_NEAR(r(x, y), grid.X * grid.Y, 1e-6);
46+
} else {
47+
EXPECT_NEAR(r(x, y), 0.0, 1e-6);
48+
}
49+
});
50+
}
51+
52+
TEST(Transformer, RoundTrip) {
53+
Grid grid{5, 16, 24};
54+
Transformer tf{grid};
55+
tf.init();
56+
57+
auto r = grid.rBufXY(), r2 = grid.rBufXY();
58+
auto c = grid.cBufXY(), c2 = grid.cBufXY();
59+
60+
grid.for_each_xy([&](Dim x, Dim y) {
61+
using namespace std;
62+
using namespace std::numbers;
63+
r(x, y) = cos(4 * pi * x / grid.X) + 2 * cos(2 * pi * y / grid.Y);
64+
});
65+
66+
tf.fft(r, c);
67+
tf.normalize(c, c2);
68+
tf.bfft(c2, r2);
69+
70+
EXPECT_THAT(r2.to_mdspan(), MdspanElementsAllClose(r.to_mdspan(), 1e-10, 1e-10));
71+
}
72+
73+
} // namespace ahr

cpp/test/util.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <fftw-cpp/fftw-cpp.h>
44
#include <gmock/gmock.h>
55
#include <tuple>
6+
#include <iomanip>
67

78
// Function to slice the tuple
89
template <std::size_t Start, std::size_t End, typename Tuple, std::size_t... Indices>

0 commit comments

Comments
 (0)