Skip to content

Commit b91c2e9

Browse files
hheydarycopybara-github
authored andcommitted
Add LoRA support to AI Edge Transformers.
PiperOrigin-RevId: 713037640
1 parent 603e8ea commit b91c2e9

File tree

4 files changed

+227
-9
lines changed

4 files changed

+227
-9
lines changed

ai_edge_torch/generative/examples/cpp/BUILD

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,17 @@ package(
2424

2525
cc_library(
2626
name = "utils",
27+
srcs = ["utils.cc"],
2728
hdrs = ["utils.h"],
2829
deps = [
30+
"@com_google_absl//absl/container:flat_hash_map",
31+
"@com_google_absl//absl/container:flat_hash_set",
32+
"@com_google_absl//absl/memory",
33+
"@com_google_absl//absl/strings",
34+
"@com_google_absl//absl/strings:str_format",
35+
"@org_tensorflow//tensorflow/lite:framework",
2936
"@org_tensorflow//tensorflow/lite:util",
37+
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
3038
],
3139
)
3240

ai_edge_torch/generative/examples/cpp/text_generator_main.cc

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,12 @@ ABSL_FLAG(std::string, stop_token, "",
7171
ABSL_FLAG(int, num_threads, 4, "Number of threads to use. Defaults to 4.");
7272
ABSL_FLAG(std::string, weight_cache_path, "",
7373
"XNNPACK weight caching path, e.g. /tmp/model.xnnpack_cache.");
74+
ABSL_FLAG(std::string, lora_path, "", "Optional path to LoRA artifact.");
7475

7576
namespace {
7677

7778
using ai_edge_torch::examples::AlignedAllocator;
79+
using ai_edge_torch::examples::LoRA;
7880

7981
std::unique_ptr<tflite::FlatBufferModel> LoadModel() {
8082
std::unique_ptr<tflite::FlatBufferModel> model =
@@ -172,23 +174,32 @@ void PrepareRunner(
172174
tflite::SignatureRunner* GetPrefillRunner(
173175
tflite::Interpreter* interpreter, std::size_t num_input_tokens,
174176
std::map<std::string, std::vector<float, AlignedAllocator<float>>>&
175-
kv_cache) {
176-
// Find the prefill signature that best matches the input token size.
177+
kv_cache,
178+
const LoRA* lora) {
179+
// Find the prefill signature length that best matches the input token size.
177180
tflite::SignatureRunner* runner = nullptr;
181+
int best_seq_size = -1;
178182
int delta = std::numeric_limits<int>::max();
179183
for (const std::string* key : interpreter->signature_keys()) {
180-
if (!absl::StrContains(*key, "prefill")) {
184+
if (!absl::StrContains(*key, "prefill") ||
185+
absl::StrContains(*key, "lora")) {
181186
continue;
182187
}
183188
TfLiteTensor* input_pos = interpreter->GetSignatureRunner(key->c_str())
184189
->input_tensor("input_pos");
185190
// The expected shape for input position is [Seq].
186191
int seq_size = input_pos->dims->data[0];
187192
if (num_input_tokens <= seq_size && seq_size - num_input_tokens < delta) {
188-
runner = interpreter->GetSignatureRunner(key->c_str());
193+
if (lora == nullptr) {
194+
runner = interpreter->GetSignatureRunner(key->c_str());
195+
}
196+
best_seq_size = seq_size;
189197
delta = seq_size - num_input_tokens;
190198
}
191199
}
200+
if (lora != nullptr) {
201+
runner = lora->GetPrefillRunner(interpreter, best_seq_size);
202+
}
192203
MINIMAL_CHECK(runner != nullptr);
193204
PrepareRunner(runner, kv_cache);
194205
return runner;
@@ -197,8 +208,11 @@ tflite::SignatureRunner* GetPrefillRunner(
197208
tflite::SignatureRunner* GetDecodeRunner(
198209
tflite::Interpreter* interpreter,
199210
std::map<std::string, std::vector<float, AlignedAllocator<float>>>&
200-
kv_cache) {
201-
tflite::SignatureRunner* runner = interpreter->GetSignatureRunner("decode");
211+
kv_cache,
212+
LoRA* lora) {
213+
tflite::SignatureRunner* runner =
214+
lora == nullptr ? interpreter->GetSignatureRunner("decode")
215+
: lora->GetDecodeRunner(interpreter);
202216
MINIMAL_CHECK(runner != nullptr);
203217
PrepareRunner(runner, kv_cache);
204218
return runner;
@@ -242,7 +256,13 @@ int main(int argc, char* argv[]) {
242256
LoadSentencePieceProcessor();
243257
std::map<std::string, std::vector<float, AlignedAllocator<float>>> kv_cache =
244258
BuildKVCache(interpreter.get());
245-
MINIMAL_CHECK(!kv_cache.empty())
259+
MINIMAL_CHECK(!kv_cache.empty());
260+
261+
std::unique_ptr<LoRA> lora = nullptr;
262+
if (!absl::GetFlag(FLAGS_lora_path).empty()) {
263+
lora = LoRA::FromFile(absl::GetFlag(FLAGS_lora_path));
264+
MINIMAL_CHECK(lora != nullptr);
265+
}
246266

247267
// Tokenize the input prompt.
248268
std::string prompt = absl::GetFlag(FLAGS_prompt);
@@ -263,10 +283,10 @@ int main(int argc, char* argv[]) {
263283
// Get prefill and decode signature runners.
264284
std::size_t effective_prefill_token_size = prompt_tokens.size() - 1;
265285
tflite::SignatureRunner* prefill_runner = GetPrefillRunner(
266-
interpreter.get(), effective_prefill_token_size, kv_cache);
286+
interpreter.get(), effective_prefill_token_size, kv_cache, lora.get());
267287
MINIMAL_CHECK(prefill_runner != nullptr);
268288
tflite::SignatureRunner* decode_runner =
269-
GetDecodeRunner(interpreter.get(), kv_cache);
289+
GetDecodeRunner(interpreter.get(), kv_cache, lora.get());
270290
MINIMAL_CHECK(decode_runner != nullptr);
271291

272292
// Get Input Tensors for each of the runners.
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/* Copyright 2025 The AI Edge Torch Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "ai_edge_torch/generative/examples/cpp/utils.h"
17+
18+
#include <cstddef>
19+
#include <cstring>
20+
#include <memory>
21+
#include <string>
22+
#include <utility>
23+
#include <vector>
24+
25+
#include "absl/container/flat_hash_map.h"
26+
#include "absl/container/flat_hash_set.h"
27+
#include "absl/memory/memory.h"
28+
#include "absl/strings/match.h"
29+
#include "absl/strings/str_format.h"
30+
#include "absl/strings/string_view.h"
31+
#include "tensorflow/lite/interpreter.h"
32+
#include "tensorflow/lite/model_builder.h"
33+
#include "tensorflow/lite/schema/schema_generated.h"
34+
#include "tensorflow/lite/signature_runner.h"
35+
36+
namespace ai_edge_torch::examples {
37+
38+
std::unique_ptr<LoRA> LoRA::FromFile(absl::string_view path) {
39+
std::unique_ptr<tflite::FlatBufferModel> model =
40+
tflite::FlatBufferModel::VerifyAndBuildFromFile(path.data());
41+
if (model == nullptr) {
42+
return nullptr;
43+
}
44+
45+
int rank = -1;
46+
absl::flat_hash_map<std::string, std::vector<float, AlignedAllocator<float>>>
47+
tensors;
48+
for (const auto& tensor :
49+
*model->GetModel()->subgraphs()->Get(0)->tensors()) {
50+
size_t size = 1;
51+
for (const int& dim : *tensor->shape()) {
52+
size *= dim;
53+
}
54+
std::vector<float, AlignedAllocator<float>> buffer(size);
55+
const auto* data =
56+
model->GetModel()->buffers()->Get(tensor->buffer())->data();
57+
memcpy(buffer.data(), data->data(), data->size());
58+
tensors.emplace(*tensor->name(), std::move(buffer));
59+
60+
if (tensor->name()->str() == "lora_atten_q_a_prime_weight_0") {
61+
rank = tensor->shape()->Get(1);
62+
}
63+
}
64+
if (rank == -1) {
65+
return nullptr;
66+
}
67+
68+
return absl::WrapUnique(new LoRA(rank, std::move(tensors)));
69+
}
70+
71+
tflite::SignatureRunner* LoRA::GetPrefillRunner(
72+
tflite::Interpreter* interpreter, int matched_sequence_length) const {
73+
std::string signature_name =
74+
absl::StrFormat("prefill_%d_lora_r%d", matched_sequence_length, rank_);
75+
return GetRunnerHelper(interpreter, signature_name);
76+
}
77+
78+
tflite::SignatureRunner* LoRA::GetDecodeRunner(
79+
tflite::Interpreter* interpreter) const {
80+
std::string signature_name = absl::StrFormat("decode_lora_r%d", rank_);
81+
return GetRunnerHelper(interpreter, signature_name);
82+
};
83+
84+
tflite::SignatureRunner* LoRA::GetRunnerHelper(
85+
tflite::Interpreter* interpreter, absl::string_view signature_name) const {
86+
tflite::SignatureRunner* runner =
87+
interpreter->GetSignatureRunner(signature_name.data());
88+
if (runner == nullptr) {
89+
return nullptr;
90+
}
91+
92+
absl::flat_hash_set<std::string> lora_input_tensors;
93+
lora_input_tensors.reserve(runner->input_size());
94+
for (const char* input_name : runner->input_names()) {
95+
if (absl::StrContains(input_name, "lora")) {
96+
lora_input_tensors.insert(input_name);
97+
}
98+
}
99+
100+
if (lora_input_tensors.size() < tensors_.size()) {
101+
return nullptr;
102+
}
103+
104+
for (const auto& [name, buffer] : tensors_) {
105+
TfLiteTensor* tensor = runner->input_tensor(name.c_str());
106+
if (tensor == nullptr) {
107+
return nullptr;
108+
}
109+
lora_input_tensors.erase(name);
110+
TfLiteCustomAllocation allocation = {
111+
.data = static_cast<void*>(const_cast<float*>(buffer.data())),
112+
.bytes = buffer.size() * sizeof(float)};
113+
if (runner->SetCustomAllocationForInputTensor(name.c_str(), allocation) !=
114+
kTfLiteOk) {
115+
return nullptr;
116+
}
117+
};
118+
if (runner->AllocateTensors() != kTfLiteOk) {
119+
return nullptr;
120+
}
121+
122+
for (const auto& name : lora_input_tensors) {
123+
TfLiteTensor* tensor = runner->input_tensor(name.c_str());
124+
if (tensor == nullptr) {
125+
return nullptr;
126+
}
127+
memset(tensor->data.data, 0, tensor->bytes);
128+
}
129+
130+
return runner;
131+
}
132+
133+
} // namespace ai_edge_torch::examples

ai_edge_torch/generative/examples/cpp/utils.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,31 @@
1+
/* Copyright 2025 The AI Edge Torch Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
116
#ifndef THIRD_PARTY_PY_AI_EDGE_TORCH_GENERATIVE_EXAMPLES_CPP_UTILS_H_
217
#define THIRD_PARTY_PY_AI_EDGE_TORCH_GENERATIVE_EXAMPLES_CPP_UTILS_H_
318

419
#include <cstddef>
20+
#include <memory>
21+
#include <string>
22+
#include <utility>
23+
#include <vector>
524

25+
#include "absl/container/flat_hash_map.h"
26+
#include "absl/strings/string_view.h"
27+
#include "tensorflow/lite/interpreter.h"
28+
#include "tensorflow/lite/signature_runner.h"
629
#include "tensorflow/lite/util.h"
730

831
namespace ai_edge_torch::examples {
@@ -39,6 +62,40 @@ class AlignedAllocator {
3962
void deallocate(T* ptr, std::size_t n) { free(ptr); }
4063
};
4164

65+
// An example implementation of LoRA adapters manager for TFLite interpreter.
66+
// The class loads an adapter from a flatbuffers files and provides helper
67+
// methods for finding the right signature and setting the appropriate input
68+
// tensors. Please note the use of CustomAllocator to ensure zero-copy loading
69+
// and potentially hot-swapping between multiple adapters with minimal cost.
70+
class LoRA {
71+
public:
72+
static std::unique_ptr<LoRA> FromFile(absl::string_view path);
73+
74+
tflite::SignatureRunner* GetPrefillRunner(tflite::Interpreter* interpreter,
75+
int matched_sequence_length) const;
76+
tflite::SignatureRunner* GetDecodeRunner(
77+
tflite::Interpreter* interpreter) const;
78+
79+
int rank() const { return rank_; };
80+
81+
private:
82+
explicit LoRA(int rank,
83+
absl::flat_hash_map<std::string,
84+
std::vector<float, AlignedAllocator<float>>>
85+
tensors)
86+
: rank_(rank), tensors_(std::move(tensors)) {}
87+
88+
tflite::SignatureRunner* GetRunnerHelper(
89+
tflite::Interpreter* interpreter, absl::string_view signature_name) const;
90+
91+
// The rank of the LoRA adapter.
92+
const int rank_;
93+
// A Map of names to LoRA tensors.
94+
const absl::flat_hash_map<std::string,
95+
std::vector<float, AlignedAllocator<float>>>
96+
tensors_;
97+
};
98+
4299
} // namespace ai_edge_torch::examples
43100

44101
#endif // THIRD_PARTY_PY_AI_EDGE_TORCH_GENERATIVE_EXAMPLES_CPP_UTILS_H_

0 commit comments

Comments
 (0)