Skip to content

Commit a497f42

Browse files
authored
Merge pull request #620 from astrorama/feature/psfstack_optimization
Optimize VariablePsfStack using KdTree
2 parents 519ad8f + 8f1e8a4 commit a497f42

File tree

9 files changed

+534
-43
lines changed

9 files changed

+534
-43
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ __pycache__
1111
/packages
1212
/doc/build/
1313
/build/
14+
.vscode

SEBenchmarks/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ elements_add_executable(BenchRendering src/program/BenchRendering.cpp
5656
LINK_LIBRARIES SEFramework SEImplementation ${Boost_LIBRARIES})
5757
elements_add_executable(BenchBackgroundModel src/program/BenchBackgroundModel.cpp
5858
LINK_LIBRARIES SEFramework SEImplementation ${Boost_LIBRARIES})
59+
elements_add_executable(BenchVariablePsfStack src/program/BenchVariablePsfStack.cpp
60+
LINK_LIBRARIES SEFramework ${Boost_LIBRARIES})
5961

6062
#===============================================================================
6163
# Declare the Boost tests here
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/** Copyright © 2019-2025 Université de Genève, LMU Munich - Faculty of Physics, IAP-CNRS/Sorbonne Université
2+
*
3+
* This library is free software; you can redistribute it and/or modify it under
4+
* the terms of the GNU Lesser General Public License as published by the Free
5+
* Software Foundation; either version 3.0 of the License, or (at your option)
6+
* any later version.
7+
*
8+
* This library is distributed in the hope that it will be useful, but WITHOUT
9+
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
10+
* FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
11+
* details.
12+
*
13+
* You should have received a copy of the GNU Lesser General Public License
14+
* along with this library; if not, write to the Free Software Foundation, Inc.,
15+
* 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16+
*/
17+
18+
/**
19+
* @file src/program/BenchVariablePsfStack.cpp
20+
* @date 06/27/25
21+
* @author marc schefer
22+
*/
23+
24+
#include <map>
25+
#include <string>
26+
#include <random>
27+
28+
#include <boost/program_options.hpp>
29+
#include <boost/timer/timer.hpp>
30+
#include "ElementsKernel/ProgramHeaders.h"
31+
#include "ElementsKernel/Real.h"
32+
#include "SEFramework/Psf/VariablePsfStack.h"
33+
34+
namespace po = boost::program_options;
35+
namespace timer = boost::timer;
36+
using namespace SourceXtractor;
37+
38+
static Elements::Logging logger = Elements::Logging::getLogger("BenchVariablePsfStack");
39+
40+
class BenchVariablePsfStack : public Elements::Program {
41+
private:
42+
std::default_random_engine random_generator;
43+
std::uniform_real_distribution<double> random_dist{0.0, 1000.0};
44+
45+
public:
46+
47+
po::options_description defineSpecificProgramOptions() override {
48+
po::options_description options{};
49+
options.add_options()
50+
("iterations", po::value<int>()->default_value(100000), "Number of getPsf calls to benchmark")
51+
("measures", po::value<int>()->default_value(3), "Number of timing measurements to take")
52+
("fits-file", po::value<std::string>()->default_value(""), "FITS file containing PSF stack");
53+
return options;
54+
}
55+
56+
Elements::ExitCode mainMethod(std::map<std::string, po::variable_value> &args) override {
57+
58+
auto iterations = args["iterations"].as<int>();
59+
auto measures = args["measures"].as<int>();
60+
auto fits_file = args["fits-file"].as<std::string>();
61+
62+
logger.info() << "Benchmarking VariablePsfStack::getPsf() with " << iterations << " iterations";
63+
logger.info() << "Taking " << measures << " timing measurements";
64+
65+
// Initialize VariablePsfStack with FITS file if provided, otherwise nullptr
66+
std::shared_ptr<CCfits::FITS> fitsPtr = nullptr;
67+
if (!fits_file.empty()) {
68+
try {
69+
fitsPtr = std::make_shared<CCfits::FITS>(fits_file);
70+
logger.info() << "Using FITS file: " << fits_file;
71+
} catch (const std::exception& e) {
72+
logger.error() << "Failed to load FITS file '" << fits_file << "': " << e.what();
73+
return Elements::ExitCode::DATAERR;
74+
}
75+
} else {
76+
logger.error() << "No FITS file provided";
77+
return Elements::ExitCode::USAGE;
78+
}
79+
80+
try {
81+
VariablePsfStack psfStack(fitsPtr);
82+
83+
logger.info() << "VariablePsfStack loaded successfully with " << psfStack.getNumberOfPsfs() << " PSFs";
84+
logger.info() << "PSF size: " << psfStack.getWidth() << "x" << psfStack.getHeight();
85+
logger.info() << "Pixel sampling: " << psfStack.getPixelSampling();
86+
87+
std::cout << "Iterations,Measurement,Time_nanoseconds" << std::endl;
88+
89+
for (int m = 0; m < measures; ++m) {
90+
logger.info() << "Measurement " << (m + 1) << "/" << measures;
91+
92+
timer::cpu_timer timer;
93+
timer.stop();
94+
95+
// Prepare test values for getPsf calls
96+
std::vector<std::vector<double>> testValues;
97+
testValues.reserve(iterations);
98+
99+
for (int i = 0; i < iterations; ++i) {
100+
testValues.push_back({random_dist(random_generator), random_dist(random_generator)});
101+
}
102+
103+
// Start timing
104+
timer.start();
105+
106+
for (int i = 0; i < iterations; ++i) {
107+
try {
108+
auto psf = psfStack.getPsf(testValues[i]);
109+
// Prevent compiler optimization by using the result
110+
volatile auto width = psf->getWidth();
111+
(void)width; // Suppress unused variable warning
112+
} catch (const std::exception& e) {
113+
// Expected to fail with nullptr, but we still measure the timing
114+
// until the exception is thrown
115+
}
116+
}
117+
118+
timer.stop();
119+
120+
auto elapsed_ns = timer.elapsed().wall;
121+
std::cout << iterations << "," << (m + 1) << "," << elapsed_ns << std::endl;
122+
123+
logger.info() << "Time for " << iterations << " calls: " << (elapsed_ns / 1e9) << " seconds";
124+
logger.info() << "Average time per call: " << (elapsed_ns / iterations) << " nanoseconds";
125+
}
126+
127+
} catch (const std::exception& e) {
128+
logger.error() << "Error initializing VariablePsfStack: " << e.what();
129+
return Elements::ExitCode::DATAERR;
130+
}
131+
132+
return Elements::ExitCode::OK;
133+
}
134+
};
135+
136+
MAIN_FOR(BenchVariablePsfStack)

SEFramework/SEFramework/Psf/VariablePsfStack.h

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525
#define _SEIMPLEMENTATION_PSF_VARIABLEPSFSTACK_H_
2626

2727
#include <CCfits/CCfits>
28+
#include <memory>
2829
#include <SEFramework/Image/VectorImage.h>
2930
#include <SEFramework/Psf/Psf.h>
31+
#include <SEUtils/KdTree.h>
3032

3133
namespace SourceXtractor {
3234

@@ -42,6 +44,18 @@ namespace SourceXtractor {
4244
*/
4345
class VariablePsfStack final : public Psf {
4446
public:
47+
/**
48+
* @brief Structure to hold PSF position data
49+
*/
50+
struct PsfPosition {
51+
SeFloat ra;
52+
SeFloat dec;
53+
SeFloat x;
54+
SeFloat y;
55+
double gridx;
56+
double gridy;
57+
};
58+
4559
/**
4660
* Constructor
4761
*/
@@ -83,6 +97,13 @@ class VariablePsfStack final : public Psf {
8397
return m_components;
8498
};
8599

100+
/**
101+
* @return The number of PSFs loaded in the stack
102+
*/
103+
long getNumberOfPsfs() const {
104+
return m_nrows;
105+
};
106+
86107
/**
87108
*
88109
*/
@@ -99,12 +120,8 @@ class VariablePsfStack final : public Psf {
99120

100121
long m_nrows;
101122

102-
std::vector<SeFloat> m_ra_values;
103-
std::vector<SeFloat> m_dec_values;
104-
std::vector<SeFloat> m_x_values;
105-
std::vector<SeFloat> m_y_values;
106-
std::vector<double> m_gridx_values;
107-
std::vector<double> m_gridy_values;
123+
std::vector<PsfPosition> m_positions;
124+
std::unique_ptr<KdTree<PsfPosition>> m_kdtree;
108125

109126
std::vector<std::string> m_components = {"X_IMAGE", "Y_IMAGE"};
110127

@@ -119,6 +136,16 @@ class VariablePsfStack final : public Psf {
119136
void selfTest();
120137
};
121138

139+
/**
140+
* @brief KdTree traits specialization for PsfPosition
141+
*/
142+
template <>
143+
struct KdTreeTraits<VariablePsfStack::PsfPosition> {
144+
static double getCoord(const VariablePsfStack::PsfPosition& pos, size_t index) {
145+
return (index == 0) ? static_cast<double>(pos.x) : static_cast<double>(pos.y);
146+
}
147+
};
148+
122149
} // namespace SourceXtractor
123150

124151
#endif //_SEIMPLEMENTATION_PSF_VARIABLEPSFSTACK_H_

SEFramework/src/lib/Psf/VariablePsfStack.cpp

Lines changed: 51 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
* Author: Martin Kuemmel
2222
*/
2323
#include <algorithm>
24+
#include <memory>
2425
#include <ElementsKernel/Logging.h>
2526
#include <ElementsKernel/Exception.h>
2627
#include "SEFramework/Psf/VariablePsfStack.h"
@@ -73,18 +74,33 @@ void VariablePsfStack::setup(std::shared_ptr<CCfits::FITS> pFits) {
7374
// read the nrows value
7475
m_nrows = position_data.rows();
7576

77+
// Temporary vectors for reading data
78+
std::vector<SeFloat> ra_values, dec_values, x_values, y_values;
79+
std::vector<double> gridx_values, gridy_values;
80+
7681
try {
7782
// read in all the EXT specific columns
78-
position_data.column("GRIDX", false).read(m_gridx_values, 0, m_nrows);
79-
position_data.column("GRIDY", false).read(m_gridy_values, 0, m_nrows);
83+
position_data.column("GRIDX", false).read(gridx_values, 0, m_nrows);
84+
position_data.column("GRIDY", false).read(gridy_values, 0, m_nrows);
8085
} catch (CCfits::Table::NoSuchColumn) {
81-
position_data.column("X_CENTER", false).read(m_gridx_values, 0, m_nrows);
82-
position_data.column("Y_CENTER", false).read(m_gridy_values, 0, m_nrows);
86+
position_data.column("X_CENTER", false).read(gridx_values, 0, m_nrows);
87+
position_data.column("Y_CENTER", false).read(gridy_values, 0, m_nrows);
88+
}
89+
position_data.column("RA", false).read(ra_values, 0, m_nrows);
90+
position_data.column("DEC", false).read(dec_values, 0, m_nrows);
91+
position_data.column("X", false).read(x_values, 0, m_nrows);
92+
position_data.column("Y", false).read(y_values, 0, m_nrows);
93+
94+
// Populate the positions vector
95+
m_positions.reserve(m_nrows);
96+
for (long i = 0; i < m_nrows; ++i) {
97+
m_positions.push_back({ra_values[i], dec_values[i], x_values[i], y_values[i], gridx_values[i], gridy_values[i]});
98+
}
99+
100+
// Build KdTree for fast nearest neighbor searches
101+
if (!m_positions.empty()) {
102+
m_kdtree = std::make_unique<KdTree<PsfPosition>>(m_positions);
83103
}
84-
position_data.column("RA", false).read(m_ra_values, 0, m_nrows);
85-
position_data.column("DEC", false).read(m_dec_values, 0, m_nrows);
86-
position_data.column("X", false).read(m_x_values, 0, m_nrows);
87-
position_data.column("Y", false).read(m_y_values, 0, m_nrows);
88104

89105
} catch (CCfits::FitsException& e) {
90106
throw Elements::Exception() << "Error loading stacked PSF file: " << e.message();
@@ -95,54 +111,54 @@ void VariablePsfStack::selfTest() {
95111
int naxis1, naxis2;
96112

97113
// read in the min/max grid values in x/y
98-
const auto x_grid_minmax = std::minmax_element(begin(m_gridx_values), end(m_gridx_values));
99-
const auto y_grid_minmax = std::minmax_element(begin(m_gridy_values), end(m_gridy_values));
114+
auto x_grid_minmax = std::minmax_element(m_positions.begin(), m_positions.end(),
115+
[](const PsfPosition& a, const PsfPosition& b) { return a.gridx < b.gridx; });
116+
auto y_grid_minmax = std::minmax_element(m_positions.begin(), m_positions.end(),
117+
[](const PsfPosition& a, const PsfPosition& b) { return a.gridy < b.gridy; });
100118

101119
// read the image size
102120
m_pFits->extension(1).readKey("NAXIS1", naxis1);
103121
m_pFits->extension(1).readKey("NAXIS2", naxis2);
104122

105123
// make sure all PSF in the grid are there
106-
if (*x_grid_minmax.first - m_grid_offset < 1)
107-
throw Elements::Exception() << "The PSF at the smallest x-grid starts at: " << *x_grid_minmax.first - m_grid_offset;
108-
if (*y_grid_minmax.first - m_grid_offset < 1)
109-
throw Elements::Exception() << "The PSF at the smallest y-grid starts at: " << *y_grid_minmax.first - m_grid_offset;
110-
if (*x_grid_minmax.second + m_grid_offset > naxis1)
111-
throw Elements::Exception() << "The PSF at the largest x-grid is too large: " << *x_grid_minmax.second + m_grid_offset
124+
if (x_grid_minmax.first->gridx - m_grid_offset < 1)
125+
throw Elements::Exception() << "The PSF at the smallest x-grid starts at: " << x_grid_minmax.first->gridx - m_grid_offset;
126+
if (y_grid_minmax.first->gridy - m_grid_offset < 1)
127+
throw Elements::Exception() << "The PSF at the smallest y-grid starts at: " << y_grid_minmax.first->gridy - m_grid_offset;
128+
if (x_grid_minmax.second->gridx + m_grid_offset > naxis1)
129+
throw Elements::Exception() << "The PSF at the largest x-grid is too large: " << x_grid_minmax.second->gridx + m_grid_offset
112130
<< " NAXIS1: " << naxis1;
113-
if (*y_grid_minmax.second + m_grid_offset > naxis2)
114-
throw Elements::Exception() << "The PSF at the largest y-grid is too large: " << *y_grid_minmax.second + m_grid_offset
115-
<< " NAXIS2: " << naxis1;
131+
if (y_grid_minmax.second->gridy + m_grid_offset > naxis2)
132+
throw Elements::Exception() << "The PSF at the largest y-grid is too large: " << y_grid_minmax.second->gridy + m_grid_offset
133+
<< " NAXIS2: " << naxis2;
116134
}
117135

118136
std::shared_ptr<VectorImage<SeFloat>> VariablePsfStack::getPsf(const std::vector<double>& values) const {
119-
long index_min_distance = 0;
120-
double min_distance = 1.0e+32;
121-
122137
// make sure there are only two positions
123138
if (values.size() != 2)
124139
throw Elements::Exception() << "There can be only two positional value for the stacked PSF!";
125140

126-
// find the position of minimal distance
127-
for (int act_index = 0; act_index < m_nrows; act_index++) {
128-
double act_distance = (values[0] - m_x_values[act_index]) * (values[0] - m_x_values[act_index]) +
129-
(values[1] - m_y_values[act_index]) * (values[1] - m_y_values[act_index]);
130-
if (act_distance < min_distance) {
131-
index_min_distance = act_index;
132-
min_distance = act_distance;
133-
}
134-
}
141+
// Use KdTree to find the nearest PSF position
142+
KdTree<PsfPosition>::Coord coord;
143+
coord.coord[0] = values[0]; // x coordinate
144+
coord.coord[1] = values[1]; // y coordinate
145+
146+
PsfPosition nearest_position = m_kdtree->findNearest(coord);
147+
148+
// Calculate distance for logging
149+
double min_distance = (values[0] - nearest_position.x) * (values[0] - nearest_position.x) +
150+
(values[1] - nearest_position.y) * (values[1] - nearest_position.y);
151+
135152
// give some feedback
136153
stack_logger.debug() << "Distance: " << sqrt(min_distance) << " (" << values[0] << "," << values[1] << ")<-->("
137-
<< m_x_values[index_min_distance] << "," << m_y_values[index_min_distance]
138-
<< ") index: " << index_min_distance;
154+
<< nearest_position.x << "," << nearest_position.y << ")";
139155

140156
// get the first and last pixels for the PSF to be extracted
141157
// NOTE: CCfits has 1-based indices, also the last index is *included* in the reading
142158
// NOTE: the +0.5 forces a correct cast/ceiling
143-
std::vector<long> first_vertex{long(m_gridx_values[index_min_distance]+.5) - long(m_grid_offset), long(m_gridy_values[index_min_distance]+.5) - long(m_grid_offset)};
159+
std::vector<long> first_vertex{long(nearest_position.gridx+.5) - long(m_grid_offset), long(nearest_position.gridy+.5) - long(m_grid_offset)};
144160
stack_logger.debug() << "First vertex: ( " << first_vertex[0] << ", " << first_vertex[1] << ") First vertex alternative: " <<
145-
m_gridx_values[index_min_distance]-m_grid_offset << " " << m_gridy_values[index_min_distance]-m_grid_offset <<
161+
nearest_position.gridx-m_grid_offset << " " << nearest_position.gridy-m_grid_offset <<
146162
" grid offset:" << m_grid_offset;
147163

148164
std::vector<long> last_vertex{first_vertex[0] + long(m_psf_size) - 1, first_vertex[1] +long( m_psf_size) - 1};
@@ -155,8 +171,6 @@ std::shared_ptr<VectorImage<SeFloat>> VariablePsfStack::getPsf(const std::vector
155171
m_pFits->extension(1).read(stamp_data, first_vertex, last_vertex, stride);
156172
}
157173

158-
//stack_logger.info() << "DDD ( " << first_vertex[0] << ", " << first_vertex[1] << ") --> ( " << last_vertex[0] << ", " << last_vertex[1] << "): " << stamp_data.size();
159-
160174
// create and return the psf image
161175
return VectorImage<SeFloat>::create(m_psf_size, m_psf_size, std::begin(stamp_data), std::end(stamp_data));
162176
}

SEUtils/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ elements_add_unit_test(Misc_test tests/src/Misc_test.cpp
6666
elements_add_unit_test(QuadTree_test tests/src/QuadTree_test.cpp
6767
LINK_LIBRARIES SEUtils
6868
TYPE Boost)
69+
elements_add_unit_test(KdTree_test tests/src/KdTree_test.cpp
70+
LINK_LIBRARIES SEUtils
71+
TYPE Boost)
6972

7073
if(GMOCK_FOUND)
7174
elements_add_unit_test(Observable_test tests/src/Observable_test.cpp

SEUtils/SEUtils/KdTree.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <vector>
2222
#include <memory>
2323
#include <algorithm>
24+
#include <limits>
2425

2526
namespace SourceXtractor {
2627

@@ -49,6 +50,7 @@ class KdTree {
4950

5051
explicit KdTree(const std::vector<T>& data);
5152
std::vector<T> findPointsWithinRadius(Coord coord, double radius) const;
53+
T findNearest(Coord coord) const;
5254

5355
private:
5456
class Node;

0 commit comments

Comments
 (0)