forked from marian-nmt/marian-dev
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_base.h
More file actions
79 lines (64 loc) · 1.9 KB
/
model_base.h
File metadata and controls
79 lines (64 loc) · 1.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#pragma once
#include <string>
#include "marian.h"
#include "layers/loss.h"
#include "layers/generic.h"
namespace marian {
namespace models {
enum struct usage { raw, training, scoring, translation, embedding };
}
} // namespace marian
YAML_REGISTER_TYPE(marian::models::usage, int)
// 'FASTOPT_REGISTER_TYPE'
#if FASTOPT
namespace marian {
namespace fastopt_helpers {
template <>
struct As<marian::models::usage> {
static marian::models::usage apply(const FastOpt& node) {
return static_cast<marian::models::usage>(As<int>::apply(node));
}
};
} // namespace fastopt_helpers
} // namespace marian
#endif
namespace marian {
namespace models {
// model = input -> predictions
class IModel {
public:
virtual void load(Ptr<ExpressionGraph>,
const std::string&,
bool markReloaded = true)
= 0;
virtual void save(Ptr<ExpressionGraph>,
const std::string&,
bool saveTranslatorConfig = false)
= 0;
virtual Logits build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true)
= 0;
virtual void clear(Ptr<ExpressionGraph> graph) = 0;
};
// criterion = (input, reference) -> loss
// @TODO: Is there a better name?
class ICriterionFunction {
public:
virtual ~ICriterionFunction() {}
virtual void load(Ptr<ExpressionGraph>,
const std::string&,
bool markReloaded = true)
= 0;
virtual void save(Ptr<ExpressionGraph>,
const std::string&,
bool saveTranslatorConfig = false)
= 0;
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true)
= 0;
virtual void clear(Ptr<ExpressionGraph> graph) = 0;
};
} // namespace models
} // namespace marian