Skip to content

Commit 6373848

Browse files
committed
Create ctranslate2 Moonshine implementation.
Adds the following: - c++ moonshine model - pybind for python moonshine model - moonshine model spec - safetensor moonshine model converter - support for GroupNorm-style weights for LayerNorm - support for multi-axis cuda layernorm
1 parent 383d063 commit 6373848

File tree

18 files changed

+871
-5
lines changed

18 files changed

+871
-5
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ set(SOURCES
129129
src/layers/wav2vec2.cc
130130
src/layers/wav2vec2bert.cc
131131
src/layers/whisper.cc
132+
src/layers/moonshine.cc
132133
src/logging.cc
133134
src/models/language_model.cc
134135
src/models/model.cc
@@ -139,6 +140,7 @@ set(SOURCES
139140
src/models/wav2vec2.cc
140141
src/models/wav2vec2bert.cc
141142
src/models/whisper.cc
143+
src/models/moonshine.cc
142144
src/ops/activation.cc
143145
src/ops/add.cc
144146
src/ops/alibi_add.cc
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#include "ctranslate2/layers/transformer.h"
2+
3+
namespace ctranslate2 {
4+
namespace layers {
5+
6+
class MoonshinePreprocessor : public Layer {
7+
public:
8+
MoonshinePreprocessor(const models::Model& model, const std::string& scope);
9+
10+
void operator()(const StorageView& features, StorageView& output);
11+
12+
DataType output_type() const override {
13+
return _conv3.output_type();
14+
}
15+
16+
dim_t output_size() const override {
17+
return _conv3.output_size();
18+
}
19+
20+
dim_t input_size() const {
21+
return _conv1.input_size();
22+
}
23+
private:
24+
const Conv1D _conv1;
25+
const ops::Tanh _tanh;
26+
const LayerNorm _norm;
27+
const Conv1D _conv2;
28+
const ops::GELU _gelu1;
29+
const Conv1D _conv3;
30+
const ops::GELU _gelu2;
31+
const ops::Transpose _transpose;
32+
};
33+
34+
35+
class MoonshineEncoder : public Layer {
36+
public:
37+
MoonshineEncoder(const models::Model& model, const std::string& scope);
38+
39+
void operator()(const StorageView& features, StorageView& output);
40+
41+
DataType output_type() const override {
42+
return _output_norm.output_type();
43+
}
44+
45+
dim_t output_size() const override {
46+
return _output_norm.output_size();
47+
}
48+
49+
bool is_encoded(const StorageView& features) const {
50+
// Input features shape: [batch_size, input_size, input_time]
51+
// Encoder output shape: [batch_size, input_time // 2, output_size]
52+
//
53+
// input_time is variable so we check that dimension 1 is different than its original value.
54+
55+
return (features.rank() == 3
56+
&& features.dim(2) == output_size()
57+
&& features.dim(1) != 1);
58+
}
59+
60+
private:
61+
const dim_t _num_heads;
62+
const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers;
63+
const LayerNorm _output_norm;
64+
};
65+
66+
class MoonshineDecoder : public TransformerDecoder {
67+
public:
68+
using TransformerDecoder::TransformerDecoder;
69+
70+
bool return_normalized_attention() const override {
71+
return false;
72+
}
73+
};
74+
}
75+
}
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#pragma once
2+
3+
#include "ctranslate2/generation.h"
4+
#include "ctranslate2/layers/moonshine.h"
5+
#include "ctranslate2/models/model.h"
6+
#include "ctranslate2/replica_pool.h"
7+
8+
namespace ctranslate2 {
9+
namespace models {
10+
11+
struct MoonshineOptions {
12+
// Beam size to use for beam search (set 1 to run greedy search).
13+
size_t beam_size = 5;
14+
15+
// Beam search patience factor, as described in https://arxiv.org/abs/2204.05424.
16+
// The decoding will continue until beam_size*patience hypotheses are finished.
17+
float patience = 1;
18+
19+
// Exponential penalty applied to the length during beam search.
20+
float length_penalty = 1;
21+
22+
// Penalty applied to the score of previously generated tokens, as described in
23+
// https://arxiv.org/abs/1909.05858 (set > 1 to penalize).
24+
float repetition_penalty = 1;
25+
26+
// Prevent repetitions of ngrams with this size (set 0 to disable).
27+
size_t no_repeat_ngram_size = 0;
28+
29+
// Maximum generation length.
30+
size_t max_length = 448;
31+
32+
// Randomly sample from the top K candidates (set 0 to sample from the full distribution).
33+
size_t sampling_topk = 1;
34+
35+
// High temperatures increase randomness.
36+
float sampling_temperature = 1;
37+
38+
// Number of hypotheses to include in the result.
39+
size_t num_hypotheses = 1;
40+
41+
// Include scores in the result.
42+
bool return_scores = false;
43+
44+
// Suppress blank outputs at the beginning of the sampling.
45+
bool suppress_blank = true;
46+
47+
// List of token IDs to suppress.
48+
// -1 will suppress a default set of symbols as defined in the model config.json file.
49+
std::vector<int> suppress_tokens = {-1};
50+
};
51+
52+
struct MoonshineGenerationResult {
53+
std::vector<std::vector<std::string>> sequences;
54+
std::vector<std::vector<size_t>> sequences_ids;
55+
std::vector<float> scores;
56+
57+
size_t num_sequences() const {
58+
return sequences.size();
59+
}
60+
61+
bool has_scores() const {
62+
return !scores.empty();
63+
}
64+
};
65+
66+
class MoonshineModel : public Model {
67+
public:
68+
const Vocabulary& get_vocabulary() const;
69+
70+
size_t current_spec_revision() const override;
71+
bool is_quantizable(const std::string& variable_name) const override;
72+
bool is_linear_weight(const std::string& variable_name) const override;
73+
std::unique_ptr<Model> clone() const override;
74+
75+
bool use_global_int16_scale() const override {
76+
return false;
77+
}
78+
79+
protected:
80+
void initialize(ModelReader& model_reader) override;
81+
82+
private:
83+
std::shared_ptr<const Vocabulary> _vocabulary;
84+
};
85+
86+
class MoonshineReplica : public ModelReplica {
87+
public:
88+
static std::unique_ptr<MoonshineReplica> create_from_model(const Model& model);
89+
90+
MoonshineReplica(const std::shared_ptr<const MoonshineModel>& model);
91+
92+
StorageView encode(StorageView features, const bool to_cpu);
93+
94+
std::vector<MoonshineGenerationResult>
95+
generate(StorageView features,
96+
const std::vector<std::vector<std::string>>& prompts,
97+
const MoonshineOptions& options);
98+
99+
std::vector<MoonshineGenerationResult>
100+
generate(StorageView features,
101+
const std::vector<std::vector<size_t>>& prompts,
102+
const MoonshineOptions& options);
103+
104+
private:
105+
const std::shared_ptr<const MoonshineModel> _model;
106+
const std::unique_ptr<layers::MoonshinePreprocessor> _preprocessor;
107+
const std::unique_ptr<layers::MoonshineEncoder> _encoder;
108+
const std::unique_ptr<layers::MoonshineDecoder> _decoder;
109+
110+
size_t _sot_id;
111+
size_t _eot_id;
112+
113+
StorageView maybe_encode(StorageView features);
114+
};
115+
116+
class Moonshine : public ReplicaPool<MoonshineReplica> {
117+
public:
118+
using ReplicaPool::ReplicaPool;
119+
120+
std::future<StorageView> encode(const StorageView& features, const bool to_cpu);
121+
122+
std::vector<std::future<MoonshineGenerationResult>>
123+
generate(const StorageView& features,
124+
std::vector<std::vector<std::string>> prompts,
125+
MoonshineOptions options = {});
126+
127+
std::vector<std::future<MoonshineGenerationResult>>
128+
generate(const StorageView& features,
129+
std::vector<std::vector<size_t>> prompts,
130+
MoonshineOptions options = {});
131+
};
132+
133+
}
134+
}

include/ctranslate2/ops/layer_norm.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace ctranslate2 {
77

88
class LayerNorm : public TernaryOp {
99
public:
10-
LayerNorm(const dim_t axis = -1, const float epsilon = 1e-5);
10+
LayerNorm(const dim_t axis = -1, const float epsilon = 1e-5, const bool multi_axis=false);
1111

1212
using TernaryOp::operator();
1313
void operator()(const StorageView& beta,
@@ -32,10 +32,12 @@ namespace ctranslate2 {
3232
const dim_t outer_size,
3333
const dim_t axis_size,
3434
const dim_t inner_size,
35+
const bool multi_axis,
3536
StorageView& output) const;
3637

3738
const dim_t _axis;
3839
const float _epsilon;
40+
const bool _multi_axis;
3941
};
4042

4143
}

python/cpp/module.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,5 @@ PYBIND11_MODULE(_ext, m)
8989
ctranslate2::python::register_wav2vec2(m);
9090
ctranslate2::python::register_wav2vec2bert(m);
9191
ctranslate2::python::register_mpi(m);
92+
ctranslate2::python::register_moonshine(m);
9293
}

python/cpp/module.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace ctranslate2 {
2020
void register_wav2vec2(py::module& m);
2121
void register_wav2vec2bert(py::module& m);
2222
void register_mpi(py::module& m);
23+
void register_moonshine(py::module& m);
2324

2425
}
2526
}

0 commit comments

Comments
 (0)