|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +// Given a image tensor, prefill the KV cache of LLaVA. |
| 10 | + |
| 11 | +#include <executorch/extension/llm/runner/constants.h> |
| 12 | +#include <executorch/extension/llm/runner/image_prefiller.h> |
| 13 | +#include <executorch/extension/tensor/tensor.h> |
| 14 | + |
| 15 | +namespace executorch::extension::llm { |
| 16 | +/** |
| 17 | + * Prefill an LLM Module with the given image input. |
| 18 | + * @param image The image input to LLaVa. |
| 19 | + * @param start_pos The starting position in KV cache of the input in the LLM |
| 20 | + * @return logits of the image prefill. |
| 21 | + */ |
| 22 | +::executorch::runtime::Result<uint64_t> ImagePrefiller::prefill( |
| 23 | + ::executorch::extension::llm::Image& image, |
| 24 | + int64_t& start_pos) { |
| 25 | + auto image_tensor = executorch::extension::from_blob( |
| 26 | + image.data.data(), |
| 27 | + {3, image.height, image.width}, |
| 28 | + ::executorch::aten::ScalarType::Byte); |
| 29 | + // Run image encoder |
| 30 | + auto image_encoder_outputs = |
| 31 | + ET_UNWRAP(module_->execute(kImageEncoderMethod, image_tensor)); |
| 32 | + |
| 33 | + // inputs:[start_pos, embeds] |
| 34 | + auto start_pos_tensor = executorch::extension::from_blob( |
| 35 | + &start_pos, {1}, ::executorch::aten::ScalarType::Long); |
| 36 | + |
| 37 | + // Run text model |
| 38 | + auto outputs_res = ET_UNWRAP(module_->execute( |
| 39 | + kTextModelMethod, {start_pos_tensor, image_encoder_outputs[0]})); |
| 40 | + ET_CHECK_MSG( |
| 41 | + outputs_res[0].isTensor(), |
| 42 | + "Non Tensor Output returned from executing image prefill"); |
| 43 | + |
| 44 | + // Update the start_pos, which is only available inside this function. |
| 45 | + // outputs_res can have only one logits. |
| 46 | + start_pos += image_encoder_outputs[0].toTensor().size(1); |
| 47 | + |
| 48 | + return logits_to_token(outputs_res[0].toTensor()); |
| 49 | +} |
| 50 | + |
| 51 | +/** |
| 52 | + * Load the Module for image prefill purpose. |
| 53 | + * @return The error code. |
| 54 | + */ |
| 55 | +::executorch::runtime::Error ImagePrefiller::load() { |
| 56 | + if (is_method_loaded()) { |
| 57 | + return ::executorch::runtime::Error::Ok; |
| 58 | + } |
| 59 | + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kImageEncoderMethod)); |
| 60 | + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTextModelMethod)); |
| 61 | + return ::executorch::runtime::Error::Ok; |
| 62 | +} |
| 63 | + |
| 64 | +/** |
| 65 | + * Check if the required methods in the Module is loaded. |
| 66 | + * @return True if the Module is loaded, false otherwise. |
| 67 | + */ |
| 68 | +bool ImagePrefiller::is_method_loaded() { |
| 69 | + ::executorch::runtime::Result<std::unordered_set<std::string>> methods_res = |
| 70 | + module_->method_names(); |
| 71 | + if (methods_res.error() != ::executorch::runtime::Error::Ok) { |
| 72 | + ET_CHECK_MSG(false, "Failed to get method names"); |
| 73 | + } |
| 74 | + std::unordered_set<std::string> methods = methods_res.get(); |
| 75 | + bool methods_exist = methods.find(kImageEncoderMethod) != methods.end() && |
| 76 | + methods.find(kTextModelMethod) != methods.end(); |
| 77 | + if (!methods_exist) { |
| 78 | + for (const auto& method : methods) { |
| 79 | + ET_LOG(Error, "Method: %s", method.c_str()); |
| 80 | + } |
| 81 | + ET_CHECK_MSG( |
| 82 | + methods_exist, |
| 83 | + "Missing required methods (%s, %s) in the model", |
| 84 | + kImageEncoderMethod, |
| 85 | + kTextModelMethod); |
| 86 | + } |
| 87 | + bool methods_loaded = module_->is_method_loaded(kImageEncoderMethod) && |
| 88 | + module_->is_method_loaded(kTextModelMethod); |
| 89 | + return methods_loaded; |
| 90 | +} |
| 91 | + |
| 92 | +} // namespace executorch::extension::llm |
0 commit comments