diff --git a/src/cpp/src/whisper_pipeline_static.cpp b/src/cpp/src/whisper_pipeline_static.cpp index 96622a7fc1..fab5acac30 100644 --- a/src/cpp/src/whisper_pipeline_static.cpp +++ b/src/cpp/src/whisper_pipeline_static.cpp @@ -125,68 +125,57 @@ void update_past_key_value(ov::InferRequest& source, ov::InferRequest& dest, con } void set_decoder_input_ids(ov::InferRequest& decoder, - const std::vector& init_ids) { + const std::vector& init_ids) { auto input_ids_tensor = decoder.get_tensor("input_ids"); const size_t seq_length = input_ids_tensor.get_shape()[1]; OPENVINO_ASSERT(seq_length >= init_ids.size()); - auto input_ids_data = input_ids_tensor.data(); + auto input_ids_data = input_ids_tensor.data(); std::copy(init_ids.begin(), init_ids.end(), input_ids_data); } -int64_t decode(ov::Tensor& encoder_hidden_state, - ov::InferRequest& decoder, - const std::vector& init_ids, - const ov::genai::WhisperGenerationConfig& config, - ov::genai::RawPerfMetrics& raw_metrics, - const bool apply_logit_processors = true, - const bool return_timestamps = false) { - // NB: Fill decoder inputs - encoder_hidden_state.copy_to(decoder.get_tensor("encoder_hidden_states")); - set_decoder_input_ids(decoder, init_ids); - - ov::genai::utils::infer_with_perf_metrics(decoder, raw_metrics); +void process_whisper_logits(ov::Tensor logits, + const ov::genai::WhisperGenerationConfig& config, + const bool return_timestamps, + const std::vector& generated_tokens) { + const bool initial_step = generated_tokens.empty(); - auto output_tensor = decoder.get_tensor("logits"); + if (initial_step) { + ov::genai::do_suppress_tokens(logits, 0, config.begin_suppress_tokens); + } - if (apply_logit_processors) { - ov::genai::do_suppress_tokens(output_tensor, 0, config.begin_suppress_tokens); - ov::genai::do_suppress_tokens(output_tensor, 0, config.suppress_tokens); + ov::genai::do_suppress_tokens(logits, 0, config.suppress_tokens); - if (return_timestamps) { - ov::genai::process_whisper_timestamp_logits(output_tensor, 0, config, {}, true); - } + if (return_timestamps) { + ov::genai::process_whisper_timestamp_logits(logits, 0, config, generated_tokens, initial_step); } - int64_t output_token = ov::genai::utils::argmax(output_tensor, 0); - return output_token; } -int64_t decode_with_past(ov::InferRequest& decoder_with_past, - const int64_t input_id, - const int64_t position_id, - const ov::genai::WhisperGenerationConfig& config, - ov::genai::RawPerfMetrics& raw_metrics, - const bool return_timestamps, - const std::vector& generated_tokens) { - // FIXME: Avoid this cast to i32. Why it's not i64 precision in model? - decoder_with_past.get_tensor("input_ids").data()[0] = static_cast(input_id); +ov::Tensor decode(ov::Tensor& encoder_hidden_state, + ov::InferRequest& decoder, + const std::vector& init_ids, + ov::genai::RawPerfMetrics& raw_metrics) { + // NB: Fill decoder inputs + encoder_hidden_state.copy_to(decoder.get_tensor("encoder_hidden_states")); + set_decoder_input_ids(decoder, init_ids); + + ov::genai::utils::infer_with_perf_metrics(decoder, raw_metrics); + return decoder.get_tensor("logits"); +} + +ov::Tensor decode_with_past(ov::InferRequest& decoder_with_past, + const int64_t input_id, + const int64_t position_id, + ov::genai::RawPerfMetrics& raw_metrics) { + decoder_with_past.get_tensor("input_ids").data()[0] = input_id; decoder_with_past.get_tensor("cache_position").data()[0] = position_id; // FIXME: Is "attention_mask" supposed to be f16? decoder_with_past.get_tensor("attention_mask").data()[position_id - 1] = 0u; ov::genai::utils::infer_with_perf_metrics(decoder_with_past, raw_metrics); - - auto output_tensor = decoder_with_past.get_tensor("logits"); - ov::genai::do_suppress_tokens(output_tensor, 0, config.suppress_tokens); - - if (return_timestamps) { - ov::genai::process_whisper_timestamp_logits(output_tensor, 0, config, generated_tokens); - } - - int64_t output_token = ov::genai::utils::argmax(output_tensor, 0); - return output_token; + return decoder_with_past.get_tensor("logits"); } void zero_past_key_values(ov::InferRequest& request) { @@ -224,7 +213,7 @@ int64_t detect_language(ov::Tensor& encoder_hidden_state, decoder.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state}); - std::vector init_ids{static_cast(config.decoder_start_token_id)}; + std::vector init_ids{config.decoder_start_token_id}; set_decoder_input_ids(decoder, init_ids); const auto infer_start = std::chrono::steady_clock::now(); @@ -250,90 +239,121 @@ int64_t detect_language(ov::Tensor& encoder_hidden_state, return output_token; } -std::vector prepare_init_ids(ov::Tensor& encoder_hidden_state, +std::vector prepare_init_ids(ov::Tensor& encoder_hidden_state, ov::genai::DecoderCache& decoder_cache, const ov::genai::WhisperGenerationConfig& config, const bool return_timestamps, ov::genai::RawPerfMetrics& raw_metrics) { if (!config.is_multilingual) { if (return_timestamps) { - return std::vector{static_cast(config.decoder_start_token_id)}; + return std::vector{config.decoder_start_token_id}; } else { - return std::vector{static_cast(config.decoder_start_token_id), - static_cast(config.no_timestamps_token_id)}; + return std::vector{config.decoder_start_token_id, + config.no_timestamps_token_id}; } } - int32_t language_token_id; + int64_t language_token_id; if (config.language.has_value()) { std::string language = *config.language; if (config.lang_to_id.count(language)) { - language_token_id = static_cast(config.lang_to_id.at(language)); + language_token_id = config.lang_to_id.at(language); } } else { language_token_id = detect_language(encoder_hidden_state, decoder_cache, config, raw_metrics); } - int32_t task_token_id = static_cast(config.transcribe_token_id); + int64_t task_token_id = config.transcribe_token_id; if (config.task.has_value() && *config.task == "translate") { - task_token_id = static_cast(config.translate_token_id); + task_token_id = config.translate_token_id; } if (return_timestamps) { - return std::vector{static_cast(config.decoder_start_token_id), + return std::vector{config.decoder_start_token_id, language_token_id, task_token_id}; } - return std::vector{static_cast(config.decoder_start_token_id), + return std::vector{config.decoder_start_token_id, language_token_id, task_token_id, - static_cast(config.no_timestamps_token_id)}; + config.no_timestamps_token_id}; } -std::pair> full_decode(ov::Tensor& encoder_hidden_state, - const ov::genai::WhisperGenerationConfig& config, - ov::genai::WhisperInitializedModels& models, - std::vector init_ids, - const size_t max_new_tokens, - const bool return_timestamps, - ov::genai::RawPerfMetrics& raw_metrics, - const std::shared_ptr streamer) { - int64_t output_token = decode(encoder_hidden_state, models.decoder, init_ids, config, raw_metrics, true, return_timestamps); - std::vector output_tokens{output_token}; - - if (!return_timestamps && streamer && streamer->write(output_token) != ov::genai::StreamingStatus::RUNNING) { - return {true, output_tokens}; +void stream_generated_tokens(const std::shared_ptr streamer_ptr, + ov::genai::GenerationHandle& handle, + const bool return_timestamps) { + if (return_timestamps || !streamer_ptr || !handle->can_read()) { + return; } - if (max_new_tokens == 1) { - return {false, output_tokens}; + std::unordered_map token = handle->read(); + + auto streaming_status = streamer_ptr->write(token.begin()->second.generated_ids); + if (streaming_status != ov::genai::StreamingStatus::RUNNING) { + streaming_status == ov::genai::StreamingStatus::CANCEL ? handle->cancel() : handle->stop(); } +} + +std::pair full_decode(ov::Tensor& encoder_hidden_state, + const ov::genai::WhisperGenerationConfig& config, + ov::genai::WhisperInitializedModels& models, + std::vector init_ids, + const bool return_timestamps, + ov::genai::RawPerfMetrics& raw_metrics, + const std::shared_ptr streamer, + ov::genai::Sampler& sampler, + ov::genai::SequenceGroup::Ptr sequence_group) { + auto handle = std::make_shared(sequence_group->get_generation_stream(), + sequence_group->get_sampling_parameters()); + + auto logits = decode(encoder_hidden_state, models.decoder, init_ids, raw_metrics); + process_whisper_logits(logits, config, return_timestamps, {}); + + // sample last token only + int64_t output_sequence_len = logits.get_shape().at(1); + sequence_group->schedule_tokens(sequence_group->get_prompt_len()); + sequence_group->set_output_seq_len(output_sequence_len); + + sampler.sample({sequence_group}, logits); + stream_generated_tokens(streamer, handle, return_timestamps); prepare_decoder_with_past(models.decoder_with_past, models.decoder, init_ids.size()); - for (size_t i = 0; i < max_new_tokens - 1; i++) { - auto output_token = decode_with_past(models.decoder_with_past, - output_tokens.back(), - i + init_ids.size(), - config, - raw_metrics, - return_timestamps, - output_tokens); - update_past_key_value(models.decoder_with_past, models.decoder_with_past, i + init_ids.size()); - - if (output_token == config.eos_token_id) { - break; - } + while (!sequence_group->has_finished() && !sequence_group->handle_stopped() && !sequence_group->handle_cancelled()) { + sequence_group->schedule_tokens(1); + const auto running_sequences = sequence_group->get_running_sequences(); + OPENVINO_ASSERT(running_sequences.size() == 1u); + auto last_token = running_sequences.front()->get_generated_ids().back(); + auto last_idx = running_sequences.front()->get_generated_ids().size() - 1; + + auto logits = decode_with_past(models.decoder_with_past, + last_token, + last_idx + init_ids.size(), + raw_metrics); + process_whisper_logits(logits, config, return_timestamps, running_sequences.front()->get_generated_ids()); + update_past_key_value(models.decoder_with_past, models.decoder_with_past, last_idx + init_ids.size()); + + sampler.sample({sequence_group}, logits); + stream_generated_tokens(streamer, handle, return_timestamps); + } - output_tokens.push_back(output_token); + ov::genai::EncodedResults results; + // NB: Only batch=1 is supported now + results.scores.resize(1u); + results.scores[0] = 0u; + results.tokens.resize(1u); - if (!return_timestamps && streamer && streamer->write(output_token) != ov::genai::StreamingStatus::RUNNING) { - return {true, output_tokens}; - } - } + OPENVINO_ASSERT(sequence_group->get_finished_sequences().size() == 1u); + auto sequence = sequence_group->get_finished_sequences().front(); + results.tokens[0] = sequence->get_generated_ids(); + results.scores[0] = sequence->get_cumulative_log_prob(); + + sampler.clear_request_info(sequence_group->get_request_id()); - return {false, output_tokens}; + results.perf_metrics.raw_metrics = raw_metrics; + + return {results, (sequence_group->handle_stopped() || sequence_group->handle_cancelled())}; } bool check_decoder_model_compatibility(const std::shared_ptr& decoder) { @@ -447,10 +467,7 @@ void preprocess_decoder(std::shared_ptr model) { ov::preprocess::PrePostProcessor preprocessor(model); for (auto tensor : model->inputs()) { - if (tensor.get_any_name().find("input_ids") != std::string::npos) { - preprocessor.input("input_ids").tensor().set_element_type(ov::element::Type_t::i32); - preprocessor.input("input_ids").preprocess().convert_element_type(ov::element::Type_t::i32); - } else if (tensor.get_any_name().find("attention_mask") != std::string::npos) { + if (tensor.get_any_name().find("attention_mask") != std::string::npos) { preprocessor.input("attention_mask").tensor().set_element_type(ov::element::Type_t::f16); preprocessor.input("attention_mask").preprocess().convert_element_type(); } else if (tensor.get_any_name().find("encoder_hidden_states") != std::string::npos) { @@ -822,7 +839,8 @@ ov::InferRequest DecoderCache::get_model(uint8_t input_ids_size) { WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesystem::path& models_path, const ov::AnyMap& properties) - : WhisperPipelineImplBase{models_path} { + : WhisperPipelineImplBase{models_path} + , m_sampler(m_tokenizer) { ov::Core core = utils::singleton_core(); auto encoder_model = core.read_model(models_path / "openvino_encoder_model.xml", {}, properties); @@ -877,6 +895,8 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys if (m_generation_config.eos_token_id == -1) { m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); } + + m_sampler.set_seed(m_generation_config.rng_seed); } WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate( @@ -916,7 +936,7 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate( // long-form audio processing requires timestamps to be enabled const bool return_timestamps = config.return_timestamps || !is_shortform; - std::vector init_ids; + std::vector init_ids; std::vector output_tokens; std::vector segments; @@ -947,14 +967,19 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate( m_models.decoder = m_decoder_cache.get_model(init_ids.size()); } - auto [cancelled, chunk_output_tokens] = full_decode(hidden_state_tensor, - config, - m_models, - init_ids, - max_new_tokens - output_tokens.size(), - return_timestamps, - raw_metrics, - streamer_ptr); + SequenceGroup::Ptr sequence_group = std::make_shared(0, init_ids, config, 1); + + auto [results, cancelled] = full_decode(hidden_state_tensor, + config, + m_models, + init_ids, + return_timestamps, + raw_metrics, + streamer_ptr, + m_sampler, + sequence_group); + + std::vector chunk_output_tokens = results.tokens[0]; if (return_timestamps) { auto extracted_segments = ov::genai::extract_segments(chunk_output_tokens, diff --git a/src/cpp/src/whisper_pipeline_static.hpp b/src/cpp/src/whisper_pipeline_static.hpp index 98c2506aaf..ca5c6a0527 100644 --- a/src/cpp/src/whisper_pipeline_static.hpp +++ b/src/cpp/src/whisper_pipeline_static.hpp @@ -11,6 +11,7 @@ #include "openvino/genai/whisper_pipeline.hpp" #include "whisper/whisper_models.hpp" #include "whisper_pipeline_base.hpp" +#include "sampler.hpp" namespace ov { namespace genai { @@ -40,6 +41,7 @@ class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPi private: WhisperInitializedModels m_models; DecoderCache m_decoder_cache; + Sampler m_sampler; }; } // namespace genai