Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ set(SOURCES
src/layers/common.cc
src/layers/decoder.cc
src/layers/transformer.cc
src/layers/wavlm.cc
src/layers/wav2vec2.cc
src/layers/wav2vec2bert.cc
src/layers/whisper.cc
Expand All @@ -139,6 +140,7 @@ set(SOURCES
src/models/model_reader.cc
src/models/sequence_to_sequence.cc
src/models/transformer.cc
src/models/wavlm.cc
src/models/wav2vec2.cc
src/models/wav2vec2bert.cc
src/models/whisper.cc
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The following model types are currently supported:

* Encoder-decoder models: Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, Whisper T5Gemma
* Decoder-only models: GPT-2, GPT-J, GPT-NeoX, OPT, BLOOM, MPT, Llama, Mistral, Gemma, CodeGen, GPTBigCode, Falcon, Qwen2
* Encoder-only models: BERT, DistilBERT, XLM-RoBERTa
* Encoder-only models: BERT, DistilBERT, XLM-RoBERTa, Wav2vec 2.0, HuBERT, WavLM, Wav2vec2-BERT

Compatible models should be first converted into an optimized model format. The library includes converters for multiple frameworks:

Expand Down
3 changes: 3 additions & 0 deletions include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,16 @@ namespace ctranslate2 {
const StorageView* _relative_position_keys;
const StorageView* _relative_asymmetric_position_keys;
const StorageView* _relative_position_values;
const StorageView* _gru_relative_position_const;
dim_t _maximum_relative_position;
dim_t _relative_left_max_position;
dim_t _relative_right_max_position;
const bool _merge_time_and_head_dims;
const dim_t _cache_time_dim;
std::unique_ptr<const LayerNorm> _q_norm; // Query normalization
std::unique_ptr<const LayerNorm> _k_norm; // Key normalization
protected:
const std::unique_ptr<const Dense> _gru_relative_position_linear;
};
}
}
112 changes: 112 additions & 0 deletions include/ctranslate2/layers/wavlm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#pragma once

#include <optional>
#include "ctranslate2/layers/transformer.h"

namespace ctranslate2 {
namespace layers {

class WavLMLayerNormConvLayer : public Layer {
public:
WavLMLayerNormConvLayer(const models::Model& model,
const std::string& scope,
dim_t stride,
dim_t padding);

void operator()(const StorageView& input, StorageView& output) const;

DataType output_type() const override {
return _conv.output_type();
}

dim_t output_size() const override {
return _conv.output_size();
}

private:
dim_t _stride;
dim_t _padding;
const Conv1D _conv;
const LayerNorm _output_norm;
const ops::Transpose _transpose;
const ops::GELU _gelu;
};

class WavLMPosConvLayer : public Layer {
public:
WavLMPosConvLayer(const models::Model& model, const std::string& scope);

void operator()(const StorageView& input, StorageView& output) const;

DataType output_type() const override {
return _conv.output_type();
}

dim_t output_size() const override {
return _conv.output_size();
}

private:
const Conv1D _conv;
const ops::Transpose _transpose;
const ops::GELU _gelu;
};

class WavLMEncoder : public Layer {
public:
WavLMEncoder(const models::Model& model, const std::string& scope);

void operator()(const StorageView& features, StorageView& output);

DataType output_type() const override {
if (_lm_head) {
return (*_lm_head).output_type();
}
else {
return _output_norm.output_type();
}
}

dim_t output_size() const override {
if (_lm_head) {
return (*_lm_head).output_size();
}
else {
return _output_norm.output_size();
}
}

dim_t input_size() const {
return 1024;
}

bool is_encoded(const StorageView& features) const {
// Input features shape: [batch_size, input_size, input_time]
// Encoder output shape: [batch_size, input_time // 2, output_size]
//
// input_time is variable so we check that dimension 1 is different than its original value.

return (features.rank() == 3
&& features.dim(2) == output_size()
&& features.dim(1) != input_size());
}

const StorageView* _upgraded_model;

private:
const StorageView* _return_logits;
std::optional<WavLMLayerNormConvLayer> _feat_layer0;
std::optional<std::vector<std::unique_ptr<const WavLMLayerNormConvLayer>>> _feat_layers;
std::optional<LayerNorm> _fp_norm;
std::optional<Dense> _fp_ff;
std::optional<WavLMPosConvLayer> _pos_conv_embed;
const ops::Transpose _transpose;
const ops::GELU _gelu;
const dim_t _num_heads;
const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers;
const LayerNorm _output_norm;
std::optional<Dense> _lm_head;
};

}
}
68 changes: 68 additions & 0 deletions include/ctranslate2/models/wavlm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#pragma once

//#include "ctranslate2/generation.h"
#include "ctranslate2/layers/wavlm.h"
#include "ctranslate2/models/model.h"
#include "ctranslate2/replica_pool.h"

namespace ctranslate2 {
namespace models {

struct WavLMOptions {
// Maximum generation length.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we planing to use the WavLMOptions structure?
It is not referenced at the moment.

Copy link
Contributor Author

@a2d8a4v a2d8a4v Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Humm, in fact it is not used at this moment.
I tried microsoft/wavlm-large for the test case, which output the last hidden state alone. It may be useful when someone using wavlm plusing a linear layer (language model) training with CTC loss, which outputs token at inferencing stage.

size_t max_length = 448;

// Randomly sample from the top K candidates (set 0 to sample from the full distribution).
size_t sampling_topk = 1;

// Maximum index of the first predicted timestamp.
size_t max_initial_timestamp_index = 50;

// Suppress blank outputs at the beginning of the sampling.
bool suppress_blank = true;

// List of token IDs to suppress.
// -1 will suppress a default set of symbols as defined in the model config.json file.
std::vector<int> suppress_tokens = {-1};
};


class WavLMModel : public Model {
public:
const Vocabulary& get_vocabulary() const;
size_t current_spec_revision() const override;
bool is_quantizable(const std::string& variable_name) const override;
bool is_linear_weight(const std::string& variable_name) const override;
std::unique_ptr<Model> clone() const override;

bool use_global_int16_scale() const override {
return false;
}

protected:
void initialize(ModelReader& model_reader) override;
private:
std::shared_ptr<const Vocabulary> _vocabulary;
};

class WavLMReplica : public ModelReplica {
public:
static std::unique_ptr<WavLMReplica> create_from_model(const Model& model);

WavLMReplica(const std::shared_ptr<const WavLMModel>& model);
StorageView encode(StorageView features, const bool to_cpu);
private:
const std::shared_ptr<const WavLMModel> _model;
const std::unique_ptr<layers::WavLMEncoder> _encoder;

StorageView maybe_encode(StorageView features);
};

class WavLM : public ReplicaPool<WavLMReplica> {
public:
using ReplicaPool::ReplicaPool;
std::future<StorageView> encode(const StorageView& features, const bool to_cpu);
};

}
}
1 change: 1 addition & 0 deletions include/ctranslate2/storage_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ namespace ctranslate2 {
template <typename T>
StorageView& fill(T value);
StorageView& zero();
StorageView& one();

StorageView& copy_from(const StorageView& other, bool synchronous = false);

Expand Down
1 change: 1 addition & 0 deletions python/cpp/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ PYBIND11_MODULE(_ext, m)
ctranslate2::python::register_generator(m);
ctranslate2::python::register_encoder(m);
ctranslate2::python::register_whisper(m);
ctranslate2::python::register_wavlm(m);
ctranslate2::python::register_wav2vec2(m);
ctranslate2::python::register_wav2vec2bert(m);
ctranslate2::python::register_mpi(m);
Expand Down
1 change: 1 addition & 0 deletions python/cpp/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace ctranslate2 {
void register_translation_stats(py::module& m);
void register_translator(py::module& m);
void register_whisper(py::module& m);
void register_wavlm(py::module& m);
void register_wav2vec2(py::module& m);
void register_wav2vec2bert(py::module& m);
void register_mpi(py::module& m);
Expand Down
123 changes: 123 additions & 0 deletions python/cpp/wavlm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#include "module.h"

#include <ctranslate2/models/wavlm.h>

#include "replica_pool.h"

#include <iostream>

namespace ctranslate2 {
namespace python {

class WavLMWrapper : public ReplicaPoolHelper<models::WavLM> {
public:
using ReplicaPoolHelper::ReplicaPoolHelper;

StorageView encode(const StorageView& features, const bool to_cpu) {
std::shared_lock lock(_mutex);
assert_model_is_ready();
return _pool->encode(features, to_cpu).get();
}
};

void register_wavlm(py::module& m) {
py::class_<WavLMWrapper>(
m, "WavLM",
R"pbdoc(
Implements the WavLM speech recognition model published by Microsoft.
)pbdoc")

.def(py::init<const std::string&, const std::string&, const std::variant<int, std::vector<int>>&, const StringOrMap&, size_t, size_t, long, bool, bool, py::object>(),
py::arg("model_path"),
py::arg("device")="cpu",
py::kw_only(),
py::arg("device_index")=0,
py::arg("compute_type")="default",
py::arg("inter_threads")=1,
py::arg("intra_threads")=0,
py::arg("max_queued_batches")=0,
py::arg("flash_attention")=false,
py::arg("tensor_parallel")=false,
py::arg("files")=py::none(),
R"pbdoc(
Initializes a WavLM model from a converted model.

Arguments:
model_path: Path to the CTranslate2 model directory.
device: Device to use (possible values are: cpu, cuda, auto).
device_index: Device IDs where to place this model on.
compute_type: Model computation type or a dictionary mapping a device name
to the computation type (possible values are: default, auto, int8, int8_float32,
int8_float16, int8_bfloat16, int16, float16, bfloat16, float32).
inter_threads: Number of workers to allow executing multiple batches in parallel.
intra_threads: Number of OpenMP threads per worker (0 to use a default value).
max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited,
0 for an automatic value). When the queue is full, future requests will block
until a free slot is available.
flash_attention: run model with flash attention 2 for self-attention layer
tensor_parallel: run model with tensor parallel mode
files: Load model files from the memory. This argument is a dictionary mapping
file names to file contents as file-like or bytes objects. If this is set,
:obj:`model_path` acts as an identifier for this model.
)pbdoc")

.def_property_readonly("device", &WavLMWrapper::device,
"Device this model is running on.")
.def_property_readonly("device_index", &WavLMWrapper::device_index,
"List of device IDs where this model is running on.")
.def_property_readonly("compute_type", &WavLMWrapper::compute_type,
"Computation type used by the model.")
.def_property_readonly("num_workers", &WavLMWrapper::num_replicas,
"Number of model workers backing this instance.")
.def_property_readonly("num_queued_batches", &WavLMWrapper::num_queued_batches,
"Number of batches waiting to be processed.")
.def_property_readonly("tensor_parallel", &WavLMWrapper::tensor_parallel,
"Run model with tensor parallel mode.")
.def_property_readonly("num_active_batches", &WavLMWrapper::num_active_batches,
"Number of batches waiting to be processed or currently processed.")

.def("encode", &WavLMWrapper::encode,
py::arg("features"),
py::arg("to_cpu")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Encodes the input features.

Arguments:
features: hidden_states (up to v.4.3.1, https://github.com/OpenNMT/CTranslate2/blob/59c7dda738892df7a064aa360d0e45a4c3840b07/python/tests/test_transformers.py#L1028) or
raw audio, as a float array with shape (followed by VAD)
``[batch_size, 409, 1024]`` or ``[batch_size, 1, 131200]``
to_cpu: Copy the encoder output to the CPU before returning the value.

Returns:
The encoder output.
)pbdoc")

.def("unload_model", &WavLMWrapper::unload_model,
py::arg("to_cpu")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Unloads the model attached to this wavlm but keep enough runtime context
to quickly resume wavlm on the initial device.

Arguments:
to_cpu: If ``True``, the model is moved to the CPU memory and not fully unloaded.
)pbdoc")

.def("load_model", &WavLMWrapper::load_model,
py::arg("keep_cache")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Loads the model back to the initial device.

Arguments:
keep_cache: If ``True``, the model cache in the CPU memory is not deleted if it exists.
)pbdoc")

.def_property_readonly("model_is_loaded", &WavLMWrapper::model_is_loaded,
"Whether the model is loaded on the initial device and ready to be used.")
;
}

}
}
Loading
Loading