Skip to content

Commit 71b86a9

Browse files
feat(core): Support serializing AdamParams and SEKParams
1 parent 21b69e6 commit 71b86a9

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

core/include/gprat/hyperparameters.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "gprat/detail/config.hpp"
77

88
#include <string>
9+
#include <memory>
910

1011
GPRAT_NS_BEGIN
1112

@@ -58,6 +59,30 @@ struct AdamParams
5859
std::string repr() const;
5960
};
6061

62+
template <class Archive>
63+
void save_construct_data(Archive &ar, const AdamParams *v, const unsigned int)
64+
{
65+
ar << v->learning_rate;
66+
ar << v->beta1;
67+
ar << v->beta2;
68+
ar << v->epsilon;
69+
ar << v->opt_iter;
70+
}
71+
72+
template <class Archive>
73+
void load_construct_data(Archive &ar, AdamParams *v, const unsigned int)
74+
{
75+
double learning_rate, beta1, beta2, epsilon;
76+
int opt_iter;
77+
ar >> learning_rate;
78+
ar >> beta1;
79+
ar >> beta2;
80+
ar >> epsilon;
81+
ar >> opt_iter;
82+
83+
std::construct_at(v, learning_rate, beta1, beta2, epsilon, opt_iter);
84+
}
85+
6186
GPRAT_NS_END
6287

6388
#endif

core/include/gprat/kernels.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <cstddef>
99
#include <vector>
10+
#include <memory>
1011

1112
GPRAT_NS_BEGIN
1213

@@ -79,6 +80,31 @@ struct SEKParams
7980
const double &get_param(std::size_t index) const;
8081
};
8182

83+
template <class Archive>
84+
void save_construct_data(Archive &ar, const SEKParams *v, const unsigned int)
85+
{
86+
ar << v->lengthscale;
87+
ar << v->vertical_lengthscale;
88+
ar << v->noise_variance;
89+
}
90+
91+
template <class Archive>
92+
void load_construct_data(Archive &ar, SEKParams *v, const unsigned int)
93+
{
94+
double lengthscale, vertical_lengthscale, noise_variance;
95+
ar >> lengthscale;
96+
ar >> vertical_lengthscale;
97+
ar >> noise_variance;
98+
99+
std::construct_at(v, lengthscale, vertical_lengthscale, noise_variance);
100+
}
101+
102+
template <typename Archive>
103+
void serialize(Archive &ar, SEKParams &pt, const unsigned int)
104+
{
105+
ar & pt.m_T & pt.w_T;
106+
}
107+
82108
GPRAT_NS_END
83109

84110
#endif

0 commit comments

Comments
 (0)