Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 127 additions & 102 deletions src/cpp/src/whisper_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,68 +125,57 @@ void update_past_key_value(ov::InferRequest& source, ov::InferRequest& dest, con
}

void set_decoder_input_ids(ov::InferRequest& decoder,
const std::vector<int32_t>& init_ids) {
const std::vector<int64_t>& init_ids) {
auto input_ids_tensor = decoder.get_tensor("input_ids");
const size_t seq_length = input_ids_tensor.get_shape()[1];

OPENVINO_ASSERT(seq_length >= init_ids.size());

auto input_ids_data = input_ids_tensor.data<int32_t>();
auto input_ids_data = input_ids_tensor.data<int64_t>();
std::copy(init_ids.begin(), init_ids.end(), input_ids_data);
}

int64_t decode(ov::Tensor& encoder_hidden_state,
ov::InferRequest& decoder,
const std::vector<int32_t>& init_ids,
const ov::genai::WhisperGenerationConfig& config,
ov::genai::RawPerfMetrics& raw_metrics,
const bool apply_logit_processors = true,
const bool return_timestamps = false) {
// NB: Fill decoder inputs
encoder_hidden_state.copy_to(decoder.get_tensor("encoder_hidden_states"));
set_decoder_input_ids(decoder, init_ids);

ov::genai::utils::infer_with_perf_metrics(decoder, raw_metrics);
void process_whisper_logits(ov::Tensor logits,
const ov::genai::WhisperGenerationConfig& config,
const bool return_timestamps,
const std::vector<int64_t>& generated_tokens) {
const bool initial_step = generated_tokens.empty();

auto output_tensor = decoder.get_tensor("logits");
if (initial_step) {
ov::genai::do_suppress_tokens(logits, 0, config.begin_suppress_tokens);
}

if (apply_logit_processors) {
ov::genai::do_suppress_tokens(output_tensor, 0, config.begin_suppress_tokens);
ov::genai::do_suppress_tokens(output_tensor, 0, config.suppress_tokens);
ov::genai::do_suppress_tokens(logits, 0, config.suppress_tokens);

if (return_timestamps) {
ov::genai::process_whisper_timestamp_logits(output_tensor, 0, config, {}, true);
}
if (return_timestamps) {
ov::genai::process_whisper_timestamp_logits(logits, 0, config, generated_tokens, initial_step);
}

int64_t output_token = ov::genai::utils::argmax(output_tensor, 0);
return output_token;
}

int64_t decode_with_past(ov::InferRequest& decoder_with_past,
const int64_t input_id,
const int64_t position_id,
const ov::genai::WhisperGenerationConfig& config,
ov::genai::RawPerfMetrics& raw_metrics,
const bool return_timestamps,
const std::vector<int64_t>& generated_tokens) {
// FIXME: Avoid this cast to i32. Why it's not i64 precision in model?
decoder_with_past.get_tensor("input_ids").data<int32_t>()[0] = static_cast<int32_t>(input_id);
ov::Tensor decode(ov::Tensor& encoder_hidden_state,
ov::InferRequest& decoder,
const std::vector<int64_t>& init_ids,
ov::genai::RawPerfMetrics& raw_metrics) {
// NB: Fill decoder inputs
encoder_hidden_state.copy_to(decoder.get_tensor("encoder_hidden_states"));
set_decoder_input_ids(decoder, init_ids);

ov::genai::utils::infer_with_perf_metrics(decoder, raw_metrics);
return decoder.get_tensor("logits");
}

ov::Tensor decode_with_past(ov::InferRequest& decoder_with_past,
const int64_t input_id,
const int64_t position_id,
ov::genai::RawPerfMetrics& raw_metrics) {
decoder_with_past.get_tensor("input_ids").data<int64_t>()[0] = input_id;
decoder_with_past.get_tensor("cache_position").data<int64_t>()[0] = position_id;
// FIXME: Is "attention_mask" supposed to be f16?
decoder_with_past.get_tensor("attention_mask").data<ov::float16>()[position_id - 1] = 0u;

ov::genai::utils::infer_with_perf_metrics(decoder_with_past, raw_metrics);

auto output_tensor = decoder_with_past.get_tensor("logits");
ov::genai::do_suppress_tokens(output_tensor, 0, config.suppress_tokens);

if (return_timestamps) {
ov::genai::process_whisper_timestamp_logits(output_tensor, 0, config, generated_tokens);
}

int64_t output_token = ov::genai::utils::argmax(output_tensor, 0);
return output_token;
return decoder_with_past.get_tensor("logits");
}

void zero_past_key_values(ov::InferRequest& request) {
Expand Down Expand Up @@ -224,7 +213,7 @@ int64_t detect_language(ov::Tensor& encoder_hidden_state,

decoder.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state});

std::vector<int32_t> init_ids{static_cast<int32_t>(config.decoder_start_token_id)};
std::vector<int64_t> init_ids{config.decoder_start_token_id};
set_decoder_input_ids(decoder, init_ids);

const auto infer_start = std::chrono::steady_clock::now();
Expand All @@ -250,90 +239,121 @@ int64_t detect_language(ov::Tensor& encoder_hidden_state,
return output_token;
}

std::vector<int32_t> prepare_init_ids(ov::Tensor& encoder_hidden_state,
std::vector<int64_t> prepare_init_ids(ov::Tensor& encoder_hidden_state,
ov::genai::DecoderCache& decoder_cache,
const ov::genai::WhisperGenerationConfig& config,
const bool return_timestamps,
ov::genai::RawPerfMetrics& raw_metrics) {
if (!config.is_multilingual) {
if (return_timestamps) {
return std::vector<int32_t>{static_cast<int32_t>(config.decoder_start_token_id)};
return std::vector<int64_t>{config.decoder_start_token_id};
} else {
return std::vector<int32_t>{static_cast<int32_t>(config.decoder_start_token_id),
static_cast<int32_t>(config.no_timestamps_token_id)};
return std::vector<int64_t>{config.decoder_start_token_id,
config.no_timestamps_token_id};
}
}

int32_t language_token_id;
int64_t language_token_id;
if (config.language.has_value()) {
std::string language = *config.language;
if (config.lang_to_id.count(language)) {
language_token_id = static_cast<int32_t>(config.lang_to_id.at(language));
language_token_id = config.lang_to_id.at(language);
}
} else {
language_token_id = detect_language(encoder_hidden_state, decoder_cache, config, raw_metrics);
}

int32_t task_token_id = static_cast<int32_t>(config.transcribe_token_id);
int64_t task_token_id = config.transcribe_token_id;
if (config.task.has_value() && *config.task == "translate") {
task_token_id = static_cast<int32_t>(config.translate_token_id);
task_token_id = config.translate_token_id;
}

if (return_timestamps) {
return std::vector<int32_t>{static_cast<int32_t>(config.decoder_start_token_id),
return std::vector<int64_t>{config.decoder_start_token_id,
language_token_id,
task_token_id};
}

return std::vector<int32_t>{static_cast<int32_t>(config.decoder_start_token_id),
return std::vector<int64_t>{config.decoder_start_token_id,
language_token_id,
task_token_id,
static_cast<int32_t>(config.no_timestamps_token_id)};
config.no_timestamps_token_id};
}

std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_state,
const ov::genai::WhisperGenerationConfig& config,
ov::genai::WhisperInitializedModels& models,
std::vector<int32_t> init_ids,
const size_t max_new_tokens,
const bool return_timestamps,
ov::genai::RawPerfMetrics& raw_metrics,
const std::shared_ptr<ov::genai::StreamerBase> streamer) {
int64_t output_token = decode(encoder_hidden_state, models.decoder, init_ids, config, raw_metrics, true, return_timestamps);
std::vector<int64_t> output_tokens{output_token};

if (!return_timestamps && streamer && streamer->write(output_token) != ov::genai::StreamingStatus::RUNNING) {
return {true, output_tokens};
void stream_generated_tokens(const std::shared_ptr<ov::genai::StreamerBase> streamer_ptr,
ov::genai::GenerationHandle& handle,
const bool return_timestamps) {
if (return_timestamps || !streamer_ptr || !handle->can_read()) {
return;
}

if (max_new_tokens == 1) {
return {false, output_tokens};
std::unordered_map<uint64_t, ov::genai::GenerationOutput> token = handle->read();

auto streaming_status = streamer_ptr->write(token.begin()->second.generated_ids);
if (streaming_status != ov::genai::StreamingStatus::RUNNING) {
streaming_status == ov::genai::StreamingStatus::CANCEL ? handle->cancel() : handle->stop();
}
}

std::pair<ov::genai::EncodedResults, bool> full_decode(ov::Tensor& encoder_hidden_state,
const ov::genai::WhisperGenerationConfig& config,
ov::genai::WhisperInitializedModels& models,
std::vector<int64_t> init_ids,
const bool return_timestamps,
ov::genai::RawPerfMetrics& raw_metrics,
const std::shared_ptr<ov::genai::StreamerBase> streamer,
ov::genai::Sampler& sampler,
ov::genai::SequenceGroup::Ptr sequence_group) {
auto handle = std::make_shared<ov::genai::GenerationHandleImpl>(sequence_group->get_generation_stream(),
sequence_group->get_sampling_parameters());

auto logits = decode(encoder_hidden_state, models.decoder, init_ids, raw_metrics);
process_whisper_logits(logits, config, return_timestamps, {});

// sample last token only
int64_t output_sequence_len = logits.get_shape().at(1);
sequence_group->schedule_tokens(sequence_group->get_prompt_len());
sequence_group->set_output_seq_len(output_sequence_len);

sampler.sample({sequence_group}, logits);
stream_generated_tokens(streamer, handle, return_timestamps);

prepare_decoder_with_past(models.decoder_with_past, models.decoder, init_ids.size());

for (size_t i = 0; i < max_new_tokens - 1; i++) {
auto output_token = decode_with_past(models.decoder_with_past,
output_tokens.back(),
i + init_ids.size(),
config,
raw_metrics,
return_timestamps,
output_tokens);
update_past_key_value(models.decoder_with_past, models.decoder_with_past, i + init_ids.size());

if (output_token == config.eos_token_id) {
break;
}
while (!sequence_group->has_finished() && !sequence_group->handle_stopped() && !sequence_group->handle_cancelled()) {
sequence_group->schedule_tokens(1);
const auto running_sequences = sequence_group->get_running_sequences();
OPENVINO_ASSERT(running_sequences.size() == 1u);
auto last_token = running_sequences.front()->get_generated_ids().back();
auto last_idx = running_sequences.front()->get_generated_ids().size() - 1;

auto logits = decode_with_past(models.decoder_with_past,
last_token,
last_idx + init_ids.size(),
raw_metrics);
process_whisper_logits(logits, config, return_timestamps, running_sequences.front()->get_generated_ids());
update_past_key_value(models.decoder_with_past, models.decoder_with_past, last_idx + init_ids.size());

sampler.sample({sequence_group}, logits);
stream_generated_tokens(streamer, handle, return_timestamps);
}

output_tokens.push_back(output_token);
ov::genai::EncodedResults results;
// NB: Only batch=1 is supported now
results.scores.resize(1u);
results.scores[0] = 0u;
results.tokens.resize(1u);

if (!return_timestamps && streamer && streamer->write(output_token) != ov::genai::StreamingStatus::RUNNING) {
return {true, output_tokens};
}
}
OPENVINO_ASSERT(sequence_group->get_finished_sequences().size() == 1u);
auto sequence = sequence_group->get_finished_sequences().front();
results.tokens[0] = sequence->get_generated_ids();
results.scores[0] = sequence->get_cumulative_log_prob();

sampler.clear_request_info(sequence_group->get_request_id());

return {false, output_tokens};
results.perf_metrics.raw_metrics = raw_metrics;

return {results, (sequence_group->handle_stopped() || sequence_group->handle_cancelled())};
}

bool check_decoder_model_compatibility(const std::shared_ptr<ov::Model>& decoder) {
Expand Down Expand Up @@ -447,10 +467,7 @@ void preprocess_decoder(std::shared_ptr<ov::Model> model) {
ov::preprocess::PrePostProcessor preprocessor(model);

for (auto tensor : model->inputs()) {
if (tensor.get_any_name().find("input_ids") != std::string::npos) {
preprocessor.input("input_ids").tensor().set_element_type(ov::element::Type_t::i32);
preprocessor.input("input_ids").preprocess().convert_element_type(ov::element::Type_t::i32);
} else if (tensor.get_any_name().find("attention_mask") != std::string::npos) {
if (tensor.get_any_name().find("attention_mask") != std::string::npos) {
preprocessor.input("attention_mask").tensor().set_element_type(ov::element::Type_t::f16);
preprocessor.input("attention_mask").preprocess().convert_element_type();
} else if (tensor.get_any_name().find("encoder_hidden_states") != std::string::npos) {
Expand Down Expand Up @@ -822,7 +839,8 @@ ov::InferRequest DecoderCache::get_model(uint8_t input_ids_size) {

WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesystem::path& models_path,
const ov::AnyMap& properties)
: WhisperPipelineImplBase{models_path} {
: WhisperPipelineImplBase{models_path}
, m_sampler(m_tokenizer) {
ov::Core core = utils::singleton_core();

auto encoder_model = core.read_model(models_path / "openvino_encoder_model.xml", {}, properties);
Expand Down Expand Up @@ -877,6 +895,8 @@ WhisperPipeline::StaticWhisperPipeline::StaticWhisperPipeline(const std::filesys
if (m_generation_config.eos_token_id == -1) {
m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id());
}

m_sampler.set_seed(m_generation_config.rng_seed);
}

WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate(
Expand Down Expand Up @@ -916,7 +936,7 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate(
// long-form audio processing requires timestamps to be enabled
const bool return_timestamps = config.return_timestamps || !is_shortform;

std::vector<int32_t> init_ids;
std::vector<int64_t> init_ids;
std::vector<int64_t> output_tokens;
std::vector<Segment> segments;

Expand Down Expand Up @@ -947,14 +967,19 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate(
m_models.decoder = m_decoder_cache.get_model(init_ids.size());
}

auto [cancelled, chunk_output_tokens] = full_decode(hidden_state_tensor,
config,
m_models,
init_ids,
max_new_tokens - output_tokens.size(),
return_timestamps,
raw_metrics,
streamer_ptr);
SequenceGroup::Ptr sequence_group = std::make_shared<SequenceGroup>(0, init_ids, config, 1);

auto [results, cancelled] = full_decode(hidden_state_tensor,
config,
m_models,
init_ids,
return_timestamps,
raw_metrics,
streamer_ptr,
m_sampler,
sequence_group);

std::vector<int64_t> chunk_output_tokens = results.tokens[0];

if (return_timestamps) {
auto extracted_segments = ov::genai::extract_segments(chunk_output_tokens,
Expand Down
2 changes: 2 additions & 0 deletions src/cpp/src/whisper_pipeline_static.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "openvino/genai/whisper_pipeline.hpp"
#include "whisper/whisper_models.hpp"
#include "whisper_pipeline_base.hpp"
#include "sampler.hpp"

namespace ov {
namespace genai {
Expand Down Expand Up @@ -40,6 +41,7 @@ class WhisperPipeline::StaticWhisperPipeline : public WhisperPipeline::WhisperPi
private:
WhisperInitializedModels m_models;
DecoderCache m_decoder_cache;
Sampler m_sampler;
};

} // namespace genai
Expand Down
Loading