Skip to content

Commit ff13c73

Browse files
committed
feat(kernels): use binary search to find compartment
1 parent a7c51c6 commit ff13c73

File tree

1 file changed

+90
-52
lines changed

1 file changed

+90
-52
lines changed

apps/libs/simulation/public/simulation/kernels/move_kernel.hpp

Lines changed: 90 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,39 @@ namespace Simulation::KernelInline
2727
constexpr bool disable_move = false;
2828
constexpr bool enable_move = true;
2929

30+
// KOKKOS_INLINE_FUNCTION std::size_t __find_next_compartment(
31+
// const ConstNeighborsView<ComputeSpace>& neighbors,
32+
// const CumulativeProbabilityView<ComputeSpace>& cumulative_probability,
33+
// const std::size_t i_compartment,
34+
// const double random_number)
35+
// {
36+
// const int max_neighbor = static_cast<int>(neighbors.extent(1));
37+
38+
// std::size_t next =
39+
// neighbors(i_compartment, 0); // Default to the first neighbor
40+
41+
// // Iterate through the neighbors to find the appropriate next compartment
42+
// for (int k_neighbor = 0; k_neighbor < max_neighbor - 1; ++k_neighbor)
43+
// {
44+
45+
// // Get the cumulative probability range for the current neighbor
46+
// const auto pi = cumulative_probability(i_compartment, k_neighbor);
47+
// const auto pn = cumulative_probability(i_compartment, k_neighbor + 1);
48+
49+
// // Use of a Condition mask to avoid branching.
50+
// next = (random_number <= pn && pi <= random_number)
51+
// ? neighbors(i_compartment, k_neighbor + 1)
52+
// : next;
53+
// }
54+
55+
// return next; // Return the index of the chosen next compartment
56+
// }
57+
58+
/** @brief probably overkill binary search to find next compartment
59+
60+
Compared with first impl it might not change anything
61+
Binary seach is O(log(n)) vs first linear is (n)
62+
*/
3063
KOKKOS_INLINE_FUNCTION std::size_t __find_next_compartment(
3164
const ConstNeighborsView<ComputeSpace>& neighbors,
3265
const CumulativeProbabilityView<ComputeSpace>& cumulative_probability,
@@ -35,24 +68,19 @@ namespace Simulation::KernelInline
3568
{
3669
const int max_neighbor = static_cast<int>(neighbors.extent(1));
3770

38-
std::size_t next =
39-
neighbors(i_compartment, 0); // Default to the first neighbor
71+
int left = 0;
72+
int right = max_neighbor - 1;
4073

41-
// Iterate through the neighbors to find the appropriate next compartment
42-
for (int k_neighbor = 0; k_neighbor < max_neighbor - 1; ++k_neighbor)
74+
while (left < right)
4375
{
44-
45-
// Get the cumulative probability range for the current neighbor
46-
const auto pi = cumulative_probability(i_compartment, k_neighbor);
47-
const auto pn = cumulative_probability(i_compartment, k_neighbor + 1);
48-
49-
// Use of a Condition mask to avoid branching.
50-
next = (random_number <= pn && pi <= random_number)
51-
? neighbors(i_compartment, k_neighbor + 1)
52-
: next;
76+
int mid = (left + right) >> 1; // NOLINT
77+
const auto pm = cumulative_probability(i_compartment, mid);
78+
const int mask = static_cast<int>(random_number > pm);
79+
left = mask * (mid + 1) + (1 - mask) * left;
80+
right = mask * right + (1 - mask) * mid;
5381
}
5482

55-
return next; // Return the index of the chosen next compartment
83+
return neighbors(i_compartment, left);
5684
}
5785

5886
struct TagMove
@@ -206,10 +234,10 @@ namespace Simulation::KernelInline
206234
KOKKOS_ASSERT(i_current_compartment < move.liquid_volume.extent(0));
207235

208236
const bool mask_next =
209-
probability_leaving<int>(rng1,
210-
move.liquid_volume(i_current_compartment),
211-
move.diag_transition(i_current_compartment),
212-
d_t);
237+
probability_leaving<void>(rng1,
238+
move.liquid_volume(i_current_compartment),
239+
move.diag_transition(i_current_compartment),
240+
d_t);
213241

214242
positions(idx) =
215243
(mask_next) ? __find_next_compartment(move.neighbors,
@@ -229,52 +257,62 @@ namespace Simulation::KernelInline
229257
}
230258
}
231259

232-
KOKKOS_INLINE_FUNCTION void handle_exit(std::size_t idx,
233-
std::size_t& dead_count) const
260+
KOKKOS_INLINE_FUNCTION void
261+
inner_handle_exit(const std::size_t i_flow,
262+
const std::size_t idx,
263+
std ::size_t& dead_count,
264+
const double liquid_volume,
265+
const std::size_t position) const
234266
{
235-
const auto position = positions(idx);
236-
const std::size_t n_flow = move.leaving_flow.size();
237-
const auto liquid_volume = move.liquid_volume(position);
238-
239-
for (std::size_t i = 0LU; i < n_flow; ++i)
240-
{
267+
// const float random_number = random(idx, i);
241268

242-
// const float random_number = random(idx, i);
243-
auto generator = random_pool.get_state();
244-
const float random_number = generator.frand(0., 1.);
245-
random_pool.free_state(generator);
269+
SAMPLE_RANDOM_VARIABLES(random_pool,
270+
const float random_number =
271+
_generator_state_.frand(0., 1.);)
246272

247-
const auto& [index, flow] = move.leaving_flow(i);
273+
const auto& [index, flow] = move.leaving_flow(i_flow);
248274

249-
const bool is_leaving =
250-
(position == index) &&
251-
probability_leaving(random_number, liquid_volume, flow, d_t);
275+
const bool is_leaving =
276+
(position == index) &&
277+
probability_leaving<void>(random_number, liquid_volume, flow, d_t);
252278

253-
const int leave_mask = static_cast<int>(is_leaving);
279+
const int leave_mask = static_cast<int>(is_leaving);
254280

255-
// If using probes
256-
if constexpr (AutoGenerated::FlagCompileTime::use_probe)
281+
// If using probes
282+
if constexpr (AutoGenerated::FlagCompileTime::use_probe)
283+
{
284+
// Execute probe set, but only actually do something if leaving
285+
const auto probe_value = ages(idx, 0);
286+
if (is_leaving && !probes.set(probe_value))
257287
{
258-
// Execute probe set, but only actually do something if leaving
259-
const auto probe_value = ages(idx, 0);
260-
if (is_leaving && !probes.set(probe_value))
261-
{
262-
Kokkos::printf("[Kernel]: PROBES OVERFLOW\r\n");
263-
}
288+
Kokkos::printf("[Kernel]: PROBES OVERFLOW\r\n");
264289
}
265-
if constexpr (AutoGenerated::FlagCompileTime::enable_event_counter)
290+
}
291+
if constexpr (AutoGenerated::FlagCompileTime::enable_event_counter)
292+
{
293+
if (is_leaving)
266294
{
267-
if (is_leaving)
268-
{
269-
events.incr<MC::EventType::Exit>();
270-
}
295+
events.incr<MC::EventType::Exit>();
271296
}
297+
}
272298

273-
dead_count += leave_mask;
274-
ages(idx, 0) = leave_mask * 0 + (1 - leave_mask) * ages(idx, 0);
275-
status(idx) = is_leaving ? MC::Status::Exit : status(idx);
299+
dead_count += leave_mask;
300+
ages(idx, 0) = leave_mask * 0 + (1 - leave_mask) * ages(idx, 0);
301+
// status(idx) = is_leaving ? MC::Status::Exit : status(idx);
302+
status(idx) = static_cast<MC::Status>(
303+
static_cast<int>(MC::Status::Exit) * leave_mask +
304+
(1 - leave_mask) * static_cast<int>(status(idx)));
305+
}
276306

277-
// events.wrap_incr<MC::EventType::Exit>();
307+
KOKKOS_INLINE_FUNCTION void handle_exit(std::size_t idx,
308+
std::size_t& dead_count) const
309+
{
310+
const auto position = positions(idx);
311+
const std::size_t n_flow = move.leaving_flow.size();
312+
const auto liquid_volume = move.liquid_volume(position);
313+
for (std::size_t i = 0LU; i < n_flow; ++i)
314+
{
315+
inner_handle_exit(i, idx, dead_count, liquid_volume, position);
278316
}
279317
}
280318

0 commit comments

Comments
 (0)