-
Notifications
You must be signed in to change notification settings - Fork 436
Add a new model: WavLM #1966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
a2d8a4v
wants to merge
19
commits into
OpenNMT:master
Choose a base branch
from
a2d8a4v:feat/wavlm
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add a new model: WavLM #1966
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
d476985
Create median_filter_gpu.cu
a2d8a4v d51d661
Create median_filter_cpu.cc
a2d8a4v f35fb69
Update median_filter.h to contain CPU and GPU compute call
a2d8a4v 9cd7f17
Add CPU and GPU of median_filter operator
a2d8a4v 08e7dc0
Update median_filter.cc
a2d8a4v d645853
Merge branch 'master' into master
jordimas 62baa10
Add performance benchmark test
jordimas c937f94
Run the median filter tests also in CUDA device
jordimas 41c658c
Merge branch 'OpenNMT:master' into master
a2d8a4v fe7a80e
Merge branch 'OpenNMT:master' into master
a2d8a4v bccfd43
Add a new model: WavLM
a2d8a4v d0efe55
fix format of coding
a2d8a4v 0426c16
Add description for encoder-only models
a2d8a4v c1ce83e
Merge branch 'master' into feat/wavlm
jordimas f2d99e2
fix a bug: missing gated_relative_attention_bias argument
a2d8a4v b147d4d
using black to reformat
a2d8a4v 4c75c2f
resorting import
a2d8a4v 753db44
Merge branch 'master' into feat/wavlm
jordimas 488d76c
Merge branch 'master' into feat/wavlm
jordimas File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| #pragma once | ||
|
|
||
| #include <optional> | ||
| #include "ctranslate2/layers/transformer.h" | ||
|
|
||
| namespace ctranslate2 { | ||
| namespace layers { | ||
|
|
||
| class WavLMLayerNormConvLayer : public Layer { | ||
| public: | ||
| WavLMLayerNormConvLayer(const models::Model& model, | ||
| const std::string& scope, | ||
| dim_t stride, | ||
| dim_t padding); | ||
|
|
||
| void operator()(const StorageView& input, StorageView& output) const; | ||
|
|
||
| DataType output_type() const override { | ||
| return _conv.output_type(); | ||
| } | ||
|
|
||
| dim_t output_size() const override { | ||
| return _conv.output_size(); | ||
| } | ||
|
|
||
| private: | ||
| dim_t _stride; | ||
| dim_t _padding; | ||
| const Conv1D _conv; | ||
| const LayerNorm _output_norm; | ||
| const ops::Transpose _transpose; | ||
| const ops::GELU _gelu; | ||
| }; | ||
|
|
||
| class WavLMPosConvLayer : public Layer { | ||
| public: | ||
| WavLMPosConvLayer(const models::Model& model, const std::string& scope); | ||
|
|
||
| void operator()(const StorageView& input, StorageView& output) const; | ||
|
|
||
| DataType output_type() const override { | ||
| return _conv.output_type(); | ||
| } | ||
|
|
||
| dim_t output_size() const override { | ||
| return _conv.output_size(); | ||
| } | ||
|
|
||
| private: | ||
| const Conv1D _conv; | ||
| const ops::Transpose _transpose; | ||
| const ops::GELU _gelu; | ||
| }; | ||
|
|
||
| class WavLMEncoder : public Layer { | ||
| public: | ||
| WavLMEncoder(const models::Model& model, const std::string& scope); | ||
|
|
||
| void operator()(const StorageView& features, StorageView& output); | ||
|
|
||
| DataType output_type() const override { | ||
| if (_lm_head) { | ||
| return (*_lm_head).output_type(); | ||
| } | ||
| else { | ||
| return _output_norm.output_type(); | ||
| } | ||
| } | ||
|
|
||
| dim_t output_size() const override { | ||
| if (_lm_head) { | ||
| return (*_lm_head).output_size(); | ||
| } | ||
| else { | ||
| return _output_norm.output_size(); | ||
| } | ||
| } | ||
|
|
||
| dim_t input_size() const { | ||
| return 1024; | ||
| } | ||
|
|
||
| bool is_encoded(const StorageView& features) const { | ||
| // Input features shape: [batch_size, input_size, input_time] | ||
| // Encoder output shape: [batch_size, input_time // 2, output_size] | ||
| // | ||
| // input_time is variable so we check that dimension 1 is different than its original value. | ||
|
|
||
| return (features.rank() == 3 | ||
| && features.dim(2) == output_size() | ||
| && features.dim(1) != input_size()); | ||
| } | ||
|
|
||
| const StorageView* _upgraded_model; | ||
|
|
||
| private: | ||
| const StorageView* _return_logits; | ||
| std::optional<WavLMLayerNormConvLayer> _feat_layer0; | ||
| std::optional<std::vector<std::unique_ptr<const WavLMLayerNormConvLayer>>> _feat_layers; | ||
| std::optional<LayerNorm> _fp_norm; | ||
| std::optional<Dense> _fp_ff; | ||
| std::optional<WavLMPosConvLayer> _pos_conv_embed; | ||
| const ops::Transpose _transpose; | ||
| const ops::GELU _gelu; | ||
| const dim_t _num_heads; | ||
| const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers; | ||
| const LayerNorm _output_norm; | ||
| std::optional<Dense> _lm_head; | ||
| }; | ||
|
|
||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| #pragma once | ||
|
|
||
| //#include "ctranslate2/generation.h" | ||
| #include "ctranslate2/layers/wavlm.h" | ||
| #include "ctranslate2/models/model.h" | ||
| #include "ctranslate2/replica_pool.h" | ||
|
|
||
| namespace ctranslate2 { | ||
| namespace models { | ||
|
|
||
| struct WavLMOptions { | ||
| // Maximum generation length. | ||
| size_t max_length = 448; | ||
|
|
||
| // Randomly sample from the top K candidates (set 0 to sample from the full distribution). | ||
| size_t sampling_topk = 1; | ||
|
|
||
| // Maximum index of the first predicted timestamp. | ||
| size_t max_initial_timestamp_index = 50; | ||
|
|
||
| // Suppress blank outputs at the beginning of the sampling. | ||
| bool suppress_blank = true; | ||
|
|
||
| // List of token IDs to suppress. | ||
| // -1 will suppress a default set of symbols as defined in the model config.json file. | ||
| std::vector<int> suppress_tokens = {-1}; | ||
| }; | ||
|
|
||
|
|
||
| class WavLMModel : public Model { | ||
| public: | ||
| const Vocabulary& get_vocabulary() const; | ||
| size_t current_spec_revision() const override; | ||
| bool is_quantizable(const std::string& variable_name) const override; | ||
| bool is_linear_weight(const std::string& variable_name) const override; | ||
| std::unique_ptr<Model> clone() const override; | ||
|
|
||
| bool use_global_int16_scale() const override { | ||
| return false; | ||
| } | ||
|
|
||
| protected: | ||
| void initialize(ModelReader& model_reader) override; | ||
| private: | ||
| std::shared_ptr<const Vocabulary> _vocabulary; | ||
| }; | ||
|
|
||
| class WavLMReplica : public ModelReplica { | ||
| public: | ||
| static std::unique_ptr<WavLMReplica> create_from_model(const Model& model); | ||
|
|
||
| WavLMReplica(const std::shared_ptr<const WavLMModel>& model); | ||
| StorageView encode(StorageView features, const bool to_cpu); | ||
| private: | ||
| const std::shared_ptr<const WavLMModel> _model; | ||
| const std::unique_ptr<layers::WavLMEncoder> _encoder; | ||
|
|
||
| StorageView maybe_encode(StorageView features); | ||
| }; | ||
|
|
||
| class WavLM : public ReplicaPool<WavLMReplica> { | ||
| public: | ||
| using ReplicaPool::ReplicaPool; | ||
| std::future<StorageView> encode(const StorageView& features, const bool to_cpu); | ||
| }; | ||
|
|
||
| } | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| #include "module.h" | ||
|
|
||
| #include <ctranslate2/models/wavlm.h> | ||
|
|
||
| #include "replica_pool.h" | ||
|
|
||
| #include <iostream> | ||
|
|
||
| namespace ctranslate2 { | ||
| namespace python { | ||
|
|
||
| class WavLMWrapper : public ReplicaPoolHelper<models::WavLM> { | ||
| public: | ||
| using ReplicaPoolHelper::ReplicaPoolHelper; | ||
|
|
||
| StorageView encode(const StorageView& features, const bool to_cpu) { | ||
| std::shared_lock lock(_mutex); | ||
| assert_model_is_ready(); | ||
| return _pool->encode(features, to_cpu).get(); | ||
| } | ||
| }; | ||
|
|
||
| void register_wavlm(py::module& m) { | ||
| py::class_<WavLMWrapper>( | ||
| m, "WavLM", | ||
| R"pbdoc( | ||
| Implements the WavLM speech recognition model published by Microsoft. | ||
| )pbdoc") | ||
|
|
||
| .def(py::init<const std::string&, const std::string&, const std::variant<int, std::vector<int>>&, const StringOrMap&, size_t, size_t, long, bool, bool, py::object>(), | ||
| py::arg("model_path"), | ||
| py::arg("device")="cpu", | ||
| py::kw_only(), | ||
| py::arg("device_index")=0, | ||
| py::arg("compute_type")="default", | ||
| py::arg("inter_threads")=1, | ||
| py::arg("intra_threads")=0, | ||
| py::arg("max_queued_batches")=0, | ||
| py::arg("flash_attention")=false, | ||
| py::arg("tensor_parallel")=false, | ||
| py::arg("files")=py::none(), | ||
| R"pbdoc( | ||
| Initializes a WavLM model from a converted model. | ||
|
|
||
| Arguments: | ||
| model_path: Path to the CTranslate2 model directory. | ||
| device: Device to use (possible values are: cpu, cuda, auto). | ||
| device_index: Device IDs where to place this model on. | ||
| compute_type: Model computation type or a dictionary mapping a device name | ||
| to the computation type (possible values are: default, auto, int8, int8_float32, | ||
| int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). | ||
| inter_threads: Number of workers to allow executing multiple batches in parallel. | ||
| intra_threads: Number of OpenMP threads per worker (0 to use a default value). | ||
| max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited, | ||
| 0 for an automatic value). When the queue is full, future requests will block | ||
| until a free slot is available. | ||
| flash_attention: run model with flash attention 2 for self-attention layer | ||
| tensor_parallel: run model with tensor parallel mode | ||
| files: Load model files from the memory. This argument is a dictionary mapping | ||
| file names to file contents as file-like or bytes objects. If this is set, | ||
| :obj:`model_path` acts as an identifier for this model. | ||
| )pbdoc") | ||
|
|
||
| .def_property_readonly("device", &WavLMWrapper::device, | ||
| "Device this model is running on.") | ||
| .def_property_readonly("device_index", &WavLMWrapper::device_index, | ||
| "List of device IDs where this model is running on.") | ||
| .def_property_readonly("compute_type", &WavLMWrapper::compute_type, | ||
| "Computation type used by the model.") | ||
| .def_property_readonly("num_workers", &WavLMWrapper::num_replicas, | ||
| "Number of model workers backing this instance.") | ||
| .def_property_readonly("num_queued_batches", &WavLMWrapper::num_queued_batches, | ||
| "Number of batches waiting to be processed.") | ||
| .def_property_readonly("tensor_parallel", &WavLMWrapper::tensor_parallel, | ||
| "Run model with tensor parallel mode.") | ||
| .def_property_readonly("num_active_batches", &WavLMWrapper::num_active_batches, | ||
| "Number of batches waiting to be processed or currently processed.") | ||
|
|
||
| .def("encode", &WavLMWrapper::encode, | ||
| py::arg("features"), | ||
| py::arg("to_cpu")=false, | ||
| py::call_guard<py::gil_scoped_release>(), | ||
| R"pbdoc( | ||
| Encodes the input features. | ||
|
|
||
| Arguments: | ||
| features: hidden_states (up to v.4.3.1, https://github.com/OpenNMT/CTranslate2/blob/59c7dda738892df7a064aa360d0e45a4c3840b07/python/tests/test_transformers.py#L1028) or | ||
| raw audio, as a float array with shape (followed by VAD) | ||
| ``[batch_size, 409, 1024]`` or ``[batch_size, 1, 131200]`` | ||
| to_cpu: Copy the encoder output to the CPU before returning the value. | ||
|
|
||
| Returns: | ||
| The encoder output. | ||
| )pbdoc") | ||
|
|
||
| .def("unload_model", &WavLMWrapper::unload_model, | ||
| py::arg("to_cpu")=false, | ||
| py::call_guard<py::gil_scoped_release>(), | ||
| R"pbdoc( | ||
| Unloads the model attached to this wavlm but keep enough runtime context | ||
| to quickly resume wavlm on the initial device. | ||
|
|
||
| Arguments: | ||
| to_cpu: If ``True``, the model is moved to the CPU memory and not fully unloaded. | ||
| )pbdoc") | ||
|
|
||
| .def("load_model", &WavLMWrapper::load_model, | ||
| py::arg("keep_cache")=false, | ||
| py::call_guard<py::gil_scoped_release>(), | ||
| R"pbdoc( | ||
| Loads the model back to the initial device. | ||
|
|
||
| Arguments: | ||
| keep_cache: If ``True``, the model cache in the CPU memory is not deleted if it exists. | ||
| )pbdoc") | ||
|
|
||
| .def_property_readonly("model_is_loaded", &WavLMWrapper::model_is_loaded, | ||
| "Whether the model is loaded on the initial device and ready to be used.") | ||
| ; | ||
| } | ||
|
|
||
| } | ||
| } |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we planing to use the WavLMOptions structure?
It is not referenced at the moment.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Humm, in fact it is not used at this moment.
I tried microsoft/wavlm-large for the test case, which output the last hidden state alone. It may be useful when someone using wavlm plusing a linear layer (language model) training with CTC loss, which outputs token at inferencing stage.