-
Notifications
You must be signed in to change notification settings - Fork 33
Expand file tree
/
Copy pathedata.h
More file actions
580 lines (508 loc) · 17.2 KB
/
edata.h
File metadata and controls
580 lines (508 loc) · 17.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
#ifndef AMICI_EDATA_H
#define AMICI_EDATA_H
#include "amici/defines.h"
#include "amici/misc.h"
#include "amici/simulation_parameters.h"
#include <string>
#include <vector>
namespace amici {
class Model;
class ReturnData;
/**
* @brief ExpData carries all information about experimental or
* condition-specific data.
*/
class ExpData : public SimulationParameters {
public:
/**
* @brief Default constructor.
*/
ExpData() = default;
/**
* @brief Copy constructor.
*/
// needs to be declared to be wrapped by SWIG
ExpData(ExpData const&) = default;
/**
* @brief Constructor that only initializes dimensions.
*
* @param nytrue Number of observables
* @param nztrue Number of event outputs
* @param nmaxevent Maximal number of events to track
*/
ExpData(int nytrue, int nztrue, int nmaxevent);
/**
* @brief constructor that initializes timepoints from vectors
*
* @param nytrue Number of observables
* @param nztrue Number of event outputs
* @param nmaxevent Maximal number of events to track
* @param ts Timepoints (dimension: nt)
*/
ExpData(int nytrue, int nztrue, int nmaxevent, std::vector<realtype> ts);
/**
* @brief constructor that initializes timepoints and fixed parameters from
* vectors
*
* @param nytrue Number of observables
* @param nztrue Number of event outputs
* @param nmaxevent Maximal number of events to track
* @param ts Timepoints (dimension: nt)
* @param fixed_parameters Model variables excluded from sensitivity
* analysis (dimension: nk)
*/
ExpData(
int nytrue, int nztrue, int nmaxevent, std::vector<realtype> ts,
std::vector<realtype> fixed_parameters
);
/**
* @brief constructor that initializes timepoints and data from vectors
*
* @param nytrue Number of observables
* @param nztrue Number of event outputs
* @param nmaxevent Maximal number of events to track
* @param ts Timepoints (dimension: nt)
* @param observed_data observed data (dimension: nt x nytrue, row-major)
* @param observed_data_std_dev standard deviation of observed data
* (dimension: nt x nytrue, row-major)
* @param observed_events observed events
* (dimension: nmaxevents x nztrue, row-major)
* @param observed_events_std_dev standard deviation of observed
* events/roots (dimension: nmaxevents x nztrue, row-major)
*/
ExpData(
int nytrue, int nztrue, int nmaxevent, std::vector<realtype> ts,
std::vector<realtype> const& observed_data,
std::vector<realtype> const& observed_data_std_dev,
std::vector<realtype> const& observed_events,
std::vector<realtype> const& observed_events_std_dev
);
/**
* @brief constructor that initializes with Model
*
* @param model pointer to model specification object
*/
explicit ExpData(Model const& model);
/**
* @brief Constructor that initializes with ReturnData, adds normally
* distributed noise according to specified sigmas.
*
* @param rdata return data pointer with stored simulation results
* @param sigma_y scalar standard deviations for all observables
* @param sigma_z scalar standard deviations for all event observables
* @param seed Seed for the random number generator. If a negative number
* is passed, a random seed is used.
*/
ExpData(
ReturnData const& rdata, realtype sigma_y, realtype sigma_z,
int seed = -1
);
/**
* @brief Constructor that initializes with ReturnData, adds normally
* distributed noise according to specified sigmas.
*
* @param rdata return data pointer with stored simulation results
* @param sigma_y vector of standard deviations for observables
* (dimension: nytrue or nt x nytrue, row-major)
* @param sigma_z vector of standard deviations for event observables
* (dimension: nztrue or nmaxevent x nztrue, row-major)
* @param seed Seed for the random number generator. If a negative number
* is passed, a random seed is used.
*/
ExpData(
ReturnData const& rdata, std::vector<realtype> sigma_y,
std::vector<realtype> sigma_z, int seed = -1
);
~ExpData() = default;
friend inline bool operator==(ExpData const& lhs, ExpData const& rhs);
/**
* @brief number of observables of the non-augmented model
*
* @return number of observables of the non-augmented model
*/
int nytrue() const;
/**
* @brief number of event observables of the non-augmented model
*
* @return number of event observables of the non-augmented model
*/
int nztrue() const;
/**
* @brief maximal number of events to track
*
* @return maximal number of events to track
*/
int nmaxevent() const;
/**
* @brief number of timepoints
*
* @return number of timepoints
*/
int nt() const;
/**
* @brief Set output ts.
*
* If the number of timepoint increases, this will grow the
* observation/sigma matrices and fill new entries with NaN.
* If the number of ts decreases, this will shrink the
* observation/sigma matrices.
*
* Note that the mapping from ts to measurements will not be
* preserved. E.g., say there are measurements at t = 2, and this
* function is called with [1, 2], then the old measurements will belong to
* t = 1.
*
* @param ts ts
*/
void set_timepoints(std::vector<realtype> const& ts);
/**
* @brief Get output timepoints.
*
* @return ExpData::ts
*/
std::vector<realtype> const& get_timepoints() const;
/**
* @brief Get timepoint for the given index
*
* @param it timepoint index
*
* @return timepoint timepoint at index
*/
realtype get_timepoint(int it) const;
/**
* @brief Set all measurements.
*
* @param observed_data observed data (dimension: nt x nytrue, row-major)
*/
void set_observed_data(std::vector<realtype> const& observed_data);
/**
* @brief Set measurements for a given observable index
*
* @param observed_data observed data (dimension: nt)
* @param iy observed data index
*/
void set_observed_data(std::vector<realtype> const& observed_data, int iy);
/**
* @brief Whether there is a measurement for the given time- and observable-
* index.
*
* @param it time index
* @param iy observable index
*
* @return boolean specifying if data was set
*/
bool is_set_observed_data(int it, int iy) const;
/**
* @brief Get all measurements.
*
* @return observed data (dimension: nt x nytrue, row-major)
*/
std::vector<realtype> const& get_observed_data() const;
/**
* @brief Get measurements for a given timepoint index.
*
* @param it timepoint index
*
* @return pointer to observed data at index (dimension: nytrue)
*/
realtype const* get_observed_data_ptr(int it) const;
/**
* @brief Set standard deviations for measurements.
*
* @param observed_data_std_dev standard deviation of observed data
* (dimension: nt x nytrue, row-major)
*/
void set_observed_data_std_dev(
std::vector<realtype> const& observed_data_std_dev
);
/**
* @brief Set identical standard deviation for all measurements.
*
* @param stdDev standard deviation (dimension: scalar)
*/
void set_observed_data_std_dev(realtype stdDev);
/**
* @brief Set standard deviations of observed data for a
* specific observable index.
*
* @param observedDataStdDev standard deviation of observed data (dimension:
* nt)
* @param iy observed data index
*/
void set_observed_data_std_dev(
std::vector<realtype> const& observedDataStdDev, int iy
);
/**
* @brief Set all standard deviation for a given observable index to the
* input value.
*
* @param stdDev standard deviation (dimension: scalar)
* @param iy observed data index
*/
void set_observed_data_std_dev(realtype stdDev, int iy);
/**
* @brief Whether standard deviation for a measurement at
* specified timepoint- and observable index has been set.
*
* @param it time index
* @param iy observable index
* @return boolean specifying if standard deviation of data was set
*/
bool is_set_observed_data_std_dev(int it, int iy) const;
/**
* @brief Get measurement standard deviations.
*
* @return standard deviation of observed data
*/
std::vector<realtype> const& get_observed_data_std_dev() const;
/**
* @brief Get pointer to measurement standard deviations.
*
* @param it timepoint index
* @return pointer to standard deviation of observed data at index
*/
realtype const* get_observed_data_std_dev_ptr(int it) const;
/**
* @brief Set observed event data.
*
* @param observedEvents observed data (dimension: nmaxevent x nztrue,
* row-major)
*/
void set_observed_events(std::vector<realtype> const& observedEvents);
/**
* @brief Set observed event data for specific event observable.
*
* @param observedEvents observed data (dimension: nmaxevent)
* @param iz observed event data index
*/
void
set_observed_events(std::vector<realtype> const& observedEvents, int iz);
/**
* @brief Check whether event data at specified indices has been set.
*
* @param ie event index
* @param iz event observable index
* @return boolean specifying if data was set
*/
bool is_set_observed_events(int ie, int iz) const;
/**
* @brief Get observed event data.
*
* @return observed event data
*/
std::vector<realtype> const& get_observed_events() const;
/**
* @brief Get pointer to observed data at ie-th occurrence.
*
* @param ie event occurrence
*
* @return pointer to observed event data at ie-th occurrence
*/
realtype const* get_observed_events_ptr(int ie) const;
/**
* @brief Set standard deviation of observed event data.
*
* @param observedEventsStdDev standard deviation of observed event data
*/
void set_observed_events_std_dev(
std::vector<realtype> const& observedEventsStdDev
);
/**
* @brief Set standard deviation of observed event data.
*
* @param stdDev standard deviation (dimension: scalar)
*/
void set_observed_events_std_dev(realtype stdDev);
/**
* @brief Set standard deviation of observed data for a specific observable.
*
* @param observedEventsStdDev standard deviation of observed data
* (dimension: nmaxevent)
* @param iz observed data index
*/
void set_observed_events_std_dev(
std::vector<realtype> const& observedEventsStdDev, int iz
);
/**
* @brief Set all standard deviations of a specific event-observable.
*
* @param stdDev standard deviation (dimension: scalar)
* @param iz observed data index
*/
void set_observed_events_std_dev(realtype stdDev, int iz);
/**
* @brief Check whether standard deviation of event data
* at specified indices has been set.
*
* @param ie event index
* @param iz event observable index
* @return boolean specifying if standard deviation of event data was set
*/
bool is_set_observed_events_std_dev(int ie, int iz) const;
/**
* @brief Get standard deviation of observed event data.
*
* @return standard deviation of observed event data
*/
std::vector<realtype> const& get_observed_events_std_dev() const;
/**
* @brief Get pointer to standard deviation of
* observed event data at ie-th occurrence.
*
* @param ie event occurrence
*
* @return pointer to standard deviation of observed event data at ie-th
* occurrence
*/
realtype const* get_observed_events_std_dev_ptr(int ie) const;
/**
* @brief Set all observations and their standard deviations to NaN.
*
* Useful, e.g., after calling ExpData::setTimepoints.
*/
void clear_observations();
/**
* @brief Arbitrary (not necessarily unique) identifier.
*/
std::string id;
protected:
/**
* @brief resizes observedData, observedDataStdDev, observedEvents and
* observedEventsStdDev
*/
void apply_dimensions();
/**
* @brief resizes observedData and observedDataStdDev
*/
void apply_data_dimension();
/**
* @brief resizes observedEvents and observedEventsStdDev
*/
void apply_event_dimension();
/**
* @brief checker for dimensions of input observedData or observedDataStdDev
*
* @param input vector input to be checked
* @param fieldname name of the input
*/
void check_data_dimension(
std::vector<realtype> const& input, char const* fieldname
) const;
/**
* @brief checker for dimensions of input observedEvents or
* observedEventsStdDev
*
* @param input vector input to be checked
* @param fieldname name of the input
*/
void check_events_dimension(
std::vector<realtype> const& input, char const* fieldname
) const;
/** @brief number of observables */
int nytrue_{0};
/** @brief number of event observables */
int nztrue_{0};
/** @brief maximal number of event occurrences */
int nmaxevent_{0};
/** @brief observed data (dimension: nt x nytrue, row-major) */
std::vector<realtype> observed_data_;
/**
* @brief standard deviation of observed data (dimension: nt x nytrue,
* row-major)
*/
std::vector<realtype> observed_data_std_dev_;
/**
* @brief observed events (dimension: nmaxevents x nztrue, row-major)
*/
std::vector<realtype> observed_events_;
/**
* @brief standard deviation of observed events/roots
* (dimension: nmaxevents x nztrue, row-major)
*/
std::vector<realtype> observed_events_std_dev_;
};
/**
* @brief Equality operator
* @param lhs some object
* @param rhs another object
* @return `true`, if both arguments are equal; `false` otherwise.
*/
inline bool operator==(ExpData const& lhs, ExpData const& rhs) {
return *dynamic_cast<SimulationParameters const*>(&lhs)
== *dynamic_cast<SimulationParameters const*>(&rhs)
&& lhs.id == rhs.id && lhs.nytrue_ == rhs.nytrue_
&& lhs.nztrue_ == rhs.nztrue_ && lhs.nmaxevent_ == rhs.nmaxevent_
&& is_equal(lhs.observed_data_, rhs.observed_data_)
&& is_equal(lhs.observed_data_std_dev_, rhs.observed_data_std_dev_)
&& is_equal(lhs.observed_events_, rhs.observed_events_)
&& is_equal(
lhs.observed_events_std_dev_, rhs.observed_events_std_dev_
);
}
/**
* @brief checks input vector of sigmas for not strictly positive values
*
* @param sigmaVector vector input to be checked
* @param vectorName name of the input
*/
void check_sigma_positivity(
std::vector<realtype> const& sigmaVector, char const* vectorName
);
/**
* @brief checks input scalar sigma for not strictly positive value
*
* @param sigma input to be checked
* @param sigmaName name of the input
*/
void check_sigma_positivity(realtype sigma, char const* sigmaName);
/**
* @brief The ConditionContext class applies condition-specific amici::Model
* settings and restores them when going out of scope
*/
class ConditionContext : public ContextManager {
public:
/**
* @brief Apply condition-specific settings from edata to model while
* keeping a backup of the original values.
*
* @param model
* @param edata
* @param fpc flag indicating which fixedParameter from edata to apply
*/
explicit ConditionContext(
Model* model, ExpData const* edata = nullptr,
FixedParameterContext fpc = FixedParameterContext::simulation
);
ConditionContext& operator=(ConditionContext const& other) = delete;
~ConditionContext();
/**
* @brief Apply condition-specific settings from edata to the
* constructor-supplied model, not changing the settings which were
* backed-up in the constructor call.
*
* @param edata
* @param fpc flag indicating which fixedParameter from edata to apply
*/
void apply_condition(ExpData const* edata, FixedParameterContext fpc);
/**
* @brief Restore original settings on constructor-supplied amici::Model.
* Will be called during destruction. Explicit call is generally not
* necessary.
*/
void restore();
private:
Model* model_ = nullptr;
std::vector<realtype> original_x0_;
std::vector<realtype> original_sx0_;
std::vector<realtype> original_parameters_;
std::vector<realtype> original_fixed_parameters_;
realtype original_tstart_;
realtype original_tstart_preeq_;
std::vector<realtype> original_timepoints_;
std::vector<int> original_parameter_list_;
std::vector<ParameterScaling> original_scaling_;
bool original_reinitialize_fixed_parameter_initial_states_;
std::vector<int> original_reinitialization_state_idxs;
};
} // namespace amici
#endif /* AMICI_EDATA_H */