From d476985199e270a1cbb2bb189ac493d6afae32d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=83=A4=E3=83=B3=E3=83=A4=E3=83=B3?= Date: Thu, 25 Sep 2025 14:05:45 +0800 Subject: [PATCH 1/9] Create median_filter_gpu.cu --- src/ops/median_filter_gpu.cu | 148 +++++++++++++++++++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 src/ops/median_filter_gpu.cu diff --git a/src/ops/median_filter_gpu.cu b/src/ops/median_filter_gpu.cu new file mode 100644 index 000000000..8dcc61031 --- /dev/null +++ b/src/ops/median_filter_gpu.cu @@ -0,0 +1,148 @@ +#include "ctranslate2/ops/median_filter.h" + +#include +#ifdef CUDA_BF16_AVAILABLE +#include +#endif + +#include "type_dispatch.h" +#include "cuda/helpers.h" +#include + +namespace ctranslate2 { + namespace ops { + + constexpr dim_t num_threads = 256; + + // Conversion helpers + __device__ __forceinline__ float to_float(float v) { return v; } + __device__ __forceinline__ float to_float(const half v) { return __half2float(v); } +#ifdef CUDA_BF16_AVAILABLE + __device__ __forceinline__ float to_float(const __nv_bfloat16 v) { return __bfloat162float(v); } +#endif + + __device__ __forceinline__ float from_float(float v) { return v; } + __device__ __forceinline__ half from_float_half(float v) { return __float2half(v); } +#ifdef CUDA_BF16_AVAILABLE + __device__ __forceinline__ __nv_bfloat16 from_float_bf16(float v) { return __float2bfloat16(v); } +#endif + + namespace { + constexpr int kMaxWindow = 129; // supports window widths up to 129 (rank 64) + } + + template + __global__ void sliding_median_lastdim_kernel(const DeviceT* input, + DeviceT* output, + int rows, + int depth, + int width) { + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int total = rows * depth; + if (tid >= total) return; + + int row = tid / depth; + int col = tid % depth; + const int rank = width / 2; + + if (depth <= rank) { + output[tid] = input[tid]; + return; + } + if (width > kMax) { + output[tid] = input[tid]; + return; + } + + float window[kMax]; + + const int row_offset = row * depth; + // Reflection gather. + for (int k = -rank; k <= rank; ++k) { + int read = col + k; + if (read < 0) read = -read; + if (read >= depth) read = 2 * depth - read - 2; + window[k + rank] = to_float(input[row_offset + read]); + } + + // Insertion sort (width is small: <= kMax, typically < 129). + for (int i = 1; i < width; ++i) { + float key = window[i]; + int j = i - 1; + while (j >= 0 && window[j] > key) { + window[j + 1] = window[j]; + --j; + } + window[j + 1] = key; + } + float median = window[rank]; + + if constexpr (std::is_same::value) { + output[tid] = median; + } else if constexpr (std::is_same::value) { + output[tid] = from_float_half(median); +#ifdef CUDA_BF16_AVAILABLE + } else if constexpr (std::is_same::value) { + output[tid] = from_float_bf16(median); +#endif + } + } + + template + void MedianFilter::compute(const StorageView& input, + const dim_t axis_size, + StorageView& output) const { + output.resize_as(input); + const int depth = static_cast(axis_size); + const int rows = static_cast(input.size() / depth); + const int width = static_cast(_width); + const int rank = width / 2; + + // Host-side guards and fallbacks. + if (width <= 1) { + if (&output != &input) + output.copy_from(input); + return; + } + if ((width & 1) == 0) + throw std::invalid_argument("MedianFilter width must be odd"); + if (width > kMaxWindow) + throw std::invalid_argument("MedianFilter width exceeds supported GPU max (" + std::to_string(kMaxWindow) + ")"); + if (depth <= rank) { + if (&output != &input) + output.copy_from(input); + return; + } + + // Grid configuration + const int total = rows * depth; + int blocks = (total + num_threads - 1) / num_threads; + if (blocks > cuda::max_blocks) { + blocks = cuda::max_blocks; + } + + using device_t = cuda::device_type; + const device_t* in_ptr = cuda::device_cast(input.data()); + device_t* out_ptr = cuda::device_cast(output.data()); + sliding_median_lastdim_kernel<<>>( + in_ptr, + out_ptr, + rows, + depth, + width); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaDeviceSynchronize()); + } + +#define DECLARE_IMPL(T) \ + template void \ + MedianFilter::compute(const StorageView& input, \ + const dim_t axis_size, \ + StorageView& output) const; + + DECLARE_IMPL(float) + DECLARE_IMPL(float16_t) + DECLARE_IMPL(bfloat16_t) + + } +} From d51d66143b5f3a4701e1fa590c6220e958fc3cf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=83=A4=E3=83=B3=E3=83=A4=E3=83=B3?= Date: Thu, 25 Sep 2025 14:06:15 +0800 Subject: [PATCH 2/9] Create median_filter_cpu.cc --- src/ops/median_filter_cpu.cc | 60 ++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 src/ops/median_filter_cpu.cc diff --git a/src/ops/median_filter_cpu.cc b/src/ops/median_filter_cpu.cc new file mode 100644 index 000000000..e92f5c95d --- /dev/null +++ b/src/ops/median_filter_cpu.cc @@ -0,0 +1,60 @@ +#include "ctranslate2/ops/median_filter.h" + +#include + +#include +#include "cpu/parallel.h" +#include "type_dispatch.h" + +namespace ctranslate2 { + namespace ops { + + template + void MedianFilter::compute(const StorageView& input, + const dim_t axis_size, + StorageView& output) const { + const auto* src = input.data(); + auto* dst = output.data(); + + + const dim_t depth = axis_size; + const dim_t batch_size = input.size() / depth; + const dim_t rank = _width / 2; + + if (depth <= rank) + return; + + cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) { + StorageView window_storage({_width}, DataType::FLOAT32); + auto* window = window_storage.data(); + + for (dim_t i = begin; i < end; ++i) { + const dim_t offset = i * depth; + const auto* in = src + offset; + auto* out = dst + offset; + + for (dim_t j = 0; j < depth; ++j) { + for (dim_t k = -rank; k <= rank; ++k) { + dim_t read = std::abs(j + k); + if (read >= depth) + read = depth - (read - depth) - 2; + window[k + rank] = in[read]; + } + + std::nth_element(window, window + rank, window + _width); + out[j] = window[rank]; + } + } + }); + } + +#define DECLARE_IMPL(T) \ + template void \ + MedianFilter::compute(const StorageView& input, \ + const dim_t axis_size, \ + StorageView& output) const; + + DECLARE_IMPL(float) + + } +} From f35fb6911498b757ad934c0d7936a700e7091595 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=83=A4=E3=83=B3=E3=83=A4=E3=83=B3?= Date: Thu, 25 Sep 2025 14:07:20 +0800 Subject: [PATCH 3/9] Update median_filter.h to contain CPU and GPU compute call --- include/ctranslate2/ops/median_filter.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/include/ctranslate2/ops/median_filter.h b/include/ctranslate2/ops/median_filter.h index ba5e87f0c..b875fa67f 100644 --- a/include/ctranslate2/ops/median_filter.h +++ b/include/ctranslate2/ops/median_filter.h @@ -1,5 +1,4 @@ #pragma once - #include "op.h" namespace ctranslate2 { @@ -7,11 +6,13 @@ namespace ctranslate2 { class MedianFilter : public Op { public: - MedianFilter(const dim_t width); + explicit MedianFilter(dim_t width); void operator()(const StorageView& input, StorageView& output) const; private: const dim_t _width; + template + void compute(const StorageView& input, const dim_t axis_size, StorageView& output) const; }; } From 9cd7f17f57d8a9b8c9c2395959332a961794988b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=83=A4=E3=83=B3=E3=83=A4=E3=83=B3?= Date: Thu, 25 Sep 2025 14:08:30 +0800 Subject: [PATCH 4/9] Add CPU and GPU of median_filter operator --- CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 62b99d136..0bf0fccca 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -170,6 +170,8 @@ set(SOURCES src/ops/mean.cc src/ops/mean_cpu.cc src/ops/median_filter.cc + src/ops/median_filter_cpu.cc + src/ops/median_filter_gpu.cu src/ops/min_max.cc src/ops/mul.cc src/ops/multinomial.cc From 08e7dc07c6047cd336e287102f6cc1b05af8afd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=83=A4=E3=83=B3=E3=83=A4=E3=83=B3?= Date: Thu, 25 Sep 2025 14:15:47 +0800 Subject: [PATCH 5/9] Update median_filter.cc --- src/ops/median_filter.cc | 48 +++++++--------------------------------- 1 file changed, 8 insertions(+), 40 deletions(-) diff --git a/src/ops/median_filter.cc b/src/ops/median_filter.cc index d83c06e11..ab022faf8 100644 --- a/src/ops/median_filter.cc +++ b/src/ops/median_filter.cc @@ -1,57 +1,25 @@ #include "ctranslate2/ops/median_filter.h" -#include - -#include "cpu/parallel.h" +#include "dispatch.h" namespace ctranslate2 { namespace ops { - MedianFilter::MedianFilter(const dim_t width) + MedianFilter::MedianFilter(dim_t width) : _width(width) - { - } + { + } void MedianFilter::operator()(const StorageView& input, StorageView& output) const { PROFILE("MedianFilter"); - if (input.device() != Device::CPU) - throw std::invalid_argument("MedianFilter currently only supports CPU execution"); + const dim_t axis = input.rank() - 1; + const dim_t axis_size = input.dim(axis); output.resize_as(input); - const dim_t depth = input.dim(-1); - const dim_t batch_size = input.size() / depth; - const dim_t rank = _width / 2; - - if (depth <= rank) - return; - - const auto* src = input.data(); - auto* dst = output.data(); - - cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) { - StorageView window_storage({_width}, DataType::FLOAT32); - auto* window = window_storage.data(); - - for (dim_t i = begin; i < end; ++i) { - const dim_t offset = i * depth; - const auto* in = src + offset; - auto* out = dst + offset; - - for (dim_t j = 0; j < depth; ++j) { - for (dim_t k = -rank; k <= rank; ++k) { - dim_t read = std::abs(j + k); - if (read >= depth) - read = depth - (read - depth) - 2; - window[k + rank] = in[read]; - } - - std::nth_element(window, window + rank, window + _width); - out[j] = window[rank]; - } - } - }); + DEVICE_AND_FLOAT_DISPATCH("MedianFilter", input.device(), input.dtype(), + (compute(input, axis_size, output))); } } From 62baa1055e7fe6c059ea832fd5c9477cf89112be Mon Sep 17 00:00:00 2001 From: Jordi Mas Date: Wed, 17 Dec 2025 13:49:05 +0100 Subject: [PATCH 6/9] Add performance benchmark test --- tests/benchmark_ops.cc | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/benchmark_ops.cc b/tests/benchmark_ops.cc index df3f5d1ba..7d5a2a326 100644 --- a/tests/benchmark_ops.cc +++ b/tests/benchmark_ops.cc @@ -116,6 +116,15 @@ void benchmark_conv1d(Device device) { BENCHMARK(conv_op(x, weight, bias, y), 100); } +void benchmark_median_filter(Device device) { + const dim_t width = 5; + std::vector x_ = rand_vector(100 * 512); + StorageView x({100, 512}, x_, device); + StorageView y(device); + const ops::MedianFilter median_filter_op(width); + BENCHMARK(median_filter_op(x, y), 10000); +} + int main(int argc, char* argv[]) { if (argc < 3) { std::cerr << "usage: " << argv[0] << " op device [dtype]" << std::endl; @@ -153,6 +162,8 @@ int main(int argc, char* argv[]) { benchmark_dequantize(device); else if (op == "conv1d") benchmark_conv1d(device); + else if (op == "median_filter") + benchmark_median_filter(device); return 0; } From c937f949b1f1d6979bd6ac290c7e2378c3a46de6 Mon Sep 17 00:00:00 2001 From: Jordi Mas Date: Wed, 17 Dec 2025 15:01:47 +0100 Subject: [PATCH 7/9] Run the median filter tests also in CUDA device --- tests/ops_test.cc | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/ops_test.cc b/tests/ops_test.cc index c9369fa67..e4843b725 100644 --- a/tests/ops_test.cc +++ b/tests/ops_test.cc @@ -118,29 +118,32 @@ TEST(OpTest, QuantizeINT16) { expect_storage_eq(reverse, input); } -TEST(OpTest, MedianFilter) { +class OpDeviceTest : public ::testing::TestWithParam { +}; + +class OpDeviceFPTest : public ::testing::TestWithParam { +}; + + +TEST_P(OpDeviceTest, MedianFilter) { + Device device = GetParam(); StorageView x({2, 8}, std::vector{ 0.2556743323802948, 0.8028775453567505, 0.3514494299888611, 0.3542254865169525, 0.5881291031837463, 0.1458204835653305, 0.6845740675926208, 0.543143630027771, 0.9039326310157776, 0.38000917434692383, 0.9094009399414062, 0.4063926637172699, - 0.7943458557128906, 0.289182186126709, 0.9932224750518799, 0.01137143187224865}); + 0.7943458557128906, 0.289182186126709, 0.9932224750518799, 0.01137143187224865}, + device); StorageView expected({2, 8}, std::vector{ 0.3514494299888611, 0.3542254865169525, 0.3542254865169525, 0.3542254865169525, 0.3542254865169525, 0.543143630027771, 0.5881291031837463, 0.543143630027771, 0.9039326310157776, 0.4063926637172699, 0.7943458557128906, 0.4063926637172699, - 0.7943458557128906, 0.4063926637172699, 0.7943458557128906, 0.289182186126709}); - StorageView y; + 0.7943458557128906, 0.4063926637172699, 0.7943458557128906, 0.289182186126709}, + device); + StorageView y(device); ops::MedianFilter(5)(x, y); expect_storage_eq(y, expected); } -class OpDeviceTest : public ::testing::TestWithParam { -}; - -class OpDeviceFPTest : public ::testing::TestWithParam { -}; - - TEST_P(OpDeviceTest, Add) { Device device = GetParam(); StorageView a({4}, std::vector{1, 2, 3, 4}, device); From bccfd4301dea097fcbaba6fdd49c1ac61fdba603 Mon Sep 17 00:00:00 2001 From: a2d8a4v Date: Fri, 2 Jan 2026 15:44:43 +0800 Subject: [PATCH 8/9] Add a new model: WavLM --- CMakeLists.txt | 2 + include/ctranslate2/layers/attention.h | 3 + include/ctranslate2/layers/wavlm.h | 112 +++++++++++++++ include/ctranslate2/models/wavlm.h | 68 +++++++++ include/ctranslate2/storage_view.h | 1 + python/cpp/module.cc | 1 + python/cpp/module.h | 1 + python/cpp/wavlm.cc | 123 +++++++++++++++++ python/ctranslate2/converters/transformers.py | 127 ++++++++++++++++- python/ctranslate2/models/__init__.py | 1 + python/ctranslate2/specs/__init__.py | 1 + python/ctranslate2/specs/attention_spec.py | 5 + python/ctranslate2/specs/wavlm_spec.py | 72 ++++++++++ python/tests/test_transformers.py | 81 +++++++++++ src/layers/attention.cc | 97 +++++++++++-- src/layers/wavlm.cc | 130 ++++++++++++++++++ src/models/model_factory.cc | 3 + src/models/wavlm.cc | 122 ++++++++++++++++ src/storage_view.cc | 5 + 19 files changed, 944 insertions(+), 11 deletions(-) create mode 100644 include/ctranslate2/layers/wavlm.h create mode 100644 include/ctranslate2/models/wavlm.h create mode 100644 python/cpp/wavlm.cc create mode 100644 python/ctranslate2/specs/wavlm_spec.py create mode 100644 src/layers/wavlm.cc create mode 100644 src/models/wavlm.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index c46d77374..4dc1d906b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -129,6 +129,7 @@ set(SOURCES src/layers/common.cc src/layers/decoder.cc src/layers/transformer.cc + src/layers/wavlm.cc src/layers/wav2vec2.cc src/layers/wav2vec2bert.cc src/layers/whisper.cc @@ -139,6 +140,7 @@ set(SOURCES src/models/model_reader.cc src/models/sequence_to_sequence.cc src/models/transformer.cc + src/models/wavlm.cc src/models/wav2vec2.cc src/models/wav2vec2bert.cc src/models/whisper.cc diff --git a/include/ctranslate2/layers/attention.h b/include/ctranslate2/layers/attention.h index 570de73f1..b1213a6a1 100644 --- a/include/ctranslate2/layers/attention.h +++ b/include/ctranslate2/layers/attention.h @@ -61,6 +61,7 @@ namespace ctranslate2 { const StorageView* _relative_position_keys; const StorageView* _relative_asymmetric_position_keys; const StorageView* _relative_position_values; + const StorageView* _gru_relative_position_const; dim_t _maximum_relative_position; dim_t _relative_left_max_position; dim_t _relative_right_max_position; @@ -68,6 +69,8 @@ namespace ctranslate2 { const dim_t _cache_time_dim; std::unique_ptr _q_norm; // Query normalization std::unique_ptr _k_norm; // Key normalization + protected: + const std::unique_ptr _gru_relative_position_linear; }; } } diff --git a/include/ctranslate2/layers/wavlm.h b/include/ctranslate2/layers/wavlm.h new file mode 100644 index 000000000..c656980d7 --- /dev/null +++ b/include/ctranslate2/layers/wavlm.h @@ -0,0 +1,112 @@ +#pragma once + +#include +#include "ctranslate2/layers/transformer.h" + +namespace ctranslate2 { + namespace layers { + + class WavLMLayerNormConvLayer : public Layer { + public: + WavLMLayerNormConvLayer(const models::Model& model, + const std::string& scope, + dim_t stride, + dim_t padding); + + void operator()(const StorageView& input, StorageView& output) const; + + DataType output_type() const override { + return _conv.output_type(); + } + + dim_t output_size() const override { + return _conv.output_size(); + } + + private: + dim_t _stride; + dim_t _padding; + const Conv1D _conv; + const LayerNorm _output_norm; + const ops::Transpose _transpose; + const ops::GELU _gelu; + }; + + class WavLMPosConvLayer : public Layer { + public: + WavLMPosConvLayer(const models::Model& model, const std::string& scope); + + void operator()(const StorageView& input, StorageView& output) const; + + DataType output_type() const override { + return _conv.output_type(); + } + + dim_t output_size() const override { + return _conv.output_size(); + } + + private: + const Conv1D _conv; + const ops::Transpose _transpose; + const ops::GELU _gelu; + }; + + class WavLMEncoder : public Layer { + public: + WavLMEncoder(const models::Model& model, const std::string& scope); + + void operator()(const StorageView& features, StorageView& output); + + DataType output_type() const override { + if (_lm_head) { + return (*_lm_head).output_type(); + } + else { + return _output_norm.output_type(); + } + } + + dim_t output_size() const override { + if (_lm_head) { + return (*_lm_head).output_size(); + } + else { + return _output_norm.output_size(); + } + } + + dim_t input_size() const { + return 1024; + } + + bool is_encoded(const StorageView& features) const { + // Input features shape: [batch_size, input_size, input_time] + // Encoder output shape: [batch_size, input_time // 2, output_size] + // + // input_time is variable so we check that dimension 1 is different than its original value. + + return (features.rank() == 3 + && features.dim(2) == output_size() + && features.dim(1) != input_size()); + } + + const StorageView* _upgraded_model; + + private: + const StorageView* _return_logits; + std::optional _feat_layer0; + std::optional>> _feat_layers; + std::optional _fp_norm; + std::optional _fp_ff; + std::optional _pos_conv_embed; + const ops::Transpose _transpose; + const ops::GELU _gelu; + const dim_t _num_heads; + const std::vector> _layers; + const LayerNorm _output_norm; + std::optional _lm_head; + }; + + } +} diff --git a/include/ctranslate2/models/wavlm.h b/include/ctranslate2/models/wavlm.h new file mode 100644 index 000000000..0b797a495 --- /dev/null +++ b/include/ctranslate2/models/wavlm.h @@ -0,0 +1,68 @@ +#pragma once + +//#include "ctranslate2/generation.h" +#include "ctranslate2/layers/wavlm.h" +#include "ctranslate2/models/model.h" +#include "ctranslate2/replica_pool.h" + +namespace ctranslate2 { + namespace models { + + struct WavLMOptions { + // Maximum generation length. + size_t max_length = 448; + + // Randomly sample from the top K candidates (set 0 to sample from the full distribution). + size_t sampling_topk = 1; + + // Maximum index of the first predicted timestamp. + size_t max_initial_timestamp_index = 50; + + // Suppress blank outputs at the beginning of the sampling. + bool suppress_blank = true; + + // List of token IDs to suppress. + // -1 will suppress a default set of symbols as defined in the model config.json file. + std::vector suppress_tokens = {-1}; + }; + + + class WavLMModel : public Model { + public: + const Vocabulary& get_vocabulary() const; + size_t current_spec_revision() const override; + bool is_quantizable(const std::string& variable_name) const override; + bool is_linear_weight(const std::string& variable_name) const override; + std::unique_ptr clone() const override; + + bool use_global_int16_scale() const override { + return false; + } + + protected: + void initialize(ModelReader& model_reader) override; + private: + std::shared_ptr _vocabulary; + }; + + class WavLMReplica : public ModelReplica { + public: + static std::unique_ptr create_from_model(const Model& model); + + WavLMReplica(const std::shared_ptr& model); + StorageView encode(StorageView features, const bool to_cpu); + private: + const std::shared_ptr _model; + const std::unique_ptr _encoder; + + StorageView maybe_encode(StorageView features); + }; + + class WavLM : public ReplicaPool { + public: + using ReplicaPool::ReplicaPool; + std::future encode(const StorageView& features, const bool to_cpu); + }; + + } +} diff --git a/include/ctranslate2/storage_view.h b/include/ctranslate2/storage_view.h index 8834ef651..36a5b2925 100644 --- a/include/ctranslate2/storage_view.h +++ b/include/ctranslate2/storage_view.h @@ -230,6 +230,7 @@ namespace ctranslate2 { template StorageView& fill(T value); StorageView& zero(); + StorageView& one(); StorageView& copy_from(const StorageView& other, bool synchronous = false); diff --git a/python/cpp/module.cc b/python/cpp/module.cc index 550aea5b2..88a0a20a8 100644 --- a/python/cpp/module.cc +++ b/python/cpp/module.cc @@ -86,6 +86,7 @@ PYBIND11_MODULE(_ext, m) ctranslate2::python::register_generator(m); ctranslate2::python::register_encoder(m); ctranslate2::python::register_whisper(m); + ctranslate2::python::register_wavlm(m); ctranslate2::python::register_wav2vec2(m); ctranslate2::python::register_wav2vec2bert(m); ctranslate2::python::register_mpi(m); diff --git a/python/cpp/module.h b/python/cpp/module.h index 71d4b3b29..94e60da0e 100644 --- a/python/cpp/module.h +++ b/python/cpp/module.h @@ -17,6 +17,7 @@ namespace ctranslate2 { void register_translation_stats(py::module& m); void register_translator(py::module& m); void register_whisper(py::module& m); + void register_wavlm(py::module& m); void register_wav2vec2(py::module& m); void register_wav2vec2bert(py::module& m); void register_mpi(py::module& m); diff --git a/python/cpp/wavlm.cc b/python/cpp/wavlm.cc new file mode 100644 index 000000000..d48e5c8c2 --- /dev/null +++ b/python/cpp/wavlm.cc @@ -0,0 +1,123 @@ +#include "module.h" + +#include + +#include "replica_pool.h" + +#include + +namespace ctranslate2 { + namespace python { + + class WavLMWrapper : public ReplicaPoolHelper { + public: + using ReplicaPoolHelper::ReplicaPoolHelper; + + StorageView encode(const StorageView& features, const bool to_cpu) { + std::shared_lock lock(_mutex); + assert_model_is_ready(); + return _pool->encode(features, to_cpu).get(); + } + }; + + void register_wavlm(py::module& m) { + py::class_( + m, "WavLM", + R"pbdoc( + Implements the WavLM speech recognition model published by Microsoft. + )pbdoc") + + .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, bool, py::object>(), + py::arg("model_path"), + py::arg("device")="cpu", + py::kw_only(), + py::arg("device_index")=0, + py::arg("compute_type")="default", + py::arg("inter_threads")=1, + py::arg("intra_threads")=0, + py::arg("max_queued_batches")=0, + py::arg("flash_attention")=false, + py::arg("tensor_parallel")=false, + py::arg("files")=py::none(), + R"pbdoc( + Initializes a WavLM model from a converted model. + + Arguments: + model_path: Path to the CTranslate2 model directory. + device: Device to use (possible values are: cpu, cuda, auto). + device_index: Device IDs where to place this model on. + compute_type: Model computation type or a dictionary mapping a device name + to the computation type (possible values are: default, auto, int8, int8_float32, + int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). + inter_threads: Number of workers to allow executing multiple batches in parallel. + intra_threads: Number of OpenMP threads per worker (0 to use a default value). + max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited, + 0 for an automatic value). When the queue is full, future requests will block + until a free slot is available. + flash_attention: run model with flash attention 2 for self-attention layer + tensor_parallel: run model with tensor parallel mode + files: Load model files from the memory. This argument is a dictionary mapping + file names to file contents as file-like or bytes objects. If this is set, + :obj:`model_path` acts as an identifier for this model. + )pbdoc") + + .def_property_readonly("device", &WavLMWrapper::device, + "Device this model is running on.") + .def_property_readonly("device_index", &WavLMWrapper::device_index, + "List of device IDs where this model is running on.") + .def_property_readonly("compute_type", &WavLMWrapper::compute_type, + "Computation type used by the model.") + .def_property_readonly("num_workers", &WavLMWrapper::num_replicas, + "Number of model workers backing this instance.") + .def_property_readonly("num_queued_batches", &WavLMWrapper::num_queued_batches, + "Number of batches waiting to be processed.") + .def_property_readonly("tensor_parallel", &WavLMWrapper::tensor_parallel, + "Run model with tensor parallel mode.") + .def_property_readonly("num_active_batches", &WavLMWrapper::num_active_batches, + "Number of batches waiting to be processed or currently processed.") + + .def("encode", &WavLMWrapper::encode, + py::arg("features"), + py::arg("to_cpu")=false, + py::call_guard(), + R"pbdoc( + Encodes the input features. + + Arguments: + features: hidden_states (up to v.4.3.1, https://github.com/OpenNMT/CTranslate2/blob/59c7dda738892df7a064aa360d0e45a4c3840b07/python/tests/test_transformers.py#L1028) or + raw audio, as a float array with shape (followed by VAD) + ``[batch_size, 409, 1024]`` or ``[batch_size, 1, 131200]`` + to_cpu: Copy the encoder output to the CPU before returning the value. + + Returns: + The encoder output. + )pbdoc") + + .def("unload_model", &WavLMWrapper::unload_model, + py::arg("to_cpu")=false, + py::call_guard(), + R"pbdoc( + Unloads the model attached to this wavlm but keep enough runtime context + to quickly resume wavlm on the initial device. + + Arguments: + to_cpu: If ``True``, the model is moved to the CPU memory and not fully unloaded. + )pbdoc") + + .def("load_model", &WavLMWrapper::load_model, + py::arg("keep_cache")=false, + py::call_guard(), + R"pbdoc( + Loads the model back to the initial device. + + Arguments: + keep_cache: If ``True``, the model cache in the CPU memory is not deleted if it exists. + )pbdoc") + + .def_property_readonly("model_is_loaded", &WavLMWrapper::model_is_loaded, + "Whether the model is loaded on the initial device and ready to be used.") + ; + } + + } +} diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index 8cd3ac536..b039f1377 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -22,6 +22,7 @@ common_spec, model_spec, transformer_spec, + wavlm_spec, wav2vec2_spec, wav2vec2bert_spec, whisper_spec, @@ -143,9 +144,13 @@ def _load(self): if self._trust_remote_code: tokenizer_kwargs["trust_remote_code"] = self._trust_remote_code - tokenizer = self.load_tokenizer( - tokenizer_class, self._model_name_or_path, **tokenizer_kwargs - ) + try: + tokenizer = self.load_tokenizer( + tokenizer_class, self._model_name_or_path, **tokenizer_kwargs + ) + except: + tokenizer = None + print("Escape tokenizer, which does not exist.") spec = loader(model, tokenizer) @@ -1076,6 +1081,122 @@ def set_common_layers(self, spec, module): self.set_layer_norm(spec.layer_norm, module.layer_norm) + +@register_loader("WavLMConfig") +class WavLMLoader(BartLoader): + @property + def architecture_name(self): + return "WavLMModel" + + def get_model_spec(self, model): + return_hidden = getattr(model.config, "return_hidden", False) + spec = wavlm_spec.WavLMSpec( + model.config.num_feat_extract_layers, + model.encoder.config.num_hidden_layers, + model.encoder.config.num_attention_heads, + return_hidden, + ) + + # layer component name matching (no duplications saving) + for layer in model.encoder.layers: + layer.self_attn = layer.attention + layer.self_attn_layer_norm = layer.layer_norm + layer.activation_fn = layer.feed_forward.intermediate_act_fn + layer.fc1 = layer.feed_forward.intermediate_dense + layer.fc2 = layer.feed_forward.output_dense + + self.set_encoder(spec.encoder, model, model.config) + return spec + + def set_config(self, config, model, tokenizer): + config.layer_norm_epsilon = model.config.layer_norm_eps + + def get_vocabulary(self, model, tokenizer): + return + + def set_vocabulary(self, spec, tokens): + return + + def set_feature_extractor(self, spec, feature_extractor): + spec.feat_layer0.conv.weight = feature_extractor.conv_layers[0].conv.weight + # spec.feat_layer0.conv.bias = feature_extractor.conv_layers[0].conv.bias // wavlm has no bias + self.set_layer_norm( + spec.feat_layer0.layer_norm, feature_extractor.conv_layers[0].layer_norm + ) + for spec_layer, module_layer in zip( + spec.feat_layer, feature_extractor.conv_layers[1:] + ): + spec_layer.conv.weight = module_layer.conv.weight + # spec_layer.conv.bias = module_layer.conv.bias // wavlm has no bias + self.set_layer_norm(spec_layer.layer_norm, module_layer.layer_norm) + + def set_feature_projection(self, spec, feature_projection): + self.set_layer_norm(spec.fp_layer_norm, feature_projection.layer_norm) + self.set_linear(spec.fp_projection, feature_projection.projection) + + def set_pos_conv_embed(self, spec, encoder, config): + # forcing parameters to be set because some transformers version initializes garbage numbers + # conv parameters are float16 so force float32 for the loading + encoder.pos_conv_embed.conv.weight.data = ( + encoder.pos_conv_embed.conv.weight.data.float() + ) + encoder.pos_conv_embed.conv.bias.data = encoder.pos_conv_embed.conv.bias.float() + for param in encoder.pos_conv_embed.parameters(): + param.data = param.data.float() + encoder.pos_conv_embed(torch.randn((1, 1, config.hidden_size))) + spec.pos_conv_embed.conv.weight = encoder.pos_conv_embed.conv.weight + spec.pos_conv_embed.conv.bias = encoder.pos_conv_embed.conv.bias + + def set_encoder(self, spec, model, config): + self.set_feature_extractor(spec, model.feature_extractor) + self.set_feature_projection(spec, model.feature_projection) + self.set_pos_conv_embed(spec, model.encoder, config) + self.set_wavlm_encoder_layer(spec, model.encoder) + + def set_wavlm_encoder_layer(self, spec, encoder): + self.set_common_layers(spec, encoder) + + for layer_index, (layer_spec, layer) in enumerate(zip(spec.layer, encoder.layers)): + self.set_attention( + layer_spec.self_attention, + layer.self_attn, + self_attention=True, + has_rel_attn_embed=(layer_index==0), + ) + self.set_layer_norm( + layer_spec.self_attention.layer_norm, + layer.self_attn_layer_norm, + ) + + self.set_linear(layer_spec.ffn.linear_0, layer.fc1) + self.set_linear(layer_spec.ffn.linear_1, layer.fc2) + self.set_layer_norm(layer_spec.ffn.layer_norm, layer.final_layer_norm) + + def set_attention(self, spec, attention, self_attention=False, has_rel_attn_embed=False): + split_layers = [common_spec.LinearSpec() for _ in range(3)] + self.set_linear(split_layers[0], attention.q_proj) + self.set_linear(split_layers[1], attention.k_proj) + self.set_linear(split_layers[2], attention.v_proj) + + if self_attention: + utils.fuse_linear(spec.linear[0], split_layers) + else: + utils.fuse_linear(spec.linear[0], split_layers[:1]) + utils.fuse_linear(spec.linear[1], split_layers[1:]) + + self.set_linear(spec.linear[-1], attention.out_proj) + + self.set_linear(spec.gru_relative_position_linear, attention.gru_rel_pos_linear) + spec.gru_relative_position_const = attention.gru_rel_pos_const.data # is torch.nn.parameter.Parameter + + if has_rel_attn_embed: + spec.relative_attention_bias = attention.rel_attn_embed.weight + spec.relative_attention_max_distance = np.int32(attention.max_distance) + + def set_common_layers(self, spec, module): + self.set_layer_norm(spec.layer_norm, module.layer_norm) + + @register_loader("Wav2Vec2BertConfig") class Wav2Vec2BertLoader(BartLoader): @property diff --git a/python/ctranslate2/models/__init__.py b/python/ctranslate2/models/__init__.py index 35a3dca37..43855f0a4 100644 --- a/python/ctranslate2/models/__init__.py +++ b/python/ctranslate2/models/__init__.py @@ -5,6 +5,7 @@ try: from ctranslate2._ext import ( Wav2Vec2, + WavLM, Wav2Vec2Bert, Whisper, WhisperGenerationResult, diff --git a/python/ctranslate2/specs/__init__.py b/python/ctranslate2/specs/__init__.py index b4e53fad2..553072f15 100644 --- a/python/ctranslate2/specs/__init__.py +++ b/python/ctranslate2/specs/__init__.py @@ -13,6 +13,7 @@ TransformerEncoderSpec, TransformerSpec, ) +from ctranslate2.specs.wavlm_spec import WavLMSpec from ctranslate2.specs.wav2vec2_spec import Wav2Vec2Spec from ctranslate2.specs.wav2vec2bert_spec import Wav2Vec2BertSpec from ctranslate2.specs.whisper_spec import WhisperSpec diff --git a/python/ctranslate2/specs/attention_spec.py b/python/ctranslate2/specs/attention_spec.py index 97a33b1c2..5198b27ab 100644 --- a/python/ctranslate2/specs/attention_spec.py +++ b/python/ctranslate2/specs/attention_spec.py @@ -21,6 +21,7 @@ def __init__( relative_position=False, relative_asymmetric_position=False, relative_attention_bias=False, + gated_relative_attention_bias=False, rms_norm=False, rotary_dim=None, rotary_interleave=True, @@ -59,6 +60,10 @@ def __init__( self.relative_left_max_position = None self.relative_right_max_position = None + if gated_relative_attention_bias: + self.gru_relative_position_const = None + self.gru_relative_position_linear = common_spec.LinearSpec() + if original_max_position_embeddings != 0: self.original_max_position_embeddings = np.dtype("int32").type( original_max_position_embeddings diff --git a/python/ctranslate2/specs/wavlm_spec.py b/python/ctranslate2/specs/wavlm_spec.py new file mode 100644 index 000000000..b9aa00be2 --- /dev/null +++ b/python/ctranslate2/specs/wavlm_spec.py @@ -0,0 +1,72 @@ +from typing import List, Optional, Tuple + +import numpy as np + +from ctranslate2.specs import common_spec, model_spec, transformer_spec + + +class WavLMConfig(model_spec.ModelConfig): + """Configuration for the WavLM model.""" + + def __init__(self, layer_norm_epsilon: float = None, **kwargs): + super().__init__(layer_norm_epsilon=layer_norm_epsilon, **kwargs) + + +class WavLMSpec(model_spec.LanguageModelSpec): + def __init__( + self, + feat_layers, + num_layers, + num_heads, + return_hidden, + ): + super().__init__() + self.vocab_size = np.dtype("int16").type(0) + self.encoder = WavLMEncoderSpec( + feat_layers, + num_layers, + num_heads, + return_hidden, + ) + + @property + def name(self): + return "WavLMSpec" + + @property + def revision(self): + return 3 + + def get_default_config(self): + return WavLMConfig() + + def get_vocabulary_size(self): + return 0 + + +class WavLMLayerNormConvLayer(model_spec.LayerSpec): + def __init__(self): + self.conv = common_spec.Conv1DSpec() + self.layer_norm = common_spec.LayerNormSpec() + + +class WavLMPosEmbedConvLayer(model_spec.LayerSpec): + def __init__(self): + self.conv = common_spec.Conv1DSpec() + + +class WavLMEncoderSpec(model_spec.LayerSpec): + def __init__(self, feat_layers, num_layers, num_heads, return_hidden): + self.num_heads = np.dtype("int16").type(num_heads) + self.feat_layer0 = WavLMLayerNormConvLayer() + self.feat_layer = [WavLMLayerNormConvLayer() for i in range(feat_layers - 1)] + self.fp_layer_norm = common_spec.LayerNormSpec() + self.fp_projection = common_spec.LinearSpec() + self.pos_conv_embed = WavLMPosEmbedConvLayer() + self.layer_norm = common_spec.LayerNormSpec() + self.layer = [ + transformer_spec.TransformerEncoderLayerSpec(gated_relative_attention_bias=True, + relative_attention_bias=(i == 0)) for i in range(num_layers) + ] + # if not return_hidden: + # self.lm_head = common_spec.LinearSpec() diff --git a/python/tests/test_transformers.py b/python/tests/test_transformers.py index af128ac70..bfc8747f5 100644 --- a/python/tests/test_transformers.py +++ b/python/tests/test_transformers.py @@ -1040,6 +1040,87 @@ def test_transformers_wav2vec2( assert transcription == expected_transcription[0] +class TestWavLM: + @classmethod + def teardown_class(cls): + clear_transformers_cache_in_ci() + + @test_utils.only_on_linux + @test_utils.on_available_devices + @pytest.mark.parametrize( + "model_name,expected_transcription", + [ + ( + "microsoft/wavlm-large", + [ + "MISTER QUILTER IS THE APOSSEL OF THE MIDDLE CLASSES AND" + " WE ARE GLAD TO WELCOME HIS GOSPEL", + ], + ), + ], + ) + def test_transformers_wavlm( + self, + tmp_dir, + device, + model_name, + expected_transcription, + ): + import torch + import transformers + + converter = ctranslate2.converters.TransformersConverter( + model_name, load_as_float16="int8" + ) + output_dir = str(tmp_dir.join("ctranslate2_model")) + output_dir = converter.convert(output_dir) + + wavlm_processor = transformers.WavLMProcessor.from_pretrained(model_name) + wavlm_processor.save_pretrained(output_dir + "/wavlm_processor") + processor = transformers.AutoProcessor.from_pretrained( + output_dir + "/wavlm_processor" + ) + + device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu" + cpu_threads = int(os.environ.get("OMP_NUM_THREADS", 0)) + model = ctranslate2.models.WavLM( + output_dir, + device=device, + device_index=[0], + compute_type="int8", + intra_threads=cpu_threads, + inter_threads=1, + ) + hf_model = transformers.WavLMModel.from_pretrained(model_name) + + speech_array = np.load( + os.path.join(test_utils.get_data_dir(), "audio", "mr_quilter.npy") + ) + input_values = processor( + speech_array, + padding=True, + return_tensors="pt", + sampling_rate=16000, + ).input_values + + hidden_states = np.ascontiguousarray(input_values.unsqueeze(0)) + hidden_states = ctranslate2.StorageView.from_array(hidden_states) + to_cpu = model.device == "cuda" and len(model.device_index) > 1 + output = model.encode(hidden_states, to_cpu=to_cpu) + if model.device == "cuda": + last_hidden_state = torch.as_tensor(output, device=model.device)[0] + else: + last_hidden_state = torch.as_tensor( + np.array(output), dtype=torch.float32, device=model.device + )[0] + + hg_output = hf_model(input_values.unsqueeze(0)) + + similarity = torch.nn.functional.cosine_similarity(last_hidden_state, hg_output.last_hidden_state.flatten(0, -1), dim=0) + + assert similarity == 1.0 + + class TestWav2Vec2Bert: @classmethod def teardown_class(cls): diff --git a/src/layers/attention.cc b/src/layers/attention.cc index 9afd773be..40df896d1 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -2,7 +2,6 @@ #include "ctranslate2/ops/split.h" #include "ctranslate2/utils.h" - #include #include #include @@ -181,10 +180,13 @@ namespace ctranslate2 { const StorageView& keys, const StorageView& values, const StorageView* values_lengths, + const StorageView* q, const StorageView* relative_position_keys, const StorageView* relative_asymmetric_position_keys, const StorageView* relative_position_values, const StorageView* relative_attention_bias, + const Dense* gru_relative_position_linear, + const StorageView* gru_relative_position_const, dim_t relative_left_max_position, dim_t relative_right_max_position, dim_t maximum_relative_position, @@ -230,13 +232,13 @@ namespace ctranslate2 { *relative_asymmetric_position_keys, keys_matmul, output); - if (relative_attention_bias) { + if (relative_attention_bias || (gru_relative_position_linear && gru_relative_position_const)) { StorageView local_position_bias(output.dtype(), output.device()); if (!position_bias) position_bias = &local_position_bias; - if (position_bias->empty()) { + if (position_bias->empty() && relative_attention_bias) { const dim_t query_length = queries.dim(2); const dim_t key_length = keys.dim(2); *position_bias = compute_relative_bias(*relative_attention_bias, @@ -256,11 +258,77 @@ namespace ctranslate2 { position_bias_per_gpu = &position_bias_tmp; } - DEVICE_AND_TYPE_DISPATCH(output.device(), output.dtype(), - primitives::add_batch_broadcast(position_bias_per_gpu->data(), - output.data(), - position_bias_per_gpu->size(), - output.size())); + if (gru_relative_position_linear && gru_relative_position_const) { + + // wavlm gated relative position bias + // Use dimensions from queries (projected Q) which has shape [B,H,T,dh] + const dim_t bsz = queries.dim(0); + const dim_t num_heads = queries.dim(1); + const dim_t seq_len = queries.dim(2); + const dim_t head_dim = queries.dim(3); + + // Reshape raw hidden states `q` [B,T,D] -> [B,T,H,dh] -> transpose to [B,H,T,dh]. + // This matches HuggingFace WavLM which uses raw hidden_states (not projected Q) + StorageView q_reshaped(output.dtype(), output.device()); + q_reshaped.shallow_copy(const_cast(*q)); + q_reshaped.reshape({bsz, seq_len, num_heads, head_dim}); + + StorageView q_view(output.dtype(), output.device()); + ops::Transpose({0, 2, 1, 3})(q_reshaped, q_view); // [B,T,H,dh] -> [B,H,T,dh] + q_reshaped.release(); + + // Project to 8 dims per head without in-place aliasing. + StorageView relative_position_proj(output.dtype(), output.device()); + (*gru_relative_position_linear)(q_view, relative_position_proj); // [B,H,T,8] + q_view.release(); + + relative_position_proj.reshape({bsz, num_heads, seq_len, 2, 4}); + ops::Sum(-1)(relative_position_proj, relative_position_proj); // [B,H,T,2,1] + relative_position_proj.squeeze(-1); // [B,H,T,2] + ops::Sigmoid()(relative_position_proj, relative_position_proj); + + StorageView gate_a(output.dtype(), output.device()); + StorageView gate_b(output.dtype(), output.device()); + std::vector gates{&gate_a, &gate_b}; + ops::Split(-1, {1, 1})(relative_position_proj, gates); // split by 2 + relative_position_proj.release(); + + StorageView gru_relative_position_const_tiled(gate_b.shape(), output.dtype(), output.device()); + ops::Tile(2, seq_len)(*gru_relative_position_const, gru_relative_position_const_tiled); + ops::Mul()(gate_b, gru_relative_position_const_tiled, gate_b); + gru_relative_position_const_tiled.release(); + StorageView one_const(gate_b.shape(), output.dtype(), output.device()); + one_const.one(); + ops::Sub()(gate_b, one_const, gate_b); + + ops::Mul()(gate_a, gate_b, gate_a); + + StorageView two_const(gate_b.shape(), output.dtype(), output.device()); + two_const.fill(2.0f); + ops::Add()(gate_a, two_const, gate_a); + gate_a.reshape({num_heads, seq_len, 1}); // torch.Size([16, 128, 1]) + gate_b.release(); + + StorageView gated_position_bias(output.dtype(), output.device()); + ops::Tile(2, seq_len)(gate_a, gated_position_bias); // [16, 128, 128] + gate_a.release(); + + ops::Mul()(gated_position_bias, *position_bias_per_gpu, gated_position_bias); // torch.Size([16, 128, 128]) + gated_position_bias.expand_dims(0); // torch.Size([1, 16, 128, 128]) + + DEVICE_AND_TYPE_DISPATCH(output.device(), output.dtype(), + primitives::add_batch_broadcast(gated_position_bias.data(), + output.data(), + gated_position_bias.size(), + output.size())); + + } else { + DEVICE_AND_TYPE_DISPATCH(output.device(), output.dtype(), + primitives::add_batch_broadcast(position_bias_per_gpu->data(), + output.data(), + position_bias_per_gpu->size(), + output.size())); + } } if (alibi) @@ -305,6 +373,8 @@ namespace ctranslate2 { , _relative_position_keys(model.get_variable_if_exists(scope + "/relative_position_keys")) , _relative_asymmetric_position_keys(model.get_variable_if_exists(scope + "/relative_asymmetric_position_keys")) , _relative_position_values(model.get_variable_if_exists(scope + "/relative_position_values")) + , _gru_relative_position_const(model.get_variable_if_exists(scope + "/gru_relative_position_const")) + , _gru_relative_position_linear(build_optional_layer(model, scope + "/gru_relative_position_linear")) , _merge_time_and_head_dims(_multi_query && !_relative_attention_bias && !_relative_position_keys @@ -356,12 +426,20 @@ namespace ctranslate2 { StorageView keys_proj(dtype, device); StorageView values_proj(dtype, device); + StorageView layer_normed_hidden(dtype, device); const StorageView* q = &queries; if (_layer_norm && _pre_norm) { (*_layer_norm)(queries, queries_proj); q = &queries_proj; + if (_gru_relative_position_linear) { + layer_normed_hidden = queries_proj; // deep copy + } } + // Point q to the saved copy if we have gru_relative_position_linear + const StorageView* q_for_rel_pos = (_gru_relative_position_linear && !layer_normed_hidden.empty()) + ? &layer_normed_hidden : q; + _linear[0](*q, fused_proj); dim_t beam_size = 1; @@ -524,10 +602,13 @@ namespace ctranslate2 { keys_proj, values_proj, values_lengths, + q_for_rel_pos, // Pass raw hidden states (after layer_norm) for gated rel pos bias _relative_position_keys, _relative_asymmetric_position_keys, _relative_position_values, _relative_attention_bias, + _gru_relative_position_linear.get(), + _gru_relative_position_const, _relative_left_max_position, _relative_right_max_position, _maximum_relative_position, diff --git a/src/layers/wavlm.cc b/src/layers/wavlm.cc new file mode 100644 index 000000000..b827736d2 --- /dev/null +++ b/src/layers/wavlm.cc @@ -0,0 +1,130 @@ +#include "ctranslate2/layers/wavlm.h" + +namespace ctranslate2 { + namespace layers { + + WavLMLayerNormConvLayer::WavLMLayerNormConvLayer(const models::Model& model, + const std::string& scope, + dim_t stride, + dim_t padding) + : _stride(stride) + , _padding(padding) + , _conv(model, scope + "/conv", _stride, _padding) + , _transpose({0, 2, 1}) + , _output_norm(model, scope + "/layer_norm") { + } + + void WavLMLayerNormConvLayer::operator()(const StorageView& input, StorageView& output) const{ + PROFILE("WavLMLayerNormConvLayer"); + + StorageView buffer(input.dtype(), input.device()); + buffer = std::move(input); + _conv(buffer, output); + _transpose(output, buffer); + _output_norm(buffer, output); + _transpose(output, buffer); + _gelu(buffer, output); + } + + WavLMPosConvLayer::WavLMPosConvLayer(const models::Model& model, const std::string& scope) + : _conv(model, scope + "/conv", /*stride=*/1, /*padding=*/64, /*dilation*/1, /*groups*/16) + , _transpose({0, 2, 1}) { + } + + void WavLMPosConvLayer::operator()(const StorageView& input, StorageView& output) const{ + PROFILE("WavLMPosConvLayer"); + + StorageView buffer(input.dtype(), input.device()); + StorageView buffer2(input.dtype(), input.device()); + _transpose(input, buffer); + _conv(buffer, buffer2); + ops::Split(2, {buffer.dim(2), 1})(buffer2, buffer, output); + _gelu(buffer, buffer); + _transpose(buffer, buffer2); + ops::Add()(input, buffer2, output); + } + + WavLMEncoder::WavLMEncoder(const models::Model& model, const std::string& scope) + : _return_logits(model.get_variable_if_exists(scope + "/lm_head/weight")) + , _upgraded_model(model.get_variable_if_exists(scope + "/fp_projection/weight")) + , _num_heads(model.get_attribute_with_default(scope + "/num_heads", 8)) + , _transpose({0, 2, 1}) + , _layers(build_layers_list(model, + scope + "/layer", + _num_heads, + /*pre_norm=*/true, + ops::ActivationType::GELU)) + , _output_norm(model, scope + "/layer_norm") + { + if (_upgraded_model) { + _feat_layer0.emplace(model, scope + "/feat_layer0", /*stride=*/5, /*padding=*/0); + _feat_layers.emplace(build_layers_list(model, + scope + "/feat_layer", + /*stride=*/2, + /*padding=*/0)); + _fp_norm.emplace(model, scope + "/fp_layer_norm"); + _fp_ff.emplace(model, scope + "/fp_projection", nullptr, true); + _pos_conv_embed.emplace(model, scope + "/pos_conv_embed"); + if (_return_logits) { + _lm_head.emplace(model, scope + "/lm_head", nullptr, true); + } + } + } + + void WavLMEncoder::operator()(const StorageView& features, StorageView& output) { + PROFILE("WavLMEncoder"); + + // SAD in front-end handles the input length + if (features.rank() != 3) + throw std::invalid_argument("Expected input features to have 3 dimensions, but got " + + std::to_string(features.rank()) + + " dimension(s) instead"); + if (_upgraded_model) { + // WavLMFeatureExtractor------------------------------------ + StorageView feat_buffer(features.dtype(), features.device()); + StorageView feat_buffer2(features.dtype(), features.device()); + feat_buffer = std::move(features); + (*_feat_layer0)(feat_buffer, output); //_feat_layer0(feat_buffer, output); + feat_buffer = std::move(output); + for (dim_t l = 0; l < _feat_layers->size(); l++) { + (*_feat_layers.value()[l])(feat_buffer, output); + if (l < _feat_layers->size() - 1 ) { + feat_buffer = std::move(output); + } + } + _transpose(output, feat_buffer); + // WavLMFeatureProjection----------------------------------- + (*_fp_norm)(feat_buffer, output); //_fp_norm(feat_buffer, output); + (*_fp_ff)(output, feat_buffer); //_fp_ff(output, feat_buffer); + // WavLMEncoderStableLayerNorm + // WavLMPositionalConvEmbedding----------------------------- + (*_pos_conv_embed)(feat_buffer, feat_buffer2); //_pos_conv_embed(feat_buffer, feat_buffer2); + // WavLMEncoderLayerStableLayerNorm------------------------- + StorageView position_bias(features.dtype(), features.device()); + for (const auto& layer : _layers) { + (*layer)(feat_buffer2, nullptr, feat_buffer, nullptr, &position_bias); + feat_buffer2 = std::move(feat_buffer); + } + if (_return_logits) { + _output_norm(feat_buffer2, feat_buffer); // default is True + (*_lm_head)(feat_buffer, output); + } + else { + _output_norm(feat_buffer2, output); + } + } + else { // backward compatibility for the previous converted model + StorageView input(features.dtype(), features.device()); + input = features; + StorageView position_bias(features.dtype(), features.device()); + for (const auto& layer : _layers) { + (*layer)(input, nullptr, output, nullptr, &position_bias); + input = std::move(output); + } + + _output_norm(input, output); + } + } + + } +} diff --git a/src/models/model_factory.cc b/src/models/model_factory.cc index 059051f5d..8d403fa64 100644 --- a/src/models/model_factory.cc +++ b/src/models/model_factory.cc @@ -3,6 +3,7 @@ #include #include "ctranslate2/models/whisper.h" +#include "ctranslate2/models/wavlm.h" #include "ctranslate2/models/wav2vec2.h" #include "ctranslate2/models/wav2vec2bert.h" #include "ctranslate2/models/transformer.h" @@ -23,6 +24,8 @@ namespace ctranslate2 { register_model("WhisperSpec"); + register_model("WavLMSpec"); + register_model("Wav2Vec2Spec"); register_model("Wav2Vec2BertSpec"); diff --git a/src/models/wavlm.cc b/src/models/wavlm.cc new file mode 100644 index 000000000..c75f7c4c4 --- /dev/null +++ b/src/models/wavlm.cc @@ -0,0 +1,122 @@ +#include "ctranslate2/models/wavlm.h" + +#include + +#include "ctranslate2/decoding.h" + +#include "dispatch.h" +#include "dtw.h" + +#ifdef CT2_WITH_CUDA +# include "cuda/utils.h" +#endif + + +namespace ctranslate2 { + namespace models { + + const Vocabulary& WavLMModel::get_vocabulary() const { + return *_vocabulary; + } + + size_t WavLMModel::current_spec_revision() const { + return 3; + } + + void WavLMModel::initialize(ModelReader& model_reader) { + VocabularyInfo vocab_info; + vocab_info.unk_token = "[UNK]"; + vocab_info.bos_token = ""; + vocab_info.eos_token = ""; + + _vocabulary = load_vocabulary(model_reader, "vocabulary", std::move(vocab_info)); + if (!_vocabulary) + throw std::runtime_error("Cannot load the vocabulary from the model directory"); + } + + bool WavLMModel::is_quantizable(const std::string& variable_name) const { + return Model::is_quantizable(variable_name); + } + + bool WavLMModel::is_linear_weight(const std::string& variable_name) const { + return is_quantizable(variable_name) && variable_name.find("embeddings") == std::string::npos; + } + + std::unique_ptr WavLMModel::clone() const { + return std::make_unique(*this); + } + + + std::unique_ptr WavLMReplica::create_from_model(const Model& model) { + if (!dynamic_cast(&model)) + throw std::invalid_argument("The model is not a WavLM model"); + + const auto scoped_device_setter = model.get_scoped_device_setter(); + const auto model_ptr = model.shared_from_this(); + const auto concrete_model = std::static_pointer_cast(model_ptr); + return std::make_unique(concrete_model); + } + + WavLMReplica::WavLMReplica(const std::shared_ptr& model) + : ModelReplica(model) + , _model(model) + , _encoder(std::make_unique(*model, "encoder")) + { + } + + StorageView WavLMReplica::encode(StorageView features, const bool to_cpu) { + PROFILE("WavLMReplica::encode"); + +#ifdef CT2_WITH_CUDA + const cuda::UseTrueFp16GemmInScope use_true_fp16_gemm(false); +#endif + + const auto scoped_device_setter = _model->get_scoped_device_setter(); + const Device device = _model->device(); + const DataType dtype = _encoder->output_type(); + features.move_to(device, dtype); + + StorageView encoder_output(dtype, device); + if (_encoder->_upgraded_model) { + encoder_output = maybe_encode(std::move(features)); + } + else { + (*_encoder)(features, encoder_output); + } + + if (to_cpu) { + if (device != Device::CPU) + encoder_output = encoder_output.to(Device::CPU); + + return encoder_output; + } + + // Ensure all operations are finished before returning the output. + synchronize_stream(device); + + return encoder_output; + } + + StorageView WavLMReplica::maybe_encode(StorageView features) { + const Device device = _model->device(); + const DataType dtype = _encoder->output_type(); + + features.move_to(device, dtype); + + if (_encoder->is_encoded(features)) + return features; + + StorageView encoder_output(dtype, device); + (*_encoder)(features, encoder_output); + return encoder_output; + } + + std::future WavLM::encode(const StorageView& features, const bool to_cpu) { + return post( + [features = features.sync_copy(), to_cpu](WavLMReplica& replica) mutable { + return replica.encode(std::move(features), to_cpu); + }); + } + + } +} diff --git a/src/storage_view.cc b/src/storage_view.cc index 0cbdf25c2..d91efb275 100644 --- a/src/storage_view.cc +++ b/src/storage_view.cc @@ -405,6 +405,11 @@ namespace ctranslate2 { return *this; } + StorageView& StorageView::one() { + DEVICE_AND_TYPE_DISPATCH(_device, _dtype, primitives::fill(data(), T(1), _size)); + return *this; + } + template StorageView& StorageView::copy_from(const T* data, dim_t size, Device device, bool synchronous) { if (size != _size) From d4a54519a9a688e6be424bdf4cf436ac97f97b71 Mon Sep 17 00:00:00 2001 From: a2d8a4v Date: Tue, 13 Jan 2026 14:32:16 +0800 Subject: [PATCH 9/9] Remove unnecessary check from WavLM (refs #1977) --- include/ctranslate2/layers/wavlm.h | 11 ----------- include/ctranslate2/models/wavlm.h | 2 -- src/models/wavlm.cc | 21 +-------------------- 3 files changed, 1 insertion(+), 33 deletions(-) diff --git a/include/ctranslate2/layers/wavlm.h b/include/ctranslate2/layers/wavlm.h index c656980d7..74894601b 100644 --- a/include/ctranslate2/layers/wavlm.h +++ b/include/ctranslate2/layers/wavlm.h @@ -80,17 +80,6 @@ namespace ctranslate2 { return 1024; } - bool is_encoded(const StorageView& features) const { - // Input features shape: [batch_size, input_size, input_time] - // Encoder output shape: [batch_size, input_time // 2, output_size] - // - // input_time is variable so we check that dimension 1 is different than its original value. - - return (features.rank() == 3 - && features.dim(2) == output_size() - && features.dim(1) != input_size()); - } - const StorageView* _upgraded_model; private: diff --git a/include/ctranslate2/models/wavlm.h b/include/ctranslate2/models/wavlm.h index 0b797a495..bc6f0b555 100644 --- a/include/ctranslate2/models/wavlm.h +++ b/include/ctranslate2/models/wavlm.h @@ -54,8 +54,6 @@ namespace ctranslate2 { private: const std::shared_ptr _model; const std::unique_ptr _encoder; - - StorageView maybe_encode(StorageView features); }; class WavLM : public ReplicaPool { diff --git a/src/models/wavlm.cc b/src/models/wavlm.cc index c75f7c4c4..9437130da 100644 --- a/src/models/wavlm.cc +++ b/src/models/wavlm.cc @@ -77,12 +77,7 @@ namespace ctranslate2 { features.move_to(device, dtype); StorageView encoder_output(dtype, device); - if (_encoder->_upgraded_model) { - encoder_output = maybe_encode(std::move(features)); - } - else { - (*_encoder)(features, encoder_output); - } + (*_encoder)(features, encoder_output); if (to_cpu) { if (device != Device::CPU) @@ -97,20 +92,6 @@ namespace ctranslate2 { return encoder_output; } - StorageView WavLMReplica::maybe_encode(StorageView features) { - const Device device = _model->device(); - const DataType dtype = _encoder->output_type(); - - features.move_to(device, dtype); - - if (_encoder->is_encoded(features)) - return features; - - StorageView encoder_output(dtype, device); - (*_encoder)(features, encoder_output); - return encoder_output; - } - std::future WavLM::encode(const StorageView& features, const bool to_cpu) { return post( [features = features.sync_copy(), to_cpu](WavLMReplica& replica) mutable {