Skip to content

Commit f046f0e

Browse files
committed
Remove unnecessary check from WavLM (refs #1977)
1 parent f511425 commit f046f0e

File tree

3 files changed

+1
-33
lines changed

3 files changed

+1
-33
lines changed

include/ctranslate2/layers/wavlm.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,6 @@ namespace ctranslate2 {
8080
return 1024;
8181
}
8282

83-
bool is_encoded(const StorageView& features) const {
84-
// Input features shape: [batch_size, input_size, input_time]
85-
// Encoder output shape: [batch_size, input_time // 2, output_size]
86-
//
87-
// input_time is variable so we check that dimension 1 is different than its original value.
88-
89-
return (features.rank() == 3
90-
&& features.dim(2) == output_size()
91-
&& features.dim(1) != input_size());
92-
}
93-
9483
const StorageView* _upgraded_model;
9584

9685
private:

include/ctranslate2/models/wavlm.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ namespace ctranslate2 {
5454
private:
5555
const std::shared_ptr<const WavLMModel> _model;
5656
const std::unique_ptr<layers::WavLMEncoder> _encoder;
57-
58-
StorageView maybe_encode(StorageView features);
5957
};
6058

6159
class WavLM : public ReplicaPool<WavLMReplica> {

src/models/wavlm.cc

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,7 @@ namespace ctranslate2 {
7777
features.move_to(device, dtype);
7878

7979
StorageView encoder_output(dtype, device);
80-
if (_encoder->_upgraded_model) {
81-
encoder_output = maybe_encode(std::move(features));
82-
}
83-
else {
84-
(*_encoder)(features, encoder_output);
85-
}
80+
(*_encoder)(features, encoder_output);
8681

8782
if (to_cpu) {
8883
if (device != Device::CPU)
@@ -97,20 +92,6 @@ namespace ctranslate2 {
9792
return encoder_output;
9893
}
9994

100-
StorageView WavLMReplica::maybe_encode(StorageView features) {
101-
const Device device = _model->device();
102-
const DataType dtype = _encoder->output_type();
103-
104-
features.move_to(device, dtype);
105-
106-
if (_encoder->is_encoded(features))
107-
return features;
108-
109-
StorageView encoder_output(dtype, device);
110-
(*_encoder)(features, encoder_output);
111-
return encoder_output;
112-
}
113-
11495
std::future<StorageView> WavLM::encode(const StorageView& features, const bool to_cpu) {
11596
return post<StorageView>(
11697
[features = features.sync_copy(), to_cpu](WavLMReplica& replica) mutable {

0 commit comments

Comments
 (0)