Skip to content

Commit 5441ec1

Browse files
authored
Add support for Machine Translation model (microsoft#1482)
This pull request introduces the Marian example for machine translation, adds support for encoder-specific configurations, and enhances the handling of decoder inputs. Key changes include updates to the build system, new Marian examples in both C++ and Python, and modifications to the configuration and input handling for encoder-decoder models. ### Marian Example Addition: * [`examples/c/CMakeLists.txt`](diffhunk://#diff-53132a16068a4beab79c0d0b3165704244a0ea790e4975b42c0e8573427a65baR13): Added the Marian example to the build options and defined its executable. [[1]](diffhunk://#diff-53132a16068a4beab79c0d0b3165704244a0ea790e4975b42c0e8573427a65baR13) [[2]](diffhunk://#diff-53132a16068a4beab79c0d0b3165704244a0ea790e4975b42c0e8573427a65baR88-R92) * [`examples/c/src/marian.cpp`](diffhunk://#diff-bb987f6b4967383e87bf6e42bbe081e3269c0ecce61314deff2e6a06d5be59beR1-R115): Introduced a new C++ implementation for the Marian example, showcasing token generation using the ONNX runtime. * [`examples/python/marian.py`](diffhunk://#diff-bdd0d4e08a0fe5b1b45141a29a780776e2cb352a9653099184892853be384ccfR1-R93): Added a Python implementation of the Marian example, supporting interactive and non-interactive modes for token generation. ### Configuration Enhancements: * `src/config.h` and `src/config.cpp`: Replaced `EncoderDecoderInit` with a more detailed `Encoder` structure, adding fields for encoder-specific configurations such as `hidden_size`, `num_attention_heads`, and input/output names. Updated JSON parsing to handle these new fields. [[1]](diffhunk://#diff-c24f78b3519d763901eb9f67b864f01d802d803df1b24faaf154019cf812bf95R46-R53) [[2]](diffhunk://#diff-c24f78b3519d763901eb9f67b864f01d802d803df1b24faaf154019cf812bf95L93-R121) [[3]](diffhunk://#diff-c24f78b3519d763901eb9f67b864f01d802d803df1b24faaf154019cf812bf95R199-R203) [[4]](diffhunk://#diff-c24f78b3519d763901eb9f67b864f01d802d803df1b24faaf154019cf812bf95R212) [[5]](diffhunk://#diff-6b2f0a449fdefd8930e23ef0dcd752beec69242e1303d77653f047c5e0766385L166-R194) [[6]](diffhunk://#diff-6b2f0a449fdefd8930e23ef0dcd752beec69242e1303d77653f047c5e0766385R225-R232) [[7]](diffhunk://#diff-6b2f0a449fdefd8930e23ef0dcd752beec69242e1303d77653f047c5e0766385R257-R258) [[8]](diffhunk://#diff-6b2f0a449fdefd8930e23ef0dcd752beec69242e1303d77653f047c5e0766385R379-R418) [[9]](diffhunk://#diff-6b2f0a449fdefd8930e23ef0dcd752beec69242e1303d77653f047c5e0766385L616-R684) [[10]](diffhunk://#diff-6b2f0a449fdefd8930e23ef0dcd752beec69242e1303d77653f047c5e0766385L636-R703) ### Decoder Input Handling: * `src/models/input_ids.cpp` and `src/models/input_ids.h`: Introduced `DecoderInputIDs` and `DecoderInputs` to handle decoder-specific input IDs, including methods for adding and updating decoder inputs. [[1]](diffhunk://#diff-68f9fc6a35c9a61872d9a22f44c0c311a8670fa6e26eb081077924d088d72decR36-R45) [[2]](diffhunk://#diff-68f9fc6a35c9a61872d9a22f44c0c311a8670fa6e26eb081077924d088d72decR60-R69) [[3]](diffhunk://#diff-68f9fc6a35c9a61872d9a22f44c0c311a8670fa6e26eb081077924d088d72decR126-R146) [[4]](diffhunk://#diff-9a2a38a53d0254d94c54591e1fd05c321b3ce5c3d38bb23123427dc3ffdf96c8R12-R18) ### Model updates: * Added marian model changes which includes encoder and decoder specific configurations * Updated logits logic to reuse most of the code for RNN and transformer-based models (Tested both encoder-decoder and decoder only models)
1 parent b8904ac commit 5441ec1

14 files changed

+412
-27
lines changed

examples/python/model-qa.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,10 @@ def main(args):
150150
else:
151151
messages = f"""[{{"role": "system", "content": "{system_prompt}"}}, {{"role": "user", "content": "{text}"}}]"""
152152
# Apply Chat Template
153-
prompt = tokenizer.apply_chat_template(messages=messages, add_generation_prompt=True)
153+
if model.type == "marian-ssru":
154+
prompt = text
155+
else:
156+
prompt = tokenizer.apply_chat_template(messages=messages, add_generation_prompt=True)
154157
input_tokens = tokenizer.encode(prompt)
155158
generator.append_tokens(input_tokens)
156159

src/config.cpp

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,18 +163,34 @@ struct SessionOptions_Element : JSON::Element {
163163
NamedStrings_Element config_entries_{v_.config_entries};
164164
};
165165

166-
struct EncoderDecoderInit_Element : JSON::Element {
167-
explicit EncoderDecoderInit_Element(Config::Model::EncoderDecoderInit& v) : v_{v} {}
166+
struct Encoder_Inputs_Element : JSON::Element {
167+
explicit Encoder_Inputs_Element(Config::Model::Encoder::Inputs& v) : v_{v} {}
168168

169169
void OnValue(std::string_view name, JSON::Value value) override {
170-
if (name == "filename") {
171-
v_.filename = JSON::Get<std::string_view>(value);
170+
if (name == "input_ids") {
171+
v_.input_ids = JSON::Get<std::string_view>(value);
172+
} else if (name == "attention_mask") {
173+
v_.attention_mask = JSON::Get<std::string_view>(value);
174+
} else
175+
throw JSON::unknown_value_error{};
176+
}
177+
178+
private:
179+
Config::Model::Encoder::Inputs& v_;
180+
};
181+
182+
struct Encoder_Outputs_Element : JSON::Element {
183+
explicit Encoder_Outputs_Element(Config::Model::Encoder::Outputs& v) : v_{v} {}
184+
185+
void OnValue(std::string_view name, JSON::Value value) override {
186+
if (name == "encoder_outputs") {
187+
v_.encoder_outputs = JSON::Get<std::string_view>(value);
172188
} else
173189
throw JSON::unknown_value_error{};
174190
}
175191

176192
private:
177-
Config::Model::EncoderDecoderInit& v_;
193+
Config::Model::Encoder::Outputs& v_;
178194
};
179195

180196
struct Inputs_Element : JSON::Element {
@@ -205,6 +221,14 @@ struct Inputs_Element : JSON::Element {
205221
v_.past_sequence_length = JSON::Get<std::string_view>(value);
206222
} else if (name == "total_sequence_length") {
207223
v_.total_sequence_length = JSON::Get<std::string_view>(value);
224+
} else if (name == "encoder_hidden_states") {
225+
v_.encoder_hidden_states = JSON::Get<std::string_view>(value);
226+
} else if (name == "encoder_attention_mask") {
227+
v_.encoder_attention_mask = JSON::Get<std::string_view>(value);
228+
} else if (name == "rnn_states_prev") {
229+
v_.rnn_prev_states = JSON::Get<std::string_view>(value);
230+
} else if (name == "past_key_values_length") {
231+
v_.past_key_values_length = JSON::Get<std::string_view>(value);
208232
} else
209233
throw JSON::unknown_value_error{};
210234
}
@@ -229,6 +253,8 @@ struct Outputs_Element : JSON::Element {
229253
v_.cross_present_key_names = JSON::Get<std::string_view>(value);
230254
} else if (name == "cross_present_value_names") {
231255
v_.cross_present_value_names = JSON::Get<std::string_view>(value);
256+
} else if (name == "rnn_states") {
257+
v_.rnn_states = JSON::Get<std::string_view>(value);
232258
} else
233259
throw JSON::unknown_value_error{};
234260
}
@@ -349,6 +375,40 @@ struct SlidingWindow_Element : JSON::Element {
349375
std::optional<Config::Model::Decoder::SlidingWindow>& v_;
350376
};
351377

378+
struct Encoder_Element : JSON::Element {
379+
explicit Encoder_Element(Config::Model::Encoder& v) : v_{v} {}
380+
381+
void OnValue(std::string_view name, JSON::Value value) override {
382+
if (name == "filename") {
383+
v_.filename = JSON::Get<std::string_view>(value);
384+
} else if (name == "hidden_size") {
385+
v_.hidden_size = static_cast<int>(JSON::Get<double>(value));
386+
} else if (name == "num_key_value_heads") {
387+
v_.num_key_value_heads = static_cast<int>(JSON::Get<double>(value));
388+
} else if (name == "num_hidden_layers") {
389+
v_.num_hidden_layers = static_cast<int>(JSON::Get<double>(value));
390+
} else if (name == "head_size") {
391+
v_.head_size = static_cast<int>(JSON::Get<double>(value));
392+
} else
393+
throw JSON::unknown_value_error{};
394+
}
395+
396+
Element& OnObject(std::string_view name) override {
397+
if (name == "inputs") {
398+
return inputs_;
399+
}
400+
if (name == "outputs") {
401+
return outputs_;
402+
}
403+
throw JSON::unknown_value_error{};
404+
}
405+
406+
private:
407+
Config::Model::Encoder& v_;
408+
Encoder_Inputs_Element inputs_{v_.inputs};
409+
Encoder_Outputs_Element outputs_{v_.outputs};
410+
};
411+
352412
struct Decoder_Element : JSON::Element {
353413
explicit Decoder_Element(Config::Model::Decoder& v) : v_{v} {}
354414

@@ -613,8 +673,8 @@ struct Model_Element : JSON::Element {
613673
}
614674

615675
Element& OnObject(std::string_view name) override {
616-
if (name == "encoder_decoder_init") {
617-
return encoder_decoder_init_;
676+
if (name == "encoder") {
677+
return encoder_;
618678
}
619679
if (name == "decoder") {
620680
return decoder_;
@@ -633,7 +693,7 @@ struct Model_Element : JSON::Element {
633693

634694
private:
635695
Config::Model& v_;
636-
EncoderDecoderInit_Element encoder_decoder_init_{v_.encoder_decoder_init};
696+
Encoder_Element encoder_{v_.encoder};
637697
Decoder_Element decoder_{v_.decoder};
638698
Int_Array_Element eos_token_id_{v_.eos_token_id};
639699
Vision_Element vision_{v_.vision};

src/config.h

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ struct Config {
2121
static constexpr std::string_view LogitsName = "logits";
2222
static constexpr std::string_view PresentKeyName = "present.%d.key";
2323
static constexpr std::string_view PresentValueName = "present.%d.value";
24+
static constexpr std::string_view RnnStatesName = "rnn_states";
25+
static constexpr std::string_view RnnStatesPrevName = "rnn_states_prev";
26+
static constexpr std::string_view PastKeyValuesLengthName = "past_key_values_length";
27+
static constexpr std::string_view EncoderHiddenStatesName = "encoder_hidden_states";
2428

2529
static constexpr std::string_view InputsEmbedsName = "inputs_embeds";
2630
static constexpr std::string_view CurrentSequenceLengthName = "current_sequence_length";
@@ -43,6 +47,10 @@ struct Config {
4347
static constexpr std::string_view AudioProjectionModeName = "audio_projection_mode";
4448
static constexpr std::string_view AudioFeaturesName = "audio_features";
4549
static constexpr std::string_view NumAudioTokens = "num_audio_tokens";
50+
51+
// Encoder names
52+
static constexpr std::string_view EncoderOutputsName = "encoder_outputs";
53+
static constexpr std::string_view EncoderAttentionMaskName = "encoder_attention_mask";
4654
};
4755

4856
fs::path config_path; // Path of the config directory
@@ -90,13 +98,25 @@ struct Config {
9098
int context_length{};
9199

92100
// For models like whisper
93-
struct EncoderDecoderInit {
101+
struct Encoder {
94102
std::string filename;
95103

104+
int hidden_size{};
105+
int num_key_value_heads{};
106+
int num_hidden_layers{};
107+
int head_size{};
108+
96109
struct Inputs {
97110
std::string input_features{Defaults::InputFeaturesName};
111+
std::string input_ids{Defaults::InputIdsName};
112+
std::string attention_mask{Defaults::AttentionMaskName};
98113
} inputs;
99-
} encoder_decoder_init;
114+
115+
struct Outputs {
116+
std::string encoder_outputs{Defaults::EncoderOutputsName};
117+
} outputs;
118+
119+
} encoder;
100120

101121
struct Embedding {
102122
std::string filename;
@@ -174,7 +194,11 @@ struct Config {
174194
std::string cross_past_key_names, cross_past_value_names;
175195
std::string current_sequence_length{Defaults::CurrentSequenceLengthName};
176196
std::string past_sequence_length{Defaults::PastSequenceLengthName};
197+
std::string past_key_values_length{Defaults::PastKeyValuesLengthName};
177198
std::string total_sequence_length{Defaults::TotalSequenceLengthName};
199+
std::string encoder_hidden_states{Defaults::EncoderHiddenStatesName};
200+
std::string rnn_prev_states{Defaults::RnnStatesPrevName};
201+
std::string encoder_attention_mask{Defaults::EncoderAttentionMaskName};
178202
} inputs;
179203

180204
struct Outputs {
@@ -183,6 +207,7 @@ struct Config {
183207
std::string present_value_names{Defaults::PresentValueName};
184208
std::string present_names; // When key/value pairs are combined
185209
std::string cross_present_key_names, cross_present_value_names;
210+
std::string rnn_states{Defaults::RnnStatesName};
186211
} outputs;
187212

188213
struct PipelineModel {

src/models/decoder_only.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ DecoderOnly_State::DecoderOnly_State(const DecoderOnly_Model& model, DeviceSpan<
1616
: State{params, model},
1717
model_{model},
1818
kv_cache_(CreateKeyValueCache(*this)),
19-
position_inputs_{model, *this, sequence_lengths_unk} {
19+
position_inputs_{model, *this, sequence_lengths_unk, model_.config_->model.decoder.inputs.attention_mask} {
2020
input_ids_.Add();
2121
position_inputs_.Add();
2222
logits_.Add();

src/models/decoder_only_pipeline.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ DecoderOnlyPipelineState::DecoderOnlyPipelineState(const DecoderOnlyPipelineMode
109109
key_value_cache_{CreateKeyValueCache(*this)},
110110
do_key_value_cache_partial_token_generation_update_{
111111
key_value_cache_ && key_value_cache_->IsPartialTokenGenerationUpdateSupported()},
112-
position_inputs_{CreatePositionInputs(*this, sequence_lengths)} {
112+
position_inputs_{CreatePositionInputs(*this, sequence_lengths, model_.config_->model.decoder.inputs.attention_mask)} {
113113
input_ids_->Add();
114114
position_inputs_->Add();
115115
logits_.Add();

src/models/gpt.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ std::unique_ptr<State> Gpt_Model::CreateState(DeviceSpan<int32_t> sequence_lengt
1616
Gpt_State::Gpt_State(const Gpt_Model& model, DeviceSpan<int32_t> sequence_lengths_unk, const GeneratorParams& params)
1717
: State{params, model},
1818
model_{model},
19-
position_inputs_{model, *this, sequence_lengths_unk} {
19+
position_inputs_{model, *this, sequence_lengths_unk, model_.config_->model.decoder.inputs.attention_mask} {
2020
input_ids_.Add();
2121
position_inputs_.Add();
2222
logits_.Add();

0 commit comments

Comments
 (0)