File tree Expand file tree Collapse file tree 3 files changed +1
-33
lines changed
Expand file tree Collapse file tree 3 files changed +1
-33
lines changed Original file line number Diff line number Diff line change @@ -80,17 +80,6 @@ namespace ctranslate2 {
8080 return 1024 ;
8181 }
8282
83- bool is_encoded (const StorageView& features) const {
84- // Input features shape: [batch_size, input_size, input_time]
85- // Encoder output shape: [batch_size, input_time // 2, output_size]
86- //
87- // input_time is variable so we check that dimension 1 is different than its original value.
88-
89- return (features.rank () == 3
90- && features.dim (2 ) == output_size ()
91- && features.dim (1 ) != input_size ());
92- }
93-
9483 const StorageView* _upgraded_model;
9584
9685 private:
Original file line number Diff line number Diff line change @@ -54,8 +54,6 @@ namespace ctranslate2 {
5454 private:
5555 const std::shared_ptr<const WavLMModel> _model;
5656 const std::unique_ptr<layers::WavLMEncoder> _encoder;
57-
58- StorageView maybe_encode (StorageView features);
5957 };
6058
6159 class WavLM : public ReplicaPool <WavLMReplica> {
Original file line number Diff line number Diff line change @@ -77,12 +77,7 @@ namespace ctranslate2 {
7777 features.move_to (device, dtype);
7878
7979 StorageView encoder_output (dtype, device);
80- if (_encoder->_upgraded_model ) {
81- encoder_output = maybe_encode (std::move (features));
82- }
83- else {
84- (*_encoder)(features, encoder_output);
85- }
80+ (*_encoder)(features, encoder_output);
8681
8782 if (to_cpu) {
8883 if (device != Device::CPU)
@@ -97,20 +92,6 @@ namespace ctranslate2 {
9792 return encoder_output;
9893 }
9994
100- StorageView WavLMReplica::maybe_encode (StorageView features) {
101- const Device device = _model->device ();
102- const DataType dtype = _encoder->output_type ();
103-
104- features.move_to (device, dtype);
105-
106- if (_encoder->is_encoded (features))
107- return features;
108-
109- StorageView encoder_output (dtype, device);
110- (*_encoder)(features, encoder_output);
111- return encoder_output;
112- }
113-
11495 std::future<StorageView> WavLM::encode (const StorageView& features, const bool to_cpu) {
11596 return post <StorageView>(
11697 [features = features.sync_copy (), to_cpu](WavLMReplica& replica) mutable {
You can’t perform that action at this time.
0 commit comments