Skip to content

Commit 626966b

Browse files
authored
feat: adding operator== for grids (and material grids) (#884)
This PR adds the necessary operator== to make grids comparable. It also adds unit tests for the comparisons of grids, particularly for material grids.
1 parent ee69f28 commit 626966b

File tree

10 files changed

+420
-0
lines changed

10 files changed

+420
-0
lines changed

core/include/detray/utils/grid/detail/axis.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,13 @@ struct single_axis {
8383
const vector_type<scalar_type> *edges)
8484
: m_binning(indx_range, edges) {}
8585

86+
/// Equality operator
87+
///
88+
/// @param rhs is the right-hand side of the comparison
89+
///
90+
/// @returns whether the two axes are equal
91+
constexpr bool operator==(const single_axis &rhs) const = default;
92+
8693
/// @returns the axis label, i.e. x, y, z, r or phi axis.
8794
DETRAY_HOST_DEVICE
8895
constexpr auto label() const -> axis::label { return bounds_type::label; }
@@ -401,6 +408,25 @@ class multi_axis {
401408
detray::get_data(m_edges)};
402409
}
403410

411+
/// Equality operator
412+
///
413+
/// @param rhs the right-hand side of the comparison
414+
///
415+
/// @note in the non-owning case, we compare the values not the pointers
416+
///
417+
/// @returns whether the two axes are equal
418+
DETRAY_HOST_DEVICE constexpr auto operator==(const multi_axis &rhs) const
419+
-> bool {
420+
if constexpr (!std::is_pointer_v<edge_range_t>) {
421+
return m_edge_offsets == rhs.m_edge_offsets &&
422+
m_edges == rhs.m_edges;
423+
} else {
424+
return m_edge_offsets == rhs.m_edge_offsets &&
425+
*m_edges == *rhs.m_edges;
426+
}
427+
return false;
428+
}
429+
404430
private:
405431
/// Get the number of bins for a single axis.
406432
///

core/include/detray/utils/grid/detail/axis_binning.hpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,19 @@ struct regular {
167167

168168
return {min, max};
169169
}
170+
171+
/// Equality operator
172+
///
173+
/// @param rhs the axis to compare to
174+
///
175+
/// @note as we cannot guarantee to have the same pointer for the bin edges,
176+
/// we make a fast comparison of the pointer first, but also allow for a
177+
/// value based comparison
178+
///
179+
/// @returns whether the two axes are equal
180+
DETRAY_HOST_DEVICE constexpr bool operator==(const regular &rhs) const {
181+
return (nbins() == rhs.nbins()) && (span() == rhs.span());
182+
}
170183
};
171184

172185
/// @brief An irregular binning scheme.
@@ -306,6 +319,36 @@ struct irregular {
306319

307320
return {min, max};
308321
}
322+
323+
/// Equality operator
324+
///
325+
/// @param rhs the axis to compare to
326+
///
327+
/// @note as we cannot guarantee to have the same pointer for the bin edges,
328+
/// we make a fast comparison of the pointer first, but also allow for a
329+
/// value based comparison
330+
///
331+
/// @returns whether the two axes are equal
332+
DETRAY_HOST_DEVICE constexpr bool operator==(const irregular &rhs) const {
333+
if (m_n_bins != rhs.m_n_bins) {
334+
return false;
335+
}
336+
if (m_offset == rhs.m_offset && m_bin_edges == rhs.m_bin_edges) {
337+
return true;
338+
}
339+
auto edge_range_lhs = detray::ranges::subrange(
340+
*m_bin_edges, dindex_range{m_offset, m_offset + m_n_bins});
341+
auto edge_range_rhs = detray::ranges::subrange(
342+
*rhs.m_bin_edges,
343+
dindex_range{rhs.m_offset, rhs.m_offset + rhs.m_n_bins});
344+
345+
for (dindex i = 0; i < m_n_bins; ++i) {
346+
if (edge_range_lhs[i] != edge_range_rhs[i]) {
347+
return false;
348+
}
349+
}
350+
return true;
351+
}
309352
};
310353

311354
} // namespace detray::axis

core/include/detray/utils/grid/detail/axis_bounds.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,13 @@ struct open {
8787
const std::size_t nbins) const noexcept {
8888
return map(range[0], range[1], nbins);
8989
}
90+
91+
/// Equality operator
92+
///
93+
/// @param rhs the open axis to compare with
94+
///
95+
/// @returns whether the two axes are equal
96+
constexpr bool operator==(const open &rhs) const = default;
9097
};
9198

9299
/// @brief Describes the behaviour of a closed axis.
@@ -155,6 +162,13 @@ struct closed {
155162
const std::size_t nbins) const {
156163
return map(range[0], range[1], nbins);
157164
}
165+
166+
/// Equality operator
167+
///
168+
/// @param rhs the open axis to compare with
169+
///
170+
/// @returns whether the two axes are equal
171+
constexpr bool operator==(const closed &rhs) const = default;
158172
};
159173

160174
/// @brief Describes the behaviour of a circular axis.
@@ -241,6 +255,13 @@ struct circular {
241255
const std::size_t nbins) const noexcept {
242256
return wrap(range[0], range[1], nbins);
243257
}
258+
259+
/// Equality operator
260+
///
261+
/// @param rhs the open axis to compare with
262+
///
263+
/// @returns whether the two axes are equal
264+
constexpr bool operator==(const circular &rhs) const = default;
244265
};
245266

246267
} // namespace detray::axis

core/include/detray/utils/grid/detail/bin_storage.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,16 @@ class bin_storage : public detray::ranges::view_interface<
112112
return detray::get_data(m_bin_data);
113113
}
114114

115+
/// Equality operator
116+
///
117+
/// @param rhs bin storage to compare with
118+
///
119+
/// @returns true if the bin data is equal
120+
DETRAY_HOST_DEVICE
121+
constexpr bool operator==(const bin_storage& rhs) const {
122+
return m_bin_data == rhs.m_bin_data;
123+
}
124+
115125
private:
116126
/// Container that holds all bin data when owning or a view into an
117127
/// externally owned container
@@ -453,6 +463,16 @@ class bin_storage<is_owning, detray::bins::dynamic_array<entry_t>, containers>
453463
detray::get_data(m_entry_data)};
454464
}
455465

466+
/// Equality operator
467+
///
468+
/// @param rhs bin storage to compare with
469+
///
470+
/// @returns true if the bin data is equal
471+
DETRAY_HOST_DEVICE
472+
constexpr bool operator==(const bin_storage& rhs) const {
473+
return m_bin_data == rhs.m_bin_data && m_entry_data == rhs.m_entry_data;
474+
}
475+
456476
private:
457477
/// Container that holds all bin data when owning or a view into an
458478
/// externally owned container

core/include/detray/utils/grid/detail/grid_bins.hpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,16 @@ class single : public detray::ranges::single_view<entry_t> {
5151
(*this).ref() = entry;
5252
return *this;
5353
}
54+
55+
/// Equality operator
56+
///
57+
/// @param rhs the single view to compare with
58+
///
59+
/// @returns true if the single value is equal
60+
DETRAY_HOST_DEVICE
61+
constexpr bool operator==(const single& rhs) const {
62+
return (*this).value() == rhs.value();
63+
}
5464
};
5565

5666
/// @brief Bin that holds a collection of entries.
@@ -153,6 +163,16 @@ class static_array
153163
return *this;
154164
}
155165

166+
/// Equality operator
167+
///
168+
/// @param rhs the bin entry to compare with
169+
///
170+
/// @returns true if the content is equal
171+
DETRAY_HOST_DEVICE
172+
constexpr bool operator==(const static_array& rhs) const {
173+
return m_content == rhs.m_content;
174+
}
175+
156176
private:
157177
/// @returns the subrange on the valid bin content - const
158178
DETRAY_HOST_DEVICE constexpr auto view() const {
@@ -187,6 +207,8 @@ class dynamic_array
187207
dindex size{0u};
188208
dindex capacity{0u};
189209

210+
constexpr bool operator==(const data& rhs) const = default;
211+
190212
DETRAY_HOST_DEVICE
191213
constexpr void update_offset(std::size_t shift) {
192214
offset += static_cast<dindex>(shift);
@@ -304,6 +326,33 @@ class dynamic_array
304326
return *this;
305327
}
306328

329+
/// Equality operator
330+
///
331+
/// @param rhs the bin to be compared with
332+
///
333+
/// @returns true if the view is identical
334+
DETRAY_HOST_DEVICE
335+
constexpr bool operator==(const dynamic_array& rhs) const {
336+
// Check if the bin points to the same data
337+
if (m_data == rhs.m_data || *m_data == *rhs.m_data) {
338+
return true;
339+
}
340+
if (m_data->size != rhs.m_data->size) {
341+
return false;
342+
}
343+
// It could still point to different data, but the
344+
// content is the same
345+
auto this_view = view();
346+
auto rhs_view = rhs.view();
347+
// Loop over the size of the bin and compare
348+
for (dindex i{0u}; i < m_data->size; ++i) {
349+
if (this_view[i] != rhs_view[i]) {
350+
return false;
351+
}
352+
}
353+
return true;
354+
}
355+
307356
private:
308357
/// @returns the subrange on the valid bin content - const
309358
DETRAY_HOST_DEVICE auto view() const {

core/include/detray/utils/grid/grid.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,29 @@ class grid_impl {
361361
detray::get_data(m_axes)};
362362
}
363363

364+
/// Equality comparison
365+
///
366+
/// @param rhs the right-hand side of the comparison
367+
///
368+
/// @note grids could have different bin storage ranges, but could still be
369+
/// identical, hence compare the actual grid content
370+
///
371+
/// @returns whether the two grids are equal
372+
DETRAY_HOST_DEVICE constexpr auto operator==(const grid_impl &rhs) const
373+
-> bool {
374+
// Check axes: they need to be identical
375+
if (m_axes != rhs.m_axes) {
376+
return false;
377+
}
378+
// Loop over global bin index and compare the two
379+
for (glob_bin_index i = 0; i < nbins(); ++i) {
380+
if (bin(i) != rhs.bin(i)) {
381+
return false;
382+
}
383+
}
384+
return true;
385+
}
386+
364387
private:
365388
/// Struct that contains the grid's data state
366389
bin_storage m_bins{};

core/include/detray/utils/ranges/subrange.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,16 @@ class subrange : public detray::ranges::view_interface<subrange<range_t>> {
8787
DETRAY_HOST_DEVICE
8888
constexpr auto end() const -> const_iterator_t { return m_end; }
8989

90+
/// Equality operator
91+
///
92+
/// @param rhs the subrange to compare with
93+
///
94+
/// @returns whether the two subranges are equal
95+
DETRAY_HOST_DEVICE
96+
constexpr auto operator==(const subrange &rhs) const -> bool {
97+
return m_begin == rhs.m_begin && m_end == rhs.m_end;
98+
}
99+
90100
private:
91101
/// Start and end position of the subrange
92102
iterator_t m_begin;

tests/unit_tests/cpu/material/material_maps.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,49 @@ GTEST_TEST(detray_material, trapezoid_map) {
265265
EXPECT_FALSE(trapezoid_map.at(199, 0) ==
266266
material_t(aluminium<scalar>{}, 201.f * unit<scalar>::mm));
267267
}
268+
269+
/// Unittest: Test the material grid comparisons
270+
GTEST_TEST(detray_material, material_grid_comparison) {
271+
272+
/** Allows to create a regular grid to check the equality opeartor
273+
* grids can differ in:
274+
* - type (will never be compared)
275+
* - bins/axes
276+
* - entries (i.e. material data)
277+
*/
278+
auto createGrid = [](const scalar hx, const scalar hy, unsigned int bx,
279+
unsigned int by, bool distort_entries) {
280+
mask<rectangle2D> r2{0u, hx, hy};
281+
auto material_grid = mat_map_factory.new_grid(r2, {bx, by});
282+
283+
// Fill the material grid with some data
284+
scalar thickness = 2.f * unit<scalar>::mm;
285+
for (dindex gbin = 0; gbin < material_grid.nbins(); ++gbin) {
286+
material_grid.template populate<replace<>>(
287+
gbin, material_t(oxygen_gas<scalar>{}, thickness));
288+
thickness += 1.f * unit<scalar>::mm;
289+
if (distort_entries) {
290+
thickness += 1.f * unit<scalar>::mm;
291+
}
292+
}
293+
// Return it
294+
return material_grid;
295+
};
296+
297+
// Two equal grids
298+
auto grid_ref = createGrid(10.f, 20.f, 10u, 20u, false);
299+
auto grid_eq = createGrid(10.f, 20.f, 10u, 20u, false);
300+
EXPECT_EQ(grid_ref, grid_eq);
301+
302+
// One grid with different size
303+
auto grid_neq_size = createGrid(11.f, 21.f, 10u, 20u, false);
304+
EXPECT_NE(grid_ref, grid_neq_size);
305+
306+
// One grid with different binning
307+
auto grid_neq_bins = createGrid(10.f, 20.f, 11u, 21u, false);
308+
EXPECT_NE(grid_ref, grid_neq_bins);
309+
310+
// One grid with different entries
311+
auto grid_neq_entries = createGrid(10.f, 20.f, 10u, 20u, true);
312+
EXPECT_NE(grid_ref, grid_neq_entries);
313+
}

0 commit comments

Comments
 (0)