|
| 1 | +// Copyright 2025 Google LLC. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#ifndef THIRD_PARTY_ODML_LITERT_LITERT_CC_INTERNAL_LITERT_COMPILED_MODEL_NEXT_H_ |
| 16 | +#define THIRD_PARTY_ODML_LITERT_LITERT_CC_INTERNAL_LITERT_COMPILED_MODEL_NEXT_H_ |
| 17 | + |
| 18 | +#include <cstddef> |
| 19 | +#include <optional> |
| 20 | +#include <string> |
| 21 | + |
| 22 | +#include "absl/strings/string_view.h" // from @com_google_absl |
| 23 | +#include "litert/c/litert_common.h" |
| 24 | +#include "litert/c/litert_compiled_model.h" |
| 25 | +#include "litert/cc/internal/litert_handle.h" |
| 26 | +#include "litert/cc/litert_common.h" |
| 27 | +#include "litert/cc/litert_compiled_model.h" |
| 28 | +#include "litert/cc/litert_environment.h" |
| 29 | +#include "litert/cc/litert_expected.h" |
| 30 | +#include "litert/cc/litert_macros.h" |
| 31 | +#include "litert/cc/litert_model.h" |
| 32 | + |
| 33 | +namespace litert { |
| 34 | + |
| 35 | +// Advanced CompiledModel with new / experimental features. |
| 36 | +class CompiledModelNext : public CompiledModel { |
| 37 | + public: |
| 38 | + static Expected<CompiledModelNext> Create( |
| 39 | + litert::Environment& env, const litert::Model& model, |
| 40 | + litert::HwAccelerators hardware_accelerators); |
| 41 | + |
| 42 | + // Sets a dispatch annotation on the compiled model. These annotations will be |
| 43 | + // propagated to dispatch graphs when they are created during model execution. |
| 44 | + // The annotations provide runtime hints and metadata that can be used by |
| 45 | + // hardware accelerators for optimization. |
| 46 | + // |
| 47 | + // Parameters: |
| 48 | + // - signature_index: the index of the signature (zero-based). |
| 49 | + // - key: the annotation key. |
| 50 | + // - value: the annotation value. |
| 51 | + // |
| 52 | + // Example annotations: |
| 53 | + // - "priority": "high|medium|low" - execution priority hints |
| 54 | + // - "memory_type": "shared|dedicated" - memory allocation preferences |
| 55 | + // - "accelerator": "npu|gpu|dsp" - preferred hardware accelerator |
| 56 | + // - "precision": "fp32|fp16|int8" - computation precision requirements |
| 57 | + Expected<void> SetDispatchAnnotation(size_t signature_index, |
| 58 | + absl::string_view key, |
| 59 | + absl::string_view value) { |
| 60 | + LITERT_RETURN_IF_ERROR(LiteRtCompiledModelSetDispatchAnnotation( |
| 61 | + Get(), signature_index, key.data(), value.data())); |
| 62 | + return {}; |
| 63 | + } |
| 64 | + |
| 65 | + // Gets a dispatch annotation from the compiled model. |
| 66 | + // |
| 67 | + // Parameters: |
| 68 | + // - signature_index: the index of the signature (zero-based). |
| 69 | + // - key: the annotation key to look up. |
| 70 | + // |
| 71 | + // Returns: |
| 72 | + // - The annotation value if found, or nullopt if the key doesn't exist. |
| 73 | + Expected<std::optional<std::string>> GetDispatchAnnotation( |
| 74 | + size_t signature_index, absl::string_view key) { |
| 75 | + const char* value = nullptr; |
| 76 | + LITERT_RETURN_IF_ERROR(LiteRtCompiledModelGetDispatchAnnotation( |
| 77 | + Get(), signature_index, key.data(), &value)); |
| 78 | + if (value == nullptr) { |
| 79 | + return Expected<std::optional<std::string>>(std::nullopt); |
| 80 | + } |
| 81 | + return Expected<std::optional<std::string>>(std::string(value)); |
| 82 | + } |
| 83 | + |
| 84 | + // Removes a dispatch annotation from the compiled model. |
| 85 | + // |
| 86 | + // Parameters: |
| 87 | + // - signature_index: the index of the signature (zero-based). |
| 88 | + // - key: the annotation key to remove. |
| 89 | + // |
| 90 | + // Note: This function succeeds even if the key doesn't exist. |
| 91 | + Expected<void> RemoveDispatchAnnotation(size_t signature_index, |
| 92 | + absl::string_view key) { |
| 93 | + LITERT_RETURN_IF_ERROR(LiteRtCompiledModelRemoveDispatchAnnotation( |
| 94 | + Get(), signature_index, key.data())); |
| 95 | + return {}; |
| 96 | + } |
| 97 | + |
| 98 | + // Overloaded version for the default signature (index 0). |
| 99 | + Expected<void> SetDispatchAnnotation(absl::string_view key, |
| 100 | + absl::string_view value) { |
| 101 | + return SetDispatchAnnotation(0, key, value); |
| 102 | + } |
| 103 | + |
| 104 | + // Overloaded version for the default signature (index 0). |
| 105 | + Expected<std::optional<std::string>> GetDispatchAnnotation( |
| 106 | + absl::string_view key) { |
| 107 | + return GetDispatchAnnotation(0, key); |
| 108 | + } |
| 109 | + |
| 110 | + // Overloaded version for the default signature (index 0). |
| 111 | + Expected<void> RemoveDispatchAnnotation(absl::string_view key) { |
| 112 | + return RemoveDispatchAnnotation(0, key); |
| 113 | + } |
| 114 | + |
| 115 | + // Overloaded version that takes a signature name instead of index. |
| 116 | + Expected<void> SetDispatchAnnotation(absl::string_view signature_name, |
| 117 | + absl::string_view key, |
| 118 | + absl::string_view value) { |
| 119 | + LITERT_ASSIGN_OR_RETURN(size_t signature_index, |
| 120 | + model_.GetSignatureIndex(signature_name)); |
| 121 | + return SetDispatchAnnotation(signature_index, key, value); |
| 122 | + } |
| 123 | + |
| 124 | + // Overloaded version that takes a signature name instead of index. |
| 125 | + Expected<std::optional<std::string>> GetDispatchAnnotation( |
| 126 | + absl::string_view signature_name, absl::string_view key) { |
| 127 | + LITERT_ASSIGN_OR_RETURN(size_t signature_index, |
| 128 | + model_.GetSignatureIndex(signature_name)); |
| 129 | + return GetDispatchAnnotation(signature_index, key); |
| 130 | + } |
| 131 | + |
| 132 | + // Overloaded version that takes a signature name instead of index. |
| 133 | + Expected<void> RemoveDispatchAnnotation(absl::string_view signature_name, |
| 134 | + absl::string_view key) { |
| 135 | + LITERT_ASSIGN_OR_RETURN(size_t signature_index, |
| 136 | + model_.GetSignatureIndex(signature_name)); |
| 137 | + return RemoveDispatchAnnotation(signature_index, key); |
| 138 | + } |
| 139 | + |
| 140 | + private: |
| 141 | + explicit CompiledModelNext(LiteRtModel litert_model, |
| 142 | + LiteRtCompiledModel compiled_model, |
| 143 | + OwnHandle owned) |
| 144 | + : CompiledModel(litert_model, compiled_model, owned) {} |
| 145 | +}; |
| 146 | + |
| 147 | +} // namespace litert |
| 148 | + |
| 149 | +#endif // THIRD_PARTY_ODML_LITERT_LITERT_CC_INTERNAL_LITERT_COMPILED_MODEL_NEXT_H_ |
0 commit comments