Skip to content

Commit bb806fc

Browse files
committed
Extract exporting (and importing) to Exporter. Currently complex is failing
1 parent 3834192 commit bb806fc

File tree

9 files changed

+217
-57
lines changed

9 files changed

+217
-57
lines changed

cpp/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ target_include_directories(adaptive_hermite_refinement PUBLIC ${CMAKE_CURRENT_SO
5959
target_link_libraries(adaptive_hermite_refinement argparse mdspan)
6060

6161
add_library(src-lib OBJECT lib/Naive.cpp lib/Naive.hpp lib/HermiteRunner.cpp lib/HermiteRunner.hpp
62-
lib/Transformer.cpp
63-
lib/Transformer.hpp)
62+
lib/Transformer.hpp lib/Transformer.cpp
63+
lib/Exporter.hpp lib/Exporter.cpp
64+
)
6465
target_include_directories(src-lib PUBLIC lib/)
6566
target_link_libraries(src-lib mdspan fftw-cpp cnpy spdlog::spdlog)
6667

cpp/lib/Exporter.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#include "Exporter.hpp"
2+
#include "Transformer.hpp"
3+
#include <cnpy.h>
4+
#include <filesystem>
5+
6+
namespace ahr {
7+
namespace fs = std::filesystem;
8+
void Exporter::exportTo(fs::path const &filename, Grid::View::C_XY cView) {
9+
// fft overwrites the input, so we need to copy it to a temporary buffer
10+
auto tempK = grid.cBufXY();
11+
auto temp = grid.rBufXY();
12+
13+
// Copy and also normalize
14+
tf.normalize(cView, tempK);
15+
16+
// Backwards FFT
17+
tf.bfft(tempK, temp);
18+
19+
// Write the real buffer to file
20+
exportTo(filename, temp);
21+
}
22+
23+
void Exporter::exportTo(fs::path const &filename, Grid::View::R_XY rView) {
24+
// TODO Copy the data to a layout-right buffer
25+
// stdex::mdarray<Real, stdex::dextents<size_t, 2u>> rArray{rView.extents()};
26+
// grid.for_each_xy([&](Dim x, Dim y) { rArray(x, y) = rView(x, y); });
27+
// cnpy::npy_save(prefix_dir / filename, rArray.data(), {grid.X, grid.Y}, "w");
28+
29+
// Dimensions are switched because we use layout_left
30+
auto const path = filename.is_absolute() ? filename : prefix_dir / filename;
31+
cnpy::npy_save(path, rView.data_handle(), {grid.Y, grid.X}, "w");
32+
}
33+
34+
NpyMdspan
35+
Exporter::importReal(const fs::path &filename) { // NOLINT(*-convert-member-functions-to-static)
36+
auto const path = filename.is_absolute() ? filename : prefix_dir / filename;
37+
return NpyMdspan{cnpy::npy_load(path)};
38+
}
39+
40+
void Exporter::importReal(const fs::path &filename, Grid::View::R_XY rView) {
41+
auto npy = importReal(filename);
42+
if (!npy.valid()) { throw std::runtime_error("Invalid npy file"); }
43+
if (npy.view().extents() != rView.extents()) { throw std::runtime_error("Incompatible extents"); }
44+
45+
// Copy the data
46+
grid.for_each_xy([&](Dim x, Dim y) { rView(x, y) = npy.view()(x, y); });
47+
}
48+
49+
Grid::Buf::R_XY Exporter::importRealBuf(const fs::path &filename) {
50+
auto rBuf = grid.rBufXY();
51+
importReal(filename, rBuf.to_mdspan());
52+
return rBuf;
53+
}
54+
55+
} // namespace ahr

cpp/lib/Exporter.hpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#pragma once
2+
3+
#include "NpyMdspan.hpp"
4+
#include "grid.hpp"
5+
#include <filesystem>
6+
#include <optional>
7+
8+
namespace ahr {
9+
namespace fs = std::filesystem;
10+
class Transformer;
11+
12+
/// This class is responsible for exporting and importing .npy buffers.
13+
/// It can automatically transform complex buffers to real before exporting.
14+
/// Any relative path is interpreted as relative to the prefix_dir.
15+
class Exporter {
16+
Grid const &grid;
17+
Transformer const &tf;
18+
fs::path prefix_dir;
19+
20+
public:
21+
Exporter(Grid const &grid, Transformer const &transformer,
22+
std::optional<fs::path> prefix_dir = std::nullopt)
23+
: grid(grid), tf(transformer) {
24+
if (prefix_dir) {
25+
this->prefix_dir = fs::canonical(*prefix_dir);
26+
} else if (auto prefix_dir_env = std::getenv("EXPORT_PREFIX_DIR"); prefix_dir_env) {
27+
this->prefix_dir = fs::canonical(prefix_dir_env);
28+
} else {
29+
this->prefix_dir = fs::current_path();
30+
}
31+
32+
fs::create_directories(this->prefix_dir);
33+
}
34+
35+
/// Export the complex view to file by transforming it to real first.
36+
void exportTo(fs::path const &filename, Grid::View::C_XY cView);
37+
38+
/// Export the real buffer to file.
39+
void exportTo(fs::path const &filename, Grid::View::R_XY rView);
40+
41+
/// Import a real buffer from file, and write it to the given view.
42+
void importReal(const fs::path &filename, Grid::View::R_XY rView);
43+
44+
/// Import a real buffer from file, and return it as an owning NpyMdspan.
45+
[[nodiscard]] NpyMdspan importReal(fs::path const &filename);
46+
47+
/// Import a real buffer from file, and return it as a real buffer.
48+
[[nodiscard]] Grid::Buf::R_XY importRealBuf(fs::path const &filename);
49+
50+
};
51+
} // namespace ahr

cpp/lib/Naive.cpp

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -452,15 +452,15 @@ void Naive::run(Dim N, Dim saveInterval) {
452452
void Naive::exportTimestep(Dim t) {
453453
std::ostringstream oss;
454454
oss << "a_par_t" << t << ".npy";
455-
exportToNpy(oss.str(), Grid::sliceXY(moments_K, A_PAR));
455+
exporter.exportTo(oss.str(), Grid::sliceXY(moments_K, A_PAR));
456456

457457
oss.str("");
458458
oss << "phi_t" << t << ".npy";
459-
exportToNpy(oss.str(), phi_K);
459+
exporter.exportTo(oss.str(), phi_K);
460460

461461
oss.str("");
462462
oss << "uekpar_t" << t << ".npy";
463-
exportToNpy(oss.str(), ueKPar_K);
463+
exporter.exportTo(oss.str(), ueKPar_K);
464464
}
465465

466466
Real Naive::updateTimestep(Real dt, Real tempDt, bool noInc, Real relative_error) const {
@@ -519,24 +519,6 @@ Naive::Buf::C_XY Naive::halfBracket(Naive::DxDy<View::R_XY> derOp1,
519519
return br_K;
520520
}
521521

522-
void Naive::exportToNpy(std::string path, View::R_XY view) const {
523-
// Coordinates are flipped because we use layout_left
524-
cnpy::npy_save(std::move(path), view.data_handle(), {g.Y, g.X}, "w");
525-
}
526-
527-
void Naive::exportToNpy(std::string path, View::C_XY view) const {
528-
// fft overwrites the input, so we need to copy it to a temporary buffer
529-
auto tempK = g.cBufXY();
530-
auto temp = g.rBufXY();
531-
532-
g.for_each_kxky([&](Dim kx, Dim ky) { tempK(kx, ky) = view(kx, ky); });
533-
534-
tf.bfft(tempK.to_mdspan(), temp.to_mdspan());
535-
tf.normalize(temp.to_mdspan(), temp.to_mdspan());
536-
537-
exportToNpy(std::move(path), temp.to_mdspan());
538-
}
539-
540522
Naive::Energies Naive::calculateEnergies() const {
541523
Energies e{};
542524
g.for_each_kxky([&](Dim kx, Dim ky) {

cpp/lib/Naive.hpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include "Exporter.hpp"
34
#include "HermiteRunner.hpp"
45
#include "Transformer.hpp"
56
#include "constants.hpp"
@@ -34,9 +35,10 @@ class Naive : public ahr::HermiteRunner {
3435
mdarray<Real, dextents<Dim, 2u>> getFinalAPar() override;
3536

3637

37-
private:
3838
Grid g;
3939
Transformer tf{g};
40+
Exporter exporter{g, tf};
41+
private:
4042

4143
using View = Grid::View;
4244
using Buf = Grid::Buf;
@@ -224,14 +226,6 @@ class Naive : public ahr::HermiteRunner {
224226
private:
225227
Real updateTimestep(Real dt, Real tempDt, bool noInc, Real relative_error) const;
226228

227-
public:
228-
// TODO(luka) separate exporting utility
229-
void exportToNpy(std::string path, View::R_XY view) const;
230-
231-
// Will also normalize and inverseFFT
232-
void exportToNpy(std::string path, View::C_XY view) const;
233-
234-
private:
235229
void exportTimestep(Dim t);
236230
};
237231
}; // namespace ahr

cpp/lib/NpyMdspan.hpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#pragma once
2+
3+
#include "grid.hpp"
4+
#include <cnpy.h>
5+
#include <filesystem>
6+
7+
namespace ahr {
8+
namespace fs = std::filesystem;
9+
10+
/// Owning holder of a npy array with convenience to see it as an mdspan.
11+
/// We can use this to avoid needlessly copying into an mdarray.
12+
class NpyMdspan {
13+
cnpy::NpyArray array_;
14+
15+
public:
16+
explicit NpyMdspan(cnpy::NpyArray array) : array_(std::move(array)) {}
17+
18+
// Layout-right
19+
using ViewXY = stdex::mdspan<Real, stdex::dextents<size_t, 2u>>;
20+
21+
// TODO(luka) const view
22+
Grid::View::R_XY view() {
23+
// Reverse the dimensions
24+
auto shape = array_.shape;
25+
std::reverse(shape.begin(), shape.end());
26+
std::span<size_t, 2> const extents{shape.data(), 2};
27+
28+
// Return a layout_left view
29+
return Grid::View::R_XY{array_.data<Real>(), extents};
30+
}
31+
32+
[[nodiscard]] bool valid() const { return array_.word_size == sizeof(Real); }
33+
};
34+
35+
} // namespace ahr

cpp/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ endfunction()
2020

2121
add_viriato_test(grid.cpp)
2222
add_viriato_test(transformer.cpp)
23+
add_viriato_test(exporter.cpp)
2324

2425
add_viriato_test(naive-energies.cpp)
2526
add_viriato_test(naive-moments.cpp)

cpp/test/exporter.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#include "Exporter.hpp"
2+
#include "Transformer.hpp"
3+
#include "debug.hpp"
4+
#include "util.hpp"
5+
#include <gtest/gtest.h>
6+
7+
namespace ahr {
8+
9+
class TestExporter : public ::testing::Test {
10+
protected:
11+
Grid grid{5, 16, 24};
12+
Transformer tf{grid};
13+
fs::path const tmp_dir{fs::temp_directory_path()};
14+
Exporter exporter{grid, tf, tmp_dir};
15+
16+
TestExporter() { tf.init(); }
17+
};
18+
19+
TEST_F(TestExporter, RoundTripReal) {
20+
21+
auto rBuf = grid.rBufXY();
22+
grid.for_each_xy([&](Dim x, Dim y) { rBuf(x, y) = Real(x + y * grid.X); });
23+
24+
exporter.exportTo("test.npy", rBuf);
25+
26+
auto const path = tmp_dir / "test.npy";
27+
ASSERT_TRUE(fs::exists(path));
28+
29+
// Import buffer and compare
30+
auto rBuf2 = exporter.importRealBuf("test.npy");
31+
// No math, tolerance is 0
32+
EXPECT_THAT(rBuf2.to_mdspan(), MdspanElementsAllClose(rBuf.to_mdspan(), 0.0));
33+
34+
// Import into npy view and compare
35+
auto rNpy = exporter.importReal("test.npy");
36+
EXPECT_THAT(rNpy.view(), MdspanElementsAllClose(rBuf.to_mdspan(), 0.0));
37+
}
38+
39+
TEST_F(TestExporter, RoundTripComplex) {
40+
auto cBuf = grid.cBufXY();
41+
grid.for_each_kxky([&](Dim kx, Dim ky) {
42+
cBuf(kx, ky) = Complex(Real(kx + ky * grid.KX), Real(kx - ky * grid.KX));
43+
});
44+
45+
// Use absolute path this time
46+
exporter.exportTo(tmp_dir / "test.npy", cBuf);
47+
48+
auto const path = tmp_dir / "test.npy";
49+
ASSERT_TRUE(fs::exists(path));
50+
51+
// Import buffer and compare
52+
auto rBuf = exporter.importReal("test.npy");
53+
auto rBuf2 = grid.rBufXY();
54+
tf.bfft(cBuf, rBuf2);
55+
56+
// No math, tolerance is 0
57+
EXPECT_THAT(rBuf.view(), MdspanElementsAllClose(rBuf2.to_mdspan(), 0.0))
58+
<< "rBuf:\n"
59+
<< rBuf.view() << "\nrBuf2:\n"
60+
<< rBuf2.to_mdspan();
61+
}
62+
63+
} // namespace ahr

cpp/test/naive-moments.cpp

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,35 +27,13 @@ constexpr Real ATOL = LOW_PRECISION_ATOL;
2727
template <param_like Param> class NaiveMomentsBase : public NaiveTester<Param> {
2828
protected:
2929
using Base = NaiveTester<Param>;
30-
std::string getFilename(Dim m) {
30+
fs::path getFilename(Dim m) {
3131
auto const p = Base::GetParam();
3232

3333
// Tests are run from inside the `test` directory
3434
auto const filename = p.to_param_str() + "_m" + std::to_string(m) + ".npy";
3535
return fs::current_path() / "_test_data" / filename;
3636
}
37-
38-
// Owning holder of a npy array with convenience to see it as an mdspan.
39-
// We can use this to avoid needlessly copying into an mdarray.
40-
class NpyMdspan {
41-
cnpy::NpyArray array_;
42-
43-
public:
44-
explicit NpyMdspan(cnpy::NpyArray array) : array_(std::move(array)) {}
45-
46-
// TODO(luka) const view
47-
Grid::View::R_XY view() {
48-
std::span<size_t, 2> const extents{array_.shape.data(), 2};
49-
return Grid::View::R_XY{array_.data<Real>(), extents};
50-
}
51-
52-
[[nodiscard]] bool valid() const { return array_.word_size == sizeof(Real); }
53-
};
54-
55-
NpyMdspan readMoment(Dim m) {
56-
auto const filename = getFilename(m);
57-
return NpyMdspan{cnpy::npy_load(filename)};
58-
}
5937
};
6038

6139
using MomentParam = WithEquilibrium<WithDiffusion<NaiveParam>>;
@@ -79,9 +57,9 @@ TEST_P(NaiveMoments, CheckMoments) {
7957
for (Dim m = 0; m < p.M; m++) {
8058
// To update values, uncomment these 2 lines
8159
// std::cout << "WARNING!: Overwriting " << getFilename(m) << std::endl;
82-
// naive.exportToNpy(getFilename(m), naive.getMoment(m));
60+
// naive.exporter.exportTo(getFilename(m), naive.getMoment(m));
8361

84-
auto npy = readMoment(m);
62+
auto npy = naive.exporter.importReal(getFilename(m));
8563
ASSERT_TRUE(npy.valid());
8664
auto const max_val = std::ranges::max(std::span(npy.view().data_handle(), npy.view().size()));
8765
auto const n = static_cast<Real>(p.N);

0 commit comments

Comments
 (0)