diff --git a/CMakeLists.txt b/CMakeLists.txt index c46d77374..4dc1d906b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -129,6 +129,7 @@ set(SOURCES src/layers/common.cc src/layers/decoder.cc src/layers/transformer.cc + src/layers/wavlm.cc src/layers/wav2vec2.cc src/layers/wav2vec2bert.cc src/layers/whisper.cc @@ -139,6 +140,7 @@ set(SOURCES src/models/model_reader.cc src/models/sequence_to_sequence.cc src/models/transformer.cc + src/models/wavlm.cc src/models/wav2vec2.cc src/models/wav2vec2bert.cc src/models/whisper.cc diff --git a/README.md b/README.md index 215249c86..b6452a331 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ The following model types are currently supported: * Encoder-decoder models: Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, Whisper T5Gemma * Decoder-only models: GPT-2, GPT-J, GPT-NeoX, OPT, BLOOM, MPT, Llama, Mistral, Gemma, CodeGen, GPTBigCode, Falcon, Qwen2 -* Encoder-only models: BERT, DistilBERT, XLM-RoBERTa +* Encoder-only models: BERT, DistilBERT, XLM-RoBERTa, Wav2vec 2.0, HuBERT, WavLM, Wav2vec2-BERT Compatible models should be first converted into an optimized model format. The library includes converters for multiple frameworks: diff --git a/include/ctranslate2/layers/attention.h b/include/ctranslate2/layers/attention.h index 3d183b07d..068667e07 100644 --- a/include/ctranslate2/layers/attention.h +++ b/include/ctranslate2/layers/attention.h @@ -79,6 +79,7 @@ namespace ctranslate2 { const StorageView* _relative_position_keys; const StorageView* _relative_asymmetric_position_keys; const StorageView* _relative_position_values; + const StorageView* _gru_relative_position_const; dim_t _maximum_relative_position; dim_t _relative_left_max_position; dim_t _relative_right_max_position; @@ -86,6 +87,8 @@ namespace ctranslate2 { const dim_t _cache_time_dim; std::unique_ptr _q_norm; // Query normalization std::unique_ptr _k_norm; // Key normalization + protected: + const std::unique_ptr _gru_relative_position_linear; }; } } diff --git a/include/ctranslate2/layers/wavlm.h b/include/ctranslate2/layers/wavlm.h new file mode 100644 index 000000000..c656980d7 --- /dev/null +++ b/include/ctranslate2/layers/wavlm.h @@ -0,0 +1,112 @@ +#pragma once + +#include +#include "ctranslate2/layers/transformer.h" + +namespace ctranslate2 { + namespace layers { + + class WavLMLayerNormConvLayer : public Layer { + public: + WavLMLayerNormConvLayer(const models::Model& model, + const std::string& scope, + dim_t stride, + dim_t padding); + + void operator()(const StorageView& input, StorageView& output) const; + + DataType output_type() const override { + return _conv.output_type(); + } + + dim_t output_size() const override { + return _conv.output_size(); + } + + private: + dim_t _stride; + dim_t _padding; + const Conv1D _conv; + const LayerNorm _output_norm; + const ops::Transpose _transpose; + const ops::GELU _gelu; + }; + + class WavLMPosConvLayer : public Layer { + public: + WavLMPosConvLayer(const models::Model& model, const std::string& scope); + + void operator()(const StorageView& input, StorageView& output) const; + + DataType output_type() const override { + return _conv.output_type(); + } + + dim_t output_size() const override { + return _conv.output_size(); + } + + private: + const Conv1D _conv; + const ops::Transpose _transpose; + const ops::GELU _gelu; + }; + + class WavLMEncoder : public Layer { + public: + WavLMEncoder(const models::Model& model, const std::string& scope); + + void operator()(const StorageView& features, StorageView& output); + + DataType output_type() const override { + if (_lm_head) { + return (*_lm_head).output_type(); + } + else { + return _output_norm.output_type(); + } + } + + dim_t output_size() const override { + if (_lm_head) { + return (*_lm_head).output_size(); + } + else { + return _output_norm.output_size(); + } + } + + dim_t input_size() const { + return 1024; + } + + bool is_encoded(const StorageView& features) const { + // Input features shape: [batch_size, input_size, input_time] + // Encoder output shape: [batch_size, input_time // 2, output_size] + // + // input_time is variable so we check that dimension 1 is different than its original value. + + return (features.rank() == 3 + && features.dim(2) == output_size() + && features.dim(1) != input_size()); + } + + const StorageView* _upgraded_model; + + private: + const StorageView* _return_logits; + std::optional _feat_layer0; + std::optional>> _feat_layers; + std::optional _fp_norm; + std::optional _fp_ff; + std::optional _pos_conv_embed; + const ops::Transpose _transpose; + const ops::GELU _gelu; + const dim_t _num_heads; + const std::vector> _layers; + const LayerNorm _output_norm; + std::optional _lm_head; + }; + + } +} diff --git a/include/ctranslate2/models/wavlm.h b/include/ctranslate2/models/wavlm.h new file mode 100644 index 000000000..0b797a495 --- /dev/null +++ b/include/ctranslate2/models/wavlm.h @@ -0,0 +1,68 @@ +#pragma once + +//#include "ctranslate2/generation.h" +#include "ctranslate2/layers/wavlm.h" +#include "ctranslate2/models/model.h" +#include "ctranslate2/replica_pool.h" + +namespace ctranslate2 { + namespace models { + + struct WavLMOptions { + // Maximum generation length. + size_t max_length = 448; + + // Randomly sample from the top K candidates (set 0 to sample from the full distribution). + size_t sampling_topk = 1; + + // Maximum index of the first predicted timestamp. + size_t max_initial_timestamp_index = 50; + + // Suppress blank outputs at the beginning of the sampling. + bool suppress_blank = true; + + // List of token IDs to suppress. + // -1 will suppress a default set of symbols as defined in the model config.json file. + std::vector suppress_tokens = {-1}; + }; + + + class WavLMModel : public Model { + public: + const Vocabulary& get_vocabulary() const; + size_t current_spec_revision() const override; + bool is_quantizable(const std::string& variable_name) const override; + bool is_linear_weight(const std::string& variable_name) const override; + std::unique_ptr clone() const override; + + bool use_global_int16_scale() const override { + return false; + } + + protected: + void initialize(ModelReader& model_reader) override; + private: + std::shared_ptr _vocabulary; + }; + + class WavLMReplica : public ModelReplica { + public: + static std::unique_ptr create_from_model(const Model& model); + + WavLMReplica(const std::shared_ptr& model); + StorageView encode(StorageView features, const bool to_cpu); + private: + const std::shared_ptr _model; + const std::unique_ptr _encoder; + + StorageView maybe_encode(StorageView features); + }; + + class WavLM : public ReplicaPool { + public: + using ReplicaPool::ReplicaPool; + std::future encode(const StorageView& features, const bool to_cpu); + }; + + } +} diff --git a/include/ctranslate2/storage_view.h b/include/ctranslate2/storage_view.h index 8834ef651..36a5b2925 100644 --- a/include/ctranslate2/storage_view.h +++ b/include/ctranslate2/storage_view.h @@ -230,6 +230,7 @@ namespace ctranslate2 { template StorageView& fill(T value); StorageView& zero(); + StorageView& one(); StorageView& copy_from(const StorageView& other, bool synchronous = false); diff --git a/python/cpp/module.cc b/python/cpp/module.cc index 550aea5b2..88a0a20a8 100644 --- a/python/cpp/module.cc +++ b/python/cpp/module.cc @@ -86,6 +86,7 @@ PYBIND11_MODULE(_ext, m) ctranslate2::python::register_generator(m); ctranslate2::python::register_encoder(m); ctranslate2::python::register_whisper(m); + ctranslate2::python::register_wavlm(m); ctranslate2::python::register_wav2vec2(m); ctranslate2::python::register_wav2vec2bert(m); ctranslate2::python::register_mpi(m); diff --git a/python/cpp/module.h b/python/cpp/module.h index 71d4b3b29..94e60da0e 100644 --- a/python/cpp/module.h +++ b/python/cpp/module.h @@ -17,6 +17,7 @@ namespace ctranslate2 { void register_translation_stats(py::module& m); void register_translator(py::module& m); void register_whisper(py::module& m); + void register_wavlm(py::module& m); void register_wav2vec2(py::module& m); void register_wav2vec2bert(py::module& m); void register_mpi(py::module& m); diff --git a/python/cpp/wavlm.cc b/python/cpp/wavlm.cc new file mode 100644 index 000000000..d48e5c8c2 --- /dev/null +++ b/python/cpp/wavlm.cc @@ -0,0 +1,123 @@ +#include "module.h" + +#include + +#include "replica_pool.h" + +#include + +namespace ctranslate2 { + namespace python { + + class WavLMWrapper : public ReplicaPoolHelper { + public: + using ReplicaPoolHelper::ReplicaPoolHelper; + + StorageView encode(const StorageView& features, const bool to_cpu) { + std::shared_lock lock(_mutex); + assert_model_is_ready(); + return _pool->encode(features, to_cpu).get(); + } + }; + + void register_wavlm(py::module& m) { + py::class_( + m, "WavLM", + R"pbdoc( + Implements the WavLM speech recognition model published by Microsoft. + )pbdoc") + + .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, bool, py::object>(), + py::arg("model_path"), + py::arg("device")="cpu", + py::kw_only(), + py::arg("device_index")=0, + py::arg("compute_type")="default", + py::arg("inter_threads")=1, + py::arg("intra_threads")=0, + py::arg("max_queued_batches")=0, + py::arg("flash_attention")=false, + py::arg("tensor_parallel")=false, + py::arg("files")=py::none(), + R"pbdoc( + Initializes a WavLM model from a converted model. + + Arguments: + model_path: Path to the CTranslate2 model directory. + device: Device to use (possible values are: cpu, cuda, auto). + device_index: Device IDs where to place this model on. + compute_type: Model computation type or a dictionary mapping a device name + to the computation type (possible values are: default, auto, int8, int8_float32, + int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). + inter_threads: Number of workers to allow executing multiple batches in parallel. + intra_threads: Number of OpenMP threads per worker (0 to use a default value). + max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited, + 0 for an automatic value). When the queue is full, future requests will block + until a free slot is available. + flash_attention: run model with flash attention 2 for self-attention layer + tensor_parallel: run model with tensor parallel mode + files: Load model files from the memory. This argument is a dictionary mapping + file names to file contents as file-like or bytes objects. If this is set, + :obj:`model_path` acts as an identifier for this model. + )pbdoc") + + .def_property_readonly("device", &WavLMWrapper::device, + "Device this model is running on.") + .def_property_readonly("device_index", &WavLMWrapper::device_index, + "List of device IDs where this model is running on.") + .def_property_readonly("compute_type", &WavLMWrapper::compute_type, + "Computation type used by the model.") + .def_property_readonly("num_workers", &WavLMWrapper::num_replicas, + "Number of model workers backing this instance.") + .def_property_readonly("num_queued_batches", &WavLMWrapper::num_queued_batches, + "Number of batches waiting to be processed.") + .def_property_readonly("tensor_parallel", &WavLMWrapper::tensor_parallel, + "Run model with tensor parallel mode.") + .def_property_readonly("num_active_batches", &WavLMWrapper::num_active_batches, + "Number of batches waiting to be processed or currently processed.") + + .def("encode", &WavLMWrapper::encode, + py::arg("features"), + py::arg("to_cpu")=false, + py::call_guard(), + R"pbdoc( + Encodes the input features. + + Arguments: + features: hidden_states (up to v.4.3.1, https://github.com/OpenNMT/CTranslate2/blob/59c7dda738892df7a064aa360d0e45a4c3840b07/python/tests/test_transformers.py#L1028) or + raw audio, as a float array with shape (followed by VAD) + ``[batch_size, 409, 1024]`` or ``[batch_size, 1, 131200]`` + to_cpu: Copy the encoder output to the CPU before returning the value. + + Returns: + The encoder output. + )pbdoc") + + .def("unload_model", &WavLMWrapper::unload_model, + py::arg("to_cpu")=false, + py::call_guard(), + R"pbdoc( + Unloads the model attached to this wavlm but keep enough runtime context + to quickly resume wavlm on the initial device. + + Arguments: + to_cpu: If ``True``, the model is moved to the CPU memory and not fully unloaded. + )pbdoc") + + .def("load_model", &WavLMWrapper::load_model, + py::arg("keep_cache")=false, + py::call_guard(), + R"pbdoc( + Loads the model back to the initial device. + + Arguments: + keep_cache: If ``True``, the model cache in the CPU memory is not deleted if it exists. + )pbdoc") + + .def_property_readonly("model_is_loaded", &WavLMWrapper::model_is_loaded, + "Whether the model is loaded on the initial device and ready to be used.") + ; + } + + } +} diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index 2b3f1d351..7924a61b9 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -24,6 +24,7 @@ transformer_spec, wav2vec2_spec, wav2vec2bert_spec, + wavlm_spec, whisper_spec, ) @@ -144,9 +145,13 @@ def _load(self): if self._trust_remote_code: tokenizer_kwargs["trust_remote_code"] = self._trust_remote_code - tokenizer = self.load_tokenizer( - tokenizer_class, self._model_name_or_path, **tokenizer_kwargs - ) + try: + tokenizer = self.load_tokenizer( + tokenizer_class, self._model_name_or_path, **tokenizer_kwargs + ) + except Exception: + tokenizer = None + print("Escape tokenizer, which does not exist.") spec = loader(model, tokenizer) @@ -1077,6 +1082,126 @@ def set_common_layers(self, spec, module): self.set_layer_norm(spec.layer_norm, module.layer_norm) +@register_loader("WavLMConfig") +class WavLMLoader(BartLoader): + @property + def architecture_name(self): + return "WavLMModel" + + def get_model_spec(self, model): + return_hidden = getattr(model.config, "return_hidden", False) + spec = wavlm_spec.WavLMSpec( + model.config.num_feat_extract_layers, + model.encoder.config.num_hidden_layers, + model.encoder.config.num_attention_heads, + return_hidden, + ) + + # layer component name matching (no duplications saving) + for layer in model.encoder.layers: + layer.self_attn = layer.attention + layer.self_attn_layer_norm = layer.layer_norm + layer.activation_fn = layer.feed_forward.intermediate_act_fn + layer.fc1 = layer.feed_forward.intermediate_dense + layer.fc2 = layer.feed_forward.output_dense + + self.set_encoder(spec.encoder, model, model.config) + return spec + + def set_config(self, config, model, tokenizer): + config.layer_norm_epsilon = model.config.layer_norm_eps + + def get_vocabulary(self, model, tokenizer): + return + + def set_vocabulary(self, spec, tokens): + return + + def set_feature_extractor(self, spec, feature_extractor): + spec.feat_layer0.conv.weight = feature_extractor.conv_layers[0].conv.weight + # wavlm has no bias in conv + self.set_layer_norm( + spec.feat_layer0.layer_norm, feature_extractor.conv_layers[0].layer_norm + ) + for spec_layer, module_layer in zip( + spec.feat_layer, feature_extractor.conv_layers[1:] + ): + spec_layer.conv.weight = module_layer.conv.weight + # spec_layer.conv.bias = module_layer.conv.bias // wavlm has no bias + self.set_layer_norm(spec_layer.layer_norm, module_layer.layer_norm) + + def set_feature_projection(self, spec, feature_projection): + self.set_layer_norm(spec.fp_layer_norm, feature_projection.layer_norm) + self.set_linear(spec.fp_projection, feature_projection.projection) + + def set_pos_conv_embed(self, spec, encoder, config): + # forcing parameters to be set because some transformers version initializes garbage numbers + # conv parameters are float16 so force float32 for the loading + encoder.pos_conv_embed.conv.weight.data = ( + encoder.pos_conv_embed.conv.weight.data.float() + ) + encoder.pos_conv_embed.conv.bias.data = encoder.pos_conv_embed.conv.bias.float() + for param in encoder.pos_conv_embed.parameters(): + param.data = param.data.float() + encoder.pos_conv_embed(torch.randn((1, 1, config.hidden_size))) + spec.pos_conv_embed.conv.weight = encoder.pos_conv_embed.conv.weight + spec.pos_conv_embed.conv.bias = encoder.pos_conv_embed.conv.bias + + def set_encoder(self, spec, model, config): + self.set_feature_extractor(spec, model.feature_extractor) + self.set_feature_projection(spec, model.feature_projection) + self.set_pos_conv_embed(spec, model.encoder, config) + self.set_wavlm_encoder_layer(spec, model.encoder) + + def set_wavlm_encoder_layer(self, spec, encoder): + self.set_common_layers(spec, encoder) + + for layer_index, (layer_spec, layer) in enumerate( + zip(spec.layer, encoder.layers) + ): + self.set_attention( + layer_spec.self_attention, + layer.self_attn, + self_attention=True, + has_rel_attn_embed=(layer_index == 0), + ) + self.set_layer_norm( + layer_spec.self_attention.layer_norm, + layer.self_attn_layer_norm, + ) + + self.set_linear(layer_spec.ffn.linear_0, layer.fc1) + self.set_linear(layer_spec.ffn.linear_1, layer.fc2) + self.set_layer_norm(layer_spec.ffn.layer_norm, layer.final_layer_norm) + + def set_attention( + self, spec, attention, self_attention=False, has_rel_attn_embed=False + ): + split_layers = [common_spec.LinearSpec() for _ in range(3)] + self.set_linear(split_layers[0], attention.q_proj) + self.set_linear(split_layers[1], attention.k_proj) + self.set_linear(split_layers[2], attention.v_proj) + + if self_attention: + utils.fuse_linear(spec.linear[0], split_layers) + else: + utils.fuse_linear(spec.linear[0], split_layers[:1]) + utils.fuse_linear(spec.linear[1], split_layers[1:]) + + self.set_linear(spec.linear[-1], attention.out_proj) + + self.set_linear(spec.gru_relative_position_linear, attention.gru_rel_pos_linear) + # which is torch.nn.parameter.Parameter + spec.gru_relative_position_const = attention.gru_rel_pos_const.data + + if has_rel_attn_embed: + spec.relative_attention_bias = attention.rel_attn_embed.weight + spec.relative_attention_max_distance = np.int32(attention.max_distance) + + def set_common_layers(self, spec, module): + self.set_layer_norm(spec.layer_norm, module.layer_norm) + + @register_loader("Wav2Vec2BertConfig") class Wav2Vec2BertLoader(BartLoader): @property diff --git a/python/ctranslate2/models/__init__.py b/python/ctranslate2/models/__init__.py index 35a3dca37..93536253a 100644 --- a/python/ctranslate2/models/__init__.py +++ b/python/ctranslate2/models/__init__.py @@ -6,6 +6,7 @@ from ctranslate2._ext import ( Wav2Vec2, Wav2Vec2Bert, + WavLM, Whisper, WhisperGenerationResult, WhisperGenerationResultAsync, diff --git a/python/ctranslate2/specs/__init__.py b/python/ctranslate2/specs/__init__.py index b4e53fad2..3f93fb6b9 100644 --- a/python/ctranslate2/specs/__init__.py +++ b/python/ctranslate2/specs/__init__.py @@ -15,4 +15,5 @@ ) from ctranslate2.specs.wav2vec2_spec import Wav2Vec2Spec from ctranslate2.specs.wav2vec2bert_spec import Wav2Vec2BertSpec +from ctranslate2.specs.wavlm_spec import WavLMSpec from ctranslate2.specs.whisper_spec import WhisperSpec diff --git a/python/ctranslate2/specs/attention_spec.py b/python/ctranslate2/specs/attention_spec.py index 73a7a5e14..0faf81ac3 100644 --- a/python/ctranslate2/specs/attention_spec.py +++ b/python/ctranslate2/specs/attention_spec.py @@ -21,6 +21,7 @@ def __init__( relative_position=False, relative_asymmetric_position=False, relative_attention_bias=False, + gated_relative_attention_bias=False, rms_norm=False, rotary_dim=None, rotary_interleave=True, @@ -61,6 +62,10 @@ def __init__( self.relative_left_max_position = None self.relative_right_max_position = None + if gated_relative_attention_bias: + self.gru_relative_position_const = None + self.gru_relative_position_linear = common_spec.LinearSpec() + if original_max_position_embeddings != 0: self.original_max_position_embeddings = np.dtype("int32").type( original_max_position_embeddings diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py index 5612bb950..f448d436b 100644 --- a/python/ctranslate2/specs/transformer_spec.py +++ b/python/ctranslate2/specs/transformer_spec.py @@ -290,6 +290,7 @@ def __init__( self, relative_position=False, relative_attention_bias=False, + gated_relative_attention_bias=False, ffn_glu=False, rms_norm=False, num_heads_kv=None, @@ -307,6 +308,7 @@ def __init__( self_attention=True, relative_position=relative_position, relative_attention_bias=relative_attention_bias, + gated_relative_attention_bias=gated_relative_attention_bias, rms_norm=rms_norm, num_heads_kv=num_heads_kv, head_dim=head_dim, diff --git a/python/ctranslate2/specs/wavlm_spec.py b/python/ctranslate2/specs/wavlm_spec.py new file mode 100644 index 000000000..3aa3e9579 --- /dev/null +++ b/python/ctranslate2/specs/wavlm_spec.py @@ -0,0 +1,74 @@ +from typing import List, Optional, Tuple + +import numpy as np + +from ctranslate2.specs import common_spec, model_spec, transformer_spec + + +class WavLMConfig(model_spec.ModelConfig): + """Configuration for the WavLM model.""" + + def __init__(self, layer_norm_epsilon: float = None, **kwargs): + super().__init__(layer_norm_epsilon=layer_norm_epsilon, **kwargs) + + +class WavLMSpec(model_spec.LanguageModelSpec): + def __init__( + self, + feat_layers, + num_layers, + num_heads, + return_hidden, + ): + super().__init__() + self.vocab_size = np.dtype("int16").type(0) + self.encoder = WavLMEncoderSpec( + feat_layers, + num_layers, + num_heads, + return_hidden, + ) + + @property + def name(self): + return "WavLMSpec" + + @property + def revision(self): + return 3 + + def get_default_config(self): + return WavLMConfig() + + def get_vocabulary_size(self): + return 0 + + +class WavLMLayerNormConvLayer(model_spec.LayerSpec): + def __init__(self): + self.conv = common_spec.Conv1DSpec() + self.layer_norm = common_spec.LayerNormSpec() + + +class WavLMPosEmbedConvLayer(model_spec.LayerSpec): + def __init__(self): + self.conv = common_spec.Conv1DSpec() + + +class WavLMEncoderSpec(model_spec.LayerSpec): + def __init__(self, feat_layers, num_layers, num_heads, return_hidden): + self.num_heads = np.dtype("int16").type(num_heads) + self.feat_layer0 = WavLMLayerNormConvLayer() + self.feat_layer = [WavLMLayerNormConvLayer() for i in range(feat_layers - 1)] + self.fp_layer_norm = common_spec.LayerNormSpec() + self.fp_projection = common_spec.LinearSpec() + self.pos_conv_embed = WavLMPosEmbedConvLayer() + self.layer_norm = common_spec.LayerNormSpec() + self.layer = [ + transformer_spec.TransformerEncoderLayerSpec( + gated_relative_attention_bias=True, relative_attention_bias=(i == 0) + ) + for i in range(num_layers) + ] + # if not return_hidden: + # self.lm_head = common_spec.LinearSpec() diff --git a/python/tests/test_transformers.py b/python/tests/test_transformers.py index 05a7416e6..fc1b9d16a 100644 --- a/python/tests/test_transformers.py +++ b/python/tests/test_transformers.py @@ -1047,6 +1047,89 @@ def test_transformers_wav2vec2( assert transcription == expected_transcription[0] +class TestWavLM: + @classmethod + def teardown_class(cls): + clear_transformers_cache_in_ci() + + @test_utils.only_on_linux + @test_utils.on_available_devices + @pytest.mark.parametrize( + "model_name,expected_transcription", + [ + ( + "microsoft/wavlm-large", + [ + "MISTER QUILTER IS THE APOSSEL OF THE MIDDLE CLASSES AND" + " WE ARE GLAD TO WELCOME HIS GOSPEL", + ], + ), + ], + ) + def test_transformers_wavlm( + self, + tmp_dir, + device, + model_name, + expected_transcription, + ): + import torch + import transformers + + converter = ctranslate2.converters.TransformersConverter( + model_name, load_as_float16="int8" + ) + output_dir = str(tmp_dir.join("ctranslate2_model")) + output_dir = converter.convert(output_dir) + + wavlm_processor = transformers.WavLMProcessor.from_pretrained(model_name) + wavlm_processor.save_pretrained(output_dir + "/wavlm_processor") + processor = transformers.AutoProcessor.from_pretrained( + output_dir + "/wavlm_processor" + ) + + device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu" + cpu_threads = int(os.environ.get("OMP_NUM_THREADS", 0)) + model = ctranslate2.models.WavLM( + output_dir, + device=device, + device_index=[0], + compute_type="int8", + intra_threads=cpu_threads, + inter_threads=1, + ) + hf_model = transformers.WavLMModel.from_pretrained(model_name) + + speech_array = np.load( + os.path.join(test_utils.get_data_dir(), "audio", "mr_quilter.npy") + ) + input_values = processor( + speech_array, + padding=True, + return_tensors="pt", + sampling_rate=16000, + ).input_values + + hidden_states = np.ascontiguousarray(input_values.unsqueeze(0)) + hidden_states = ctranslate2.StorageView.from_array(hidden_states) + to_cpu = model.device == "cuda" and len(model.device_index) > 1 + output = model.encode(hidden_states, to_cpu=to_cpu) + if model.device == "cuda": + last_hidden_state = torch.as_tensor(output, device=model.device)[0] + else: + last_hidden_state = torch.as_tensor( + np.array(output), dtype=torch.float32, device=model.device + )[0] + + hg_output = hf_model(input_values.unsqueeze(0)) + + similarity = torch.nn.functional.cosine_similarity( + last_hidden_state, hg_output.last_hidden_state.flatten(0, -1), dim=0 + ) + + assert similarity == 1.0 + + class TestWav2Vec2Bert: @classmethod def teardown_class(cls): diff --git a/src/layers/attention.cc b/src/layers/attention.cc index fe36b0197..13f14a1db 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -2,7 +2,6 @@ #include "ctranslate2/ops/split.h" #include "ctranslate2/utils.h" - #include #include #include @@ -181,10 +180,13 @@ namespace ctranslate2 { const StorageView& keys, const StorageView& values, const StorageView* values_lengths, + const StorageView* q, const StorageView* relative_position_keys, const StorageView* relative_asymmetric_position_keys, const StorageView* relative_position_values, const StorageView* relative_attention_bias, + const Dense* gru_relative_position_linear, + const StorageView* gru_relative_position_const, dim_t relative_left_max_position, dim_t relative_right_max_position, dim_t maximum_relative_position, @@ -230,13 +232,13 @@ namespace ctranslate2 { *relative_asymmetric_position_keys, keys_matmul, output); - if (relative_attention_bias) { + if (relative_attention_bias || (gru_relative_position_linear && gru_relative_position_const)) { StorageView local_position_bias(output.dtype(), output.device()); if (!position_bias) position_bias = &local_position_bias; - if (position_bias->empty()) { + if (position_bias->empty() && relative_attention_bias) { const dim_t query_length = queries.dim(2); const dim_t key_length = keys.dim(2); *position_bias = compute_relative_bias(*relative_attention_bias, @@ -256,11 +258,77 @@ namespace ctranslate2 { position_bias_per_gpu = &position_bias_tmp; } - DEVICE_AND_TYPE_DISPATCH(output.device(), output.dtype(), - primitives::add_batch_broadcast(position_bias_per_gpu->data(), - output.data(), - position_bias_per_gpu->size(), - output.size())); + if (gru_relative_position_linear && gru_relative_position_const) { + + // wavlm gated relative position bias + // Use dimensions from queries (projected Q) which has shape [B,H,T,dh] + const dim_t bsz = queries.dim(0); + const dim_t num_heads = queries.dim(1); + const dim_t seq_len = queries.dim(2); + const dim_t head_dim = queries.dim(3); + + // Reshape raw hidden states `q` [B,T,D] -> [B,T,H,dh] -> transpose to [B,H,T,dh]. + // This matches HuggingFace WavLM which uses raw hidden_states (not projected Q) + StorageView q_reshaped(output.dtype(), output.device()); + q_reshaped.shallow_copy(const_cast(*q)); + q_reshaped.reshape({bsz, seq_len, num_heads, head_dim}); + + StorageView q_view(output.dtype(), output.device()); + ops::Transpose({0, 2, 1, 3})(q_reshaped, q_view); // [B,T,H,dh] -> [B,H,T,dh] + q_reshaped.release(); + + // Project to 8 dims per head without in-place aliasing. + StorageView relative_position_proj(output.dtype(), output.device()); + (*gru_relative_position_linear)(q_view, relative_position_proj); // [B,H,T,8] + q_view.release(); + + relative_position_proj.reshape({bsz, num_heads, seq_len, 2, 4}); + ops::Sum(-1)(relative_position_proj, relative_position_proj); // [B,H,T,2,1] + relative_position_proj.squeeze(-1); // [B,H,T,2] + ops::Sigmoid()(relative_position_proj, relative_position_proj); + + StorageView gate_a(output.dtype(), output.device()); + StorageView gate_b(output.dtype(), output.device()); + std::vector gates{&gate_a, &gate_b}; + ops::Split(-1, {1, 1})(relative_position_proj, gates); // split by 2 + relative_position_proj.release(); + + StorageView gru_relative_position_const_tiled(gate_b.shape(), output.dtype(), output.device()); + ops::Tile(2, seq_len)(*gru_relative_position_const, gru_relative_position_const_tiled); + ops::Mul()(gate_b, gru_relative_position_const_tiled, gate_b); + gru_relative_position_const_tiled.release(); + StorageView one_const(gate_b.shape(), output.dtype(), output.device()); + one_const.one(); + ops::Sub()(gate_b, one_const, gate_b); + + ops::Mul()(gate_a, gate_b, gate_a); + + StorageView two_const(gate_b.shape(), output.dtype(), output.device()); + two_const.fill(2.0f); + ops::Add()(gate_a, two_const, gate_a); + gate_a.reshape({num_heads, seq_len, 1}); // torch.Size([16, 128, 1]) + gate_b.release(); + + StorageView gated_position_bias(output.dtype(), output.device()); + ops::Tile(2, seq_len)(gate_a, gated_position_bias); // [16, 128, 128] + gate_a.release(); + + ops::Mul()(gated_position_bias, *position_bias_per_gpu, gated_position_bias); // torch.Size([16, 128, 128]) + gated_position_bias.expand_dims(0); // torch.Size([1, 16, 128, 128]) + + DEVICE_AND_TYPE_DISPATCH(output.device(), output.dtype(), + primitives::add_batch_broadcast(gated_position_bias.data(), + output.data(), + gated_position_bias.size(), + output.size())); + + } else { + DEVICE_AND_TYPE_DISPATCH(output.device(), output.dtype(), + primitives::add_batch_broadcast(position_bias_per_gpu->data(), + output.data(), + position_bias_per_gpu->size(), + output.size())); + } } if (alibi) @@ -305,6 +373,8 @@ namespace ctranslate2 { , _relative_position_keys(model.get_variable_if_exists(scope + "/relative_position_keys")) , _relative_asymmetric_position_keys(model.get_variable_if_exists(scope + "/relative_asymmetric_position_keys")) , _relative_position_values(model.get_variable_if_exists(scope + "/relative_position_values")) + , _gru_relative_position_const(model.get_variable_if_exists(scope + "/gru_relative_position_const")) + , _gru_relative_position_linear(build_optional_layer(model, scope + "/gru_relative_position_linear")) , _merge_time_and_head_dims(_multi_query && !_relative_attention_bias && !_relative_position_keys @@ -460,12 +530,20 @@ namespace ctranslate2 { StorageView keys_proj(dtype, device); StorageView values_proj(dtype, device); + StorageView layer_normed_hidden(dtype, device); const StorageView* q = &queries; if (_layer_norm && _pre_norm) { (*_layer_norm)(queries, queries_proj); q = &queries_proj; + if (_gru_relative_position_linear) { + layer_normed_hidden = queries_proj; // deep copy + } } + // Point q to the saved copy if we have gru_relative_position_linear + const StorageView* q_for_rel_pos = (_gru_relative_position_linear && !layer_normed_hidden.empty()) + ? &layer_normed_hidden : q; + _linear[0](*q, fused_proj); dim_t beam_size = 1; @@ -556,10 +634,13 @@ namespace ctranslate2 { keys_proj, values_proj, values_lengths, + q_for_rel_pos, // Pass raw hidden states (after layer_norm) for gated rel pos bias _relative_position_keys, _relative_asymmetric_position_keys, _relative_position_values, _relative_attention_bias, + _gru_relative_position_linear.get(), + _gru_relative_position_const, _relative_left_max_position, _relative_right_max_position, _maximum_relative_position, diff --git a/src/layers/wavlm.cc b/src/layers/wavlm.cc new file mode 100644 index 000000000..b827736d2 --- /dev/null +++ b/src/layers/wavlm.cc @@ -0,0 +1,130 @@ +#include "ctranslate2/layers/wavlm.h" + +namespace ctranslate2 { + namespace layers { + + WavLMLayerNormConvLayer::WavLMLayerNormConvLayer(const models::Model& model, + const std::string& scope, + dim_t stride, + dim_t padding) + : _stride(stride) + , _padding(padding) + , _conv(model, scope + "/conv", _stride, _padding) + , _transpose({0, 2, 1}) + , _output_norm(model, scope + "/layer_norm") { + } + + void WavLMLayerNormConvLayer::operator()(const StorageView& input, StorageView& output) const{ + PROFILE("WavLMLayerNormConvLayer"); + + StorageView buffer(input.dtype(), input.device()); + buffer = std::move(input); + _conv(buffer, output); + _transpose(output, buffer); + _output_norm(buffer, output); + _transpose(output, buffer); + _gelu(buffer, output); + } + + WavLMPosConvLayer::WavLMPosConvLayer(const models::Model& model, const std::string& scope) + : _conv(model, scope + "/conv", /*stride=*/1, /*padding=*/64, /*dilation*/1, /*groups*/16) + , _transpose({0, 2, 1}) { + } + + void WavLMPosConvLayer::operator()(const StorageView& input, StorageView& output) const{ + PROFILE("WavLMPosConvLayer"); + + StorageView buffer(input.dtype(), input.device()); + StorageView buffer2(input.dtype(), input.device()); + _transpose(input, buffer); + _conv(buffer, buffer2); + ops::Split(2, {buffer.dim(2), 1})(buffer2, buffer, output); + _gelu(buffer, buffer); + _transpose(buffer, buffer2); + ops::Add()(input, buffer2, output); + } + + WavLMEncoder::WavLMEncoder(const models::Model& model, const std::string& scope) + : _return_logits(model.get_variable_if_exists(scope + "/lm_head/weight")) + , _upgraded_model(model.get_variable_if_exists(scope + "/fp_projection/weight")) + , _num_heads(model.get_attribute_with_default(scope + "/num_heads", 8)) + , _transpose({0, 2, 1}) + , _layers(build_layers_list(model, + scope + "/layer", + _num_heads, + /*pre_norm=*/true, + ops::ActivationType::GELU)) + , _output_norm(model, scope + "/layer_norm") + { + if (_upgraded_model) { + _feat_layer0.emplace(model, scope + "/feat_layer0", /*stride=*/5, /*padding=*/0); + _feat_layers.emplace(build_layers_list(model, + scope + "/feat_layer", + /*stride=*/2, + /*padding=*/0)); + _fp_norm.emplace(model, scope + "/fp_layer_norm"); + _fp_ff.emplace(model, scope + "/fp_projection", nullptr, true); + _pos_conv_embed.emplace(model, scope + "/pos_conv_embed"); + if (_return_logits) { + _lm_head.emplace(model, scope + "/lm_head", nullptr, true); + } + } + } + + void WavLMEncoder::operator()(const StorageView& features, StorageView& output) { + PROFILE("WavLMEncoder"); + + // SAD in front-end handles the input length + if (features.rank() != 3) + throw std::invalid_argument("Expected input features to have 3 dimensions, but got " + + std::to_string(features.rank()) + + " dimension(s) instead"); + if (_upgraded_model) { + // WavLMFeatureExtractor------------------------------------ + StorageView feat_buffer(features.dtype(), features.device()); + StorageView feat_buffer2(features.dtype(), features.device()); + feat_buffer = std::move(features); + (*_feat_layer0)(feat_buffer, output); //_feat_layer0(feat_buffer, output); + feat_buffer = std::move(output); + for (dim_t l = 0; l < _feat_layers->size(); l++) { + (*_feat_layers.value()[l])(feat_buffer, output); + if (l < _feat_layers->size() - 1 ) { + feat_buffer = std::move(output); + } + } + _transpose(output, feat_buffer); + // WavLMFeatureProjection----------------------------------- + (*_fp_norm)(feat_buffer, output); //_fp_norm(feat_buffer, output); + (*_fp_ff)(output, feat_buffer); //_fp_ff(output, feat_buffer); + // WavLMEncoderStableLayerNorm + // WavLMPositionalConvEmbedding----------------------------- + (*_pos_conv_embed)(feat_buffer, feat_buffer2); //_pos_conv_embed(feat_buffer, feat_buffer2); + // WavLMEncoderLayerStableLayerNorm------------------------- + StorageView position_bias(features.dtype(), features.device()); + for (const auto& layer : _layers) { + (*layer)(feat_buffer2, nullptr, feat_buffer, nullptr, &position_bias); + feat_buffer2 = std::move(feat_buffer); + } + if (_return_logits) { + _output_norm(feat_buffer2, feat_buffer); // default is True + (*_lm_head)(feat_buffer, output); + } + else { + _output_norm(feat_buffer2, output); + } + } + else { // backward compatibility for the previous converted model + StorageView input(features.dtype(), features.device()); + input = features; + StorageView position_bias(features.dtype(), features.device()); + for (const auto& layer : _layers) { + (*layer)(input, nullptr, output, nullptr, &position_bias); + input = std::move(output); + } + + _output_norm(input, output); + } + } + + } +} diff --git a/src/models/model_factory.cc b/src/models/model_factory.cc index 059051f5d..8d403fa64 100644 --- a/src/models/model_factory.cc +++ b/src/models/model_factory.cc @@ -3,6 +3,7 @@ #include #include "ctranslate2/models/whisper.h" +#include "ctranslate2/models/wavlm.h" #include "ctranslate2/models/wav2vec2.h" #include "ctranslate2/models/wav2vec2bert.h" #include "ctranslate2/models/transformer.h" @@ -23,6 +24,8 @@ namespace ctranslate2 { register_model("WhisperSpec"); + register_model("WavLMSpec"); + register_model("Wav2Vec2Spec"); register_model("Wav2Vec2BertSpec"); diff --git a/src/models/wavlm.cc b/src/models/wavlm.cc new file mode 100644 index 000000000..c75f7c4c4 --- /dev/null +++ b/src/models/wavlm.cc @@ -0,0 +1,122 @@ +#include "ctranslate2/models/wavlm.h" + +#include + +#include "ctranslate2/decoding.h" + +#include "dispatch.h" +#include "dtw.h" + +#ifdef CT2_WITH_CUDA +# include "cuda/utils.h" +#endif + + +namespace ctranslate2 { + namespace models { + + const Vocabulary& WavLMModel::get_vocabulary() const { + return *_vocabulary; + } + + size_t WavLMModel::current_spec_revision() const { + return 3; + } + + void WavLMModel::initialize(ModelReader& model_reader) { + VocabularyInfo vocab_info; + vocab_info.unk_token = "[UNK]"; + vocab_info.bos_token = ""; + vocab_info.eos_token = ""; + + _vocabulary = load_vocabulary(model_reader, "vocabulary", std::move(vocab_info)); + if (!_vocabulary) + throw std::runtime_error("Cannot load the vocabulary from the model directory"); + } + + bool WavLMModel::is_quantizable(const std::string& variable_name) const { + return Model::is_quantizable(variable_name); + } + + bool WavLMModel::is_linear_weight(const std::string& variable_name) const { + return is_quantizable(variable_name) && variable_name.find("embeddings") == std::string::npos; + } + + std::unique_ptr WavLMModel::clone() const { + return std::make_unique(*this); + } + + + std::unique_ptr WavLMReplica::create_from_model(const Model& model) { + if (!dynamic_cast(&model)) + throw std::invalid_argument("The model is not a WavLM model"); + + const auto scoped_device_setter = model.get_scoped_device_setter(); + const auto model_ptr = model.shared_from_this(); + const auto concrete_model = std::static_pointer_cast(model_ptr); + return std::make_unique(concrete_model); + } + + WavLMReplica::WavLMReplica(const std::shared_ptr& model) + : ModelReplica(model) + , _model(model) + , _encoder(std::make_unique(*model, "encoder")) + { + } + + StorageView WavLMReplica::encode(StorageView features, const bool to_cpu) { + PROFILE("WavLMReplica::encode"); + +#ifdef CT2_WITH_CUDA + const cuda::UseTrueFp16GemmInScope use_true_fp16_gemm(false); +#endif + + const auto scoped_device_setter = _model->get_scoped_device_setter(); + const Device device = _model->device(); + const DataType dtype = _encoder->output_type(); + features.move_to(device, dtype); + + StorageView encoder_output(dtype, device); + if (_encoder->_upgraded_model) { + encoder_output = maybe_encode(std::move(features)); + } + else { + (*_encoder)(features, encoder_output); + } + + if (to_cpu) { + if (device != Device::CPU) + encoder_output = encoder_output.to(Device::CPU); + + return encoder_output; + } + + // Ensure all operations are finished before returning the output. + synchronize_stream(device); + + return encoder_output; + } + + StorageView WavLMReplica::maybe_encode(StorageView features) { + const Device device = _model->device(); + const DataType dtype = _encoder->output_type(); + + features.move_to(device, dtype); + + if (_encoder->is_encoded(features)) + return features; + + StorageView encoder_output(dtype, device); + (*_encoder)(features, encoder_output); + return encoder_output; + } + + std::future WavLM::encode(const StorageView& features, const bool to_cpu) { + return post( + [features = features.sync_copy(), to_cpu](WavLMReplica& replica) mutable { + return replica.encode(std::move(features), to_cpu); + }); + } + + } +} diff --git a/src/storage_view.cc b/src/storage_view.cc index 0cbdf25c2..d91efb275 100644 --- a/src/storage_view.cc +++ b/src/storage_view.cc @@ -405,6 +405,11 @@ namespace ctranslate2 { return *this; } + StorageView& StorageView::one() { + DEVICE_AND_TYPE_DISPATCH(_device, _dtype, primitives::fill(data(), T(1), _size)); + return *this; + } + template StorageView& StorageView::copy_from(const T* data, dim_t size, Device device, bool synchronous) { if (size != _size)