Skip to content

Commit 186eced

Browse files
authored
add consistency checks for allow halving/doubling flags. closes #266 (#362)
1 parent ebaffda commit 186eced

File tree

6 files changed

+130
-6
lines changed

6 files changed

+130
-6
lines changed

src/aero_state.hpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ auto pointer_vec_magic(arr_t &data_vec, const arg_t &arg) {
224224
struct AeroState {
225225
PMCResource ptr;
226226
std::shared_ptr<AeroData> aero_data;
227+
int allow_halving = -1, allow_doubling = -1;
227228

228229
AeroState(
229230
std::shared_ptr<AeroData> aero_data,
@@ -572,7 +573,7 @@ struct AeroState {
572573
}
573574

574575
static int dist_sample(
575-
const AeroState &self,
576+
AeroState &self,
576577
const AeroDist &aero_dist,
577578
const double &sample_prop,
578579
const double &create_time,
@@ -581,6 +582,15 @@ struct AeroState {
581582
) {
582583
int n_part_add = 0;
583584

585+
if (
586+
(self.allow_doubling != -1 && self.allow_doubling != allow_doubling) ||
587+
(self.allow_halving != -1 && self.allow_halving != allow_halving)
588+
)
589+
throw std::runtime_error("dist_sample() called with different halving/doubling settings then in last call");
590+
591+
self.allow_doubling = allow_doubling;
592+
self.allow_halving = allow_halving;
593+
584594
f_aero_state_add_aero_dist_sample(
585595
self.ptr.f_arg(),
586596
self.aero_data->ptr.f_arg(),

src/run_part.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,17 @@
77
#include "run_part.hpp"
88
#include "pybind11/stl.h"
99

10+
void check_allow_flags(
11+
const AeroState &aero_state,
12+
const RunPartOpt &run_part_opt
13+
) {
14+
if (
15+
(aero_state.allow_halving != -1 && run_part_opt.allow_halving != aero_state.allow_halving) ||
16+
(aero_state.allow_doubling != -1 && run_part_opt.allow_doubling != aero_state.allow_doubling)
17+
)
18+
throw std::runtime_error("allow halving/doubling flags set differently then while sampling");
19+
}
20+
1021
void run_part(
1122
const Scenario &scenario,
1223
EnvState &env_state,
@@ -18,6 +29,7 @@ void run_part(
1829
const CampCore &camp_core,
1930
const Photolysis &photolysis
2031
) {
32+
check_allow_flags(aero_state, run_part_opt);
2133
f_run_part(
2234
scenario.ptr.f_arg(),
2335
env_state.ptr.f_arg_non_const(),
@@ -47,6 +59,7 @@ std::tuple<double, double, int> run_part_timestep(
4759
double &last_progress_time,
4860
int &i_output
4961
) {
62+
check_allow_flags(aero_state, run_part_opt);
5063
f_run_part_timestep(
5164
scenario.ptr.f_arg(),
5265
env_state.ptr.f_arg_non_const(),
@@ -84,6 +97,7 @@ std::tuple<double, double, int> run_part_timeblock(
8497
double &last_progress_time,
8598
int &i_output
8699
) {
100+
check_allow_flags(aero_state, run_part_opt);
87101
f_run_part_timeblock(
88102
scenario.ptr.f_arg(),
89103
env_state.ptr.f_arg_non_const(),

src/run_part_opt.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ extern "C" void f_run_part_opt_del_t(const void *ptr, double *del_t) noexcept;
1818

1919
struct RunPartOpt {
2020
PMCResource ptr;
21+
bool allow_halving, allow_doubling;
2122

2223
RunPartOpt(const nlohmann::json &json) :
2324
ptr(f_run_part_opt_ctor, f_run_part_opt_dtor)
@@ -39,6 +40,8 @@ struct RunPartOpt {
3940
}))
4041
if (json_copy.find(key) == json_copy.end())
4142
json_copy[key] = true;
43+
allow_halving = json_copy["allow_halving"];
44+
allow_doubling = json_copy["allow_doubling"];
4245

4346
for (auto key : std::set<std::string>({
4447
"t_output", "t_progress", "rand_init"

tests/test_aero_state.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,3 +556,37 @@ def test_dist_sample_mono():
556556

557557
# assert
558558
assert np.isclose(np.array(sut.diameters()), diam).all()
559+
560+
@staticmethod
561+
@pytest.mark.parametrize(
562+
"args",
563+
(
564+
((True, True), (True, False)),
565+
((True, True), (False, True)),
566+
((True, True), (False, False)),
567+
((False, False), (True, False)),
568+
((False, False), (False, True)),
569+
((False, False), (True, True)),
570+
((True, False), (False, False)),
571+
((True, False), (False, True)),
572+
((False, True), (False, False)),
573+
((False, True), (True, False)),
574+
),
575+
)
576+
@pytest.mark.skipif(platform.machine() == "arm64", reason="TODO #348")
577+
def test_dist_sample_different_halving(args):
578+
# arrange
579+
aero_data = ppmc.AeroData(AERO_DATA_CTOR_ARG_MINIMAL)
580+
aero_dist = ppmc.AeroDist(aero_data, [AERO_MODE_CTOR_SAMPLED])
581+
sut = ppmc.AeroState(aero_data, *AERO_STATE_CTOR_ARG_MINIMAL)
582+
583+
# act
584+
with pytest.raises(RuntimeError) as excinfo:
585+
_ = sut.dist_sample(aero_dist, 1.0, 0.0, *args[0])
586+
_ = sut.dist_sample(aero_dist, 1.0, 0.0, *args[1])
587+
588+
# assert
589+
assert (
590+
str(excinfo.value)
591+
== "dist_sample() called with different halving/doubling settings then in last call"
592+
)

tests/test_output.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def test_input_netcdf(tmp_path):
6868
aero_dist,
6969
sample_prop=1.0,
7070
create_time=0.0,
71-
allow_doubling=True,
72-
allow_halving=True,
71+
allow_doubling=False,
72+
allow_halving=False,
7373
)
7474

7575
num_concs = aero_state.num_concs

tests/test_run_part.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# Authors: https://github.com/open-atmos/PyPartMC/graphs/contributors #
55
####################################################################################################
66

7+
import platform
8+
79
import numpy as np
810
import pytest
911

@@ -52,7 +54,7 @@ def test_run_part(common_args):
5254

5355
@staticmethod
5456
def test_run_part_timestep(common_args):
55-
(last_output_time, last_progress_time, i_output) = ppmc.run_part_timestep(
57+
last_output_time, last_progress_time, i_output = ppmc.run_part_timestep(
5658
*common_args, 1, 0, 0, 0, 1
5759
)
5860

@@ -63,14 +65,18 @@ def test_run_part_timestep(common_args):
6365

6466
@staticmethod
6567
def test_run_part_timeblock(common_args):
68+
# arrange
6669
num_times = int(
6770
RUN_PART_OPT_CTOR_ARG_SIMULATION["t_output"]
6871
/ RUN_PART_OPT_CTOR_ARG_SIMULATION["del_t"]
6972
)
70-
(last_output_time, last_progress_time, i_output) = ppmc.run_part_timeblock(
73+
74+
# act
75+
last_output_time, last_progress_time, i_output = ppmc.run_part_timeblock(
7176
*common_args, 1, num_times, 0, 0, 0, 1
7277
)
7378

79+
# assert
7480
assert last_output_time == RUN_PART_OPT_CTOR_ARG_SIMULATION["t_output"]
7581
assert last_progress_time == 0.0
7682
assert i_output == 2
@@ -94,8 +100,65 @@ def test_run_part_do_condensation(common_args, tmp_path):
94100
"do_condensation": True,
95101
}
96102
)
97-
aero_state.dist_sample(aero_dist, 1.0, 0.0, True, True)
103+
aero_state.dist_sample(aero_dist, 1.0, 0.0, False, False)
98104
ppmc.condense_equilib_particles(env_state, aero_data, aero_state)
99105
ppmc.run_part(*args)
100106

101107
assert np.sum(aero_state.masses(include=["H2O"])) > 0.0
108+
109+
@staticmethod
110+
@pytest.mark.parametrize(
111+
"flags",
112+
(
113+
((True, True), (True, False)),
114+
((True, True), (False, True)),
115+
((True, True), (False, False)),
116+
((False, False), (True, False)),
117+
((False, False), (False, True)),
118+
((False, False), (True, True)),
119+
((True, False), (False, False)),
120+
((True, False), (False, True)),
121+
((False, True), (False, False)),
122+
((False, True), (True, False)),
123+
),
124+
)
125+
@pytest.mark.parametrize(
126+
"fun_args",
127+
(
128+
("run_part", []),
129+
("run_part_timestep", [0, 0, 0, 0, 0]),
130+
("run_part_timeblock", [0, 0, 0, 0, 0, 0]),
131+
),
132+
)
133+
@pytest.mark.skipif(platform.machine() == "arm64", reason="TODO #348")
134+
def test_run_part_allow_flag_mismatch(common_args, tmp_path, fun_args, flags):
135+
# arrange
136+
filename = tmp_path / "test"
137+
env_state = ppmc.EnvState(ENV_STATE_CTOR_ARG_HIGH_RH)
138+
aero_data = ppmc.AeroData(AERO_DATA_CTOR_ARG_FULL)
139+
aero_dist = ppmc.AeroDist(aero_data, AERO_DIST_CTOR_ARG_FULL)
140+
aero_state = ppmc.AeroState(aero_data, *AERO_STATE_CTOR_ARG_MINIMAL)
141+
args = list(common_args)
142+
args[0].init_env_state(env_state, 0.0)
143+
args[1] = env_state
144+
args[2] = aero_data
145+
args[3] = aero_state
146+
args[6] = ppmc.RunPartOpt(
147+
{
148+
**RUN_PART_OPT_CTOR_ARG_SIMULATION,
149+
"output_prefix": str(filename),
150+
"allow_doubling": flags[0][0],
151+
"allow_halving": flags[0][1],
152+
}
153+
)
154+
aero_state.dist_sample(aero_dist, 1.0, 0.0, flags[1][0], flags[1][1])
155+
156+
# act
157+
with pytest.raises(RuntimeError) as excinfo:
158+
getattr(ppmc, fun_args[0])(*args, *fun_args[1])
159+
160+
# assert
161+
assert (
162+
str(excinfo.value)
163+
== "allow halving/doubling flags set differently then while sampling"
164+
)

0 commit comments

Comments
 (0)