Skip to content

Commit 6bc0825

Browse files
committed
Improvements according to review --- IGNORE ---
1 parent 42cbb92 commit 6bc0825

File tree

6 files changed

+45
-58
lines changed

6 files changed

+45
-58
lines changed

cpp/models/abm/analyze_result.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,6 @@ std::vector<Model> ensemble_params_percentile(const std::vector<std::vector<Mode
169169
return model.parameters.template get<DeathsPerInfectedCritical>()[{virus_variant, age_group}];
170170
});
171171

172-
param_percentile(node, [age_group, virus_variant](auto&& model) -> auto& {
173-
return model.parameters.template get<DetectInfection>()[{virus_variant, age_group}];
174-
});
175-
176172
param_percentile(node, [virus_variant](auto&& model) -> auto& {
177173
return model.parameters.template get<AerosolTransmissionRates>()[{virus_variant}];
178174
});
@@ -187,9 +183,9 @@ std::vector<Model> ensemble_params_percentile(const std::vector<std::vector<Mode
187183
return dist1.viral_load_peak < dist2.viral_load_peak;
188184
});
189185
param_percentile_dist(
190-
node, std::vector<ViralShedParameters>(num_runs),
186+
node, std::vector<ViralShedTuple>(num_runs),
191187
[age_group, virus_variant](auto&& model) -> auto& {
192-
return model.parameters.template get<ViralShedDistribution>()[{virus_variant, age_group}];
188+
return model.parameters.template get<ViralShedParameters>()[{virus_variant, age_group}];
193189
},
194190
[](auto& dist1, auto& dist2) {
195191
return dist1.virus_shed_alpha < dist2.virus_shed_alpha;

cpp/models/abm/infection.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace abm
3131
void Infection::initialize_viral_load(PersonalRandomNumberGenerator& rng, VirusVariant virus, AgeGroup age,
3232
const Parameters& params, TimePoint init_date, ProtectionEvent latest_protection)
3333
{
34-
auto vl_params = params.get<ViralLoadDistributions>()[{virus, age}];
34+
auto& vl_params = params.get<ViralLoadDistributions>()[{virus, age}];
3535
ScalarType high_viral_load_factor = 1;
3636

3737
if (latest_protection.type != ProtectionType::NoProtection) {
@@ -49,7 +49,7 @@ void Infection::initialize_viral_load(PersonalRandomNumberGenerator& rng, VirusV
4949
void Infection::initialize_viral_shed(PersonalRandomNumberGenerator& rng, VirusVariant virus, AgeGroup age,
5050
const Parameters& params)
5151
{
52-
auto viral_shed_params = params.get<ViralShedDistribution>()[{virus, age}];
52+
auto viral_shed_params = params.get<ViralShedParameters>()[{virus, age}];
5353
m_log_norm_alpha = viral_shed_params.viral_shed_alpha;
5454
m_log_norm_beta = viral_shed_params.viral_shed_beta;
5555

@@ -349,8 +349,8 @@ TimePoint Infection::draw_infection_course_forward(PersonalRandomNumberGenerator
349349
{
350350
TimePoint start_of_init_state = init_date; // since there is no state transition from Recovered or Dead
351351
bool init = true; // the random start time of the first state cannot be drawn
352-
auto& uniform_dist =
353-
UniformDistribution<double>::get_instance(); // thus, just use the init_date as the start for this
352+
// thus, just use the init_date as the start for this
353+
auto& init_state_dist = params.get<InitialInfectionStateDistributions>()[{m_virus_variant, age, start_state}];
354354

355355
TimePoint current_time = init_date;
356356
InfectionState current_state = start_state;
@@ -360,7 +360,7 @@ TimePoint Infection::draw_infection_course_forward(PersonalRandomNumberGenerator
360360
StateTransition transition =
361361
get_forward_transition(rng, age, params, current_state, current_time, latest_protection);
362362
if (init && current_state != InfectionState::Susceptible) { // random init within first time period
363-
ScalarType p = uniform_dist(rng);
363+
ScalarType p = init_state_dist.get(rng);
364364
start_of_init_state -= transition.duration.multiply(1 - p);
365365
transition.duration = transition.duration.multiply(p);
366366
init = false;

cpp/models/abm/parameters.h

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include "abm/mask_type.h"
2424
#include "abm/time.h"
25+
#include "abm/infection_state.h"
2526
#include "abm/virus_variant.h"
2627
#include "abm/protection_event.h"
2728
#include "abm/protection_event.h"
@@ -276,6 +277,23 @@ struct DeathsPerInfectedCritical {
276277
}
277278
};
278279

280+
/**
281+
* @brief Distributions of the time that people have been in their initial infection state at the beginning of the simulation.
282+
* This makes it possible to draw from a user-defined distribution instead of drawing from a uniform distribution.
283+
*/
284+
struct InitialInfectionStateDistributions {
285+
using Type = CustomIndexArray<AbstractParameterDistribution, VirusVariant, AgeGroup, InfectionState>;
286+
static Type get_default(AgeGroup size)
287+
{
288+
return Type({VirusVariant::Count, size, InfectionState::Count},
289+
AbstractParameterDistribution(ParameterDistributionUniform(0., 1.)));
290+
}
291+
static std::string name()
292+
{
293+
return "InitialInfectionStateDistributions";
294+
}
295+
};
296+
279297
/**
280298
* @brief Parameters for the ViralLoad course. Default values taken as constant values from the average from
281299
* https://github.com/VirologyCharite/SARS-CoV-2-VL-paper/tree/main
@@ -317,29 +335,29 @@ struct ViralLoadDistributions {
317335
* @brief Parameters for the viral shed. Default values taken as constant values that match the graph 2C from
318336
* https://github.com/VirologyCharite/SARS-CoV-2-VL-paper/tree/main
319337
*/
320-
struct ViralShedParameters {
338+
struct ViralShedTuple {
321339
ScalarType viral_shed_alpha;
322340
ScalarType viral_shed_beta;
323341

324342
/// This method is used by the default serialization feature.
325343
auto default_serialize()
326344
{
327-
return Members("ViralShedParameters")
345+
return Members("ViralShedTuple")
328346
.add("viral_shed_alpha", viral_shed_alpha)
329347
.add("viral_shed_beta", viral_shed_beta);
330348
}
331349
};
332350

333-
struct ViralShedDistribution {
334-
using Type = CustomIndexArray<ViralShedParameters, VirusVariant, AgeGroup>;
351+
struct ViralShedParameters {
352+
using Type = CustomIndexArray<ViralShedTuple, VirusVariant, AgeGroup>;
335353
static Type get_default(AgeGroup size)
336354
{
337-
Type default_val({VirusVariant::Count, size}, ViralShedParameters{-7., 1.});
355+
Type default_val({VirusVariant::Count, size}, ViralShedTuple{-7., 1.});
338356
return default_val;
339357
}
340358
static std::string name()
341359
{
342-
return "ViralShedDistribution";
360+
return "ViralShedParameters";
343361
}
344362
};
345363

@@ -375,21 +393,6 @@ struct InfectionRateFromViralShed {
375393
}
376394
};
377395

378-
/**
379-
* @brief Probability that an Infection is detected.
380-
*/
381-
struct DetectInfection {
382-
using Type = CustomIndexArray<UncertainValue<ScalarType>, VirusVariant, AgeGroup>;
383-
static Type get_default(AgeGroup size)
384-
{
385-
return Type({VirusVariant::Count, size}, 1.);
386-
}
387-
static std::string name()
388-
{
389-
return "DetectInfection";
390-
}
391-
};
392-
393396
/**
394397
* @brief Effectiveness of a Mask of a certain MaskType% against an Infection%.
395398
*/
@@ -715,12 +718,13 @@ using ParametersBase =
715718
TimeInfectedSymptomsToSevere, TimeInfectedSymptomsToRecovered, TimeInfectedSevereToCritical,
716719
TimeInfectedSevereToRecovered, TimeInfectedSevereToDead, TimeInfectedCriticalToDead,
717720
TimeInfectedCriticalToRecovered, SymptomsPerInfectedNoSymptoms, SeverePerInfectedSymptoms,
718-
CriticalPerInfectedSevere, DeathsPerInfectedSevere, DeathsPerInfectedCritical, ViralLoadDistributions,
719-
ViralShedDistribution, ViralShedFactor, InfectionRateFromViralShed, DetectInfection, MaskProtection, AerosolTransmissionRates,
720-
LockdownDate, QuarantineDuration, QuarantineEffectiveness, SocialEventRate, BasicShoppingRate,
721-
WorkRatio, SchoolRatio, GotoWorkTimeMinimum, GotoWorkTimeMaximum, GotoSchoolTimeMinimum,
722-
GotoSchoolTimeMaximum, AgeGroupGotoSchool, AgeGroupGotoWork, InfectionProtectionFactor,
723-
SeverityProtectionFactor, HighViralLoadProtectionFactor, TestData>;
721+
CriticalPerInfectedSevere, DeathsPerInfectedSevere, DeathsPerInfectedCritical,
722+
InitialInfectionStateDistributions, ViralLoadDistributions, ViralShedParameters, ViralShedFactor,
723+
InfectionRateFromViralShed, MaskProtection, AerosolTransmissionRates, LockdownDate, QuarantineDuration,
724+
QuarantineEffectiveness, SocialEventRate, BasicShoppingRate, WorkRatio, SchoolRatio,
725+
GotoWorkTimeMinimum, GotoWorkTimeMaximum, GotoSchoolTimeMinimum, GotoSchoolTimeMaximum,
726+
AgeGroupGotoSchool, AgeGroupGotoWork, InfectionProtectionFactor, SeverityProtectionFactor,
727+
HighViralLoadProtectionFactor, TestData>;
724728

725729
/**
726730
* @brief Maximum number of Person%s an infectious Person can infect at the respective Location.
@@ -954,14 +958,6 @@ class Parameters : public ParametersBase
954958
(uint32_t)v, (size_t)i, 0);
955959
return true;
956960
}
957-
958-
if (this->get<DetectInfection>()[{v, i}] < 0.0 || this->get<DetectInfection>()[{v, i}] > 1.0) {
959-
log_error("Constraint check: Parameter DetectInfection of virus variant {} and age group {:.0f} "
960-
"smaller than {:d} or "
961-
"larger than {:d}",
962-
(uint32_t)v, (size_t)i, 0, 1);
963-
return true;
964-
}
965961
}
966962

967963
if (this->get<GotoWorkTimeMinimum>()[i].seconds() < 0.0 ||

cpp/models/abm/person.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,8 @@ ProtectionEvent Person::get_latest_protection(TimePoint t) const
233233
// Use reverse iterators to start from the most recent infection
234234
for (auto it = m_infections.rbegin(); it != m_infections.rend(); ++it) {
235235
if (it->get_start_date() <= t) {
236-
latest_exposure_type = ExposureType::NaturalInfection;
237-
latest_time = it->get_start_date();
236+
latest_protection_type = ProtectionType::NaturalInfection;
237+
latest_time = it->get_start_date();
238238
break; // Stop once we find the latest infection before time t
239239
}
240240
}
@@ -243,8 +243,8 @@ ProtectionEvent Person::get_latest_protection(TimePoint t) const
243243
for (auto it = m_vaccinations.rbegin(); it != m_vaccinations.rend(); ++it) {
244244
if (it->time <= t) {
245245
if (it->time > latest_time) {
246-
latest_exposure_type = it->exposure_type;
247-
latest_time = it->time;
246+
latest_protection_type = it->type;
247+
latest_time = it->time;
248248
}
249249
}
250250
else {

cpp/tests/test_abm_location.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ TEST_F(TestLocation, interact)
9292
params.get<mio::abm::ViralLoadDistributions>()[{variant, age}] = {mio::ParameterDistributionConstant(1.),
9393
mio::ParameterDistributionConstant(0.0001),
9494
mio::ParameterDistributionConstant(-0.0001)};
95-
params.set_default<mio::abm::ViralShedDistribution>(num_age_groups);
96-
params.get<mio::abm::ViralShedDistribution>()[{variant, age}] = {mio::ParameterDistributionConstant(1.),
97-
mio::ParameterDistributionConstant(1.)};
95+
params.set_default<mio::abm::ViralShedParameters>(num_age_groups);
96+
params.get<mio::abm::ViralShedParameters>()[{variant, age}] = {mio::ParameterDistributionConstant(1.),
97+
mio::ParameterDistributionConstant(1.)};
9898

9999
// Set incubtion period to two days so that the newly infected person is still exposed
100100
ScopedMockDistribution<testing::StrictMock<MockDistribution<mio::LogNormalDistribution<double>>>> mock_logNorm_dist;

cpp/tests/test_abm_model.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,6 @@ TEST_F(TestModel, checkParameterConstraints)
674674
params.get<mio::abm::CriticalPerInfectedSevere>()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.05;
675675
params.get<mio::abm::DeathsPerInfectedSevere>()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.001;
676676
params.get<mio::abm::DeathsPerInfectedCritical>()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.1;
677-
params.get<mio::abm::DetectInfection>()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.3;
678677
params.get<mio::abm::GotoWorkTimeMinimum>()[age_group_35_to_59] = mio::abm::hours(4);
679678
params.get<mio::abm::GotoWorkTimeMaximum>()[age_group_35_to_59] = mio::abm::hours(8);
680679
params.get<mio::abm::GotoSchoolTimeMinimum>()[age_group_0_to_4] = mio::abm::hours(3);
@@ -737,10 +736,6 @@ TEST_F(TestModel, checkParameterConstraints)
737736
ASSERT_EQ(params.check_constraints(), true);
738737
params.get<mio::abm::TimeInfectedCriticalToRecovered>()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] =
739738
mio::ParameterDistributionLogNormal(9., 0.5);
740-
params.get<mio::abm::DetectInfection>()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 1.1;
741-
ASSERT_EQ(params.check_constraints(), true);
742-
params.get<mio::abm::DetectInfection>()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.3;
743-
744739
params.get<mio::abm::SymptomsPerInfectedNoSymptoms>()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = -0.1;
745740
ASSERT_EQ(params.check_constraints(), true);
746741
params.get<mio::abm::SymptomsPerInfectedNoSymptoms>()[{mio::abm::VirusVariant::Wildtype, age_group_0_to_4}] = 0.2;

0 commit comments

Comments
 (0)