Skip to content

Commit 486dd4e

Browse files
First working llama shared weights flow
1 parent 84c81a3 commit 486dd4e

File tree

10 files changed

+347
-120
lines changed

10 files changed

+347
-120
lines changed

backends/mediatek/runtime/NeuronBackend.cpp

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,26 @@ Result<DelegateHandle*> NeuronBackend::init(
7373
"NeuronBackend",
7474
"SharedWeights Enabled for %s",
7575
shared_weights_key.c_str());
76-
76+
std::shared_ptr<NeuronSharedWeights> neuron_shared_weights;
77+
if (neuron_shared_weights_cache_.find(shared_weights_key) !=
78+
neuron_shared_weights_cache_.end()) {
79+
neuron_shared_weights =
80+
neuron_shared_weights_cache_.at(shared_weights_key).lock();
81+
if (neuron_shared_weights) {
82+
LogInfo(
83+
"NeuronBackend",
84+
"Reusing cached shared weights with key %s",
85+
shared_weights_key.c_str());
86+
delegate->SetSharedWeights(neuron_shared_weights);
87+
continue;
88+
} else {
89+
LogInfo(
90+
"NeuronBackend",
91+
"Shared weights cache expired: %s",
92+
shared_weights_key.c_str());
93+
neuron_shared_weights_cache_.erase(shared_weights_key); // Expired
94+
}
95+
}
7796
const NamedDataMap* named_data_map = context.get_named_data_map();
7897
Result<FreeableBuffer> shared_weights =
7998
named_data_map->get_data(shared_weights_key.c_str());
@@ -84,7 +103,11 @@ Result<DelegateHandle*> NeuronBackend::init(
84103
"Loaded shared weights from named_data_map. Size: %zu",
85104
shared_weights.get().size());
86105
FreeableBuffer& buffer = shared_weights.get();
87-
delegate->SetSharedWeights(buffer);
106+
neuron_shared_weights =
107+
std::make_shared<NeuronSharedWeights>(std::move(buffer));
108+
delegate->SetSharedWeights(neuron_shared_weights);
109+
neuron_shared_weights_cache_[shared_weights_key] =
110+
neuron_shared_weights;
88111
} else {
89112
LogError(
90113
"NeuronBackend",
@@ -148,13 +171,10 @@ Error NeuronExecuTorchDelegate::execute(
148171
auto allocator = dynamic_cast<torch::executor::neuron::BufferAllocator*>(
149172
context.get_temp_allocator());
150173

151-
bool has_shared_weights_input = neuron_shared_weights_.size() > 0;
152-
153-
size_t inputCount =
154-
has_shared_weights_input ? mInputSizes.size() + 1 : mInputSizes.size();
174+
size_t inputCount = mInputSizes.size() + neuron_shared_weights_.size();
155175
size_t outputCount = mOutputSizes.size();
156176

157-
for (int i = 0; i < inputCount; i++) {
177+
for (size_t i = 0; i < inputCount; i++) {
158178
auto data_ptr = mPreparedInputs[i].data_ptr;
159179
auto data_size = mPreparedInputs[i].size;
160180
if (IsCached</*isInput=*/true>(i, data_ptr)) {
@@ -171,7 +191,7 @@ Error NeuronExecuTorchDelegate::execute(
171191
}
172192
}
173193

174-
for (int o = 0; o < outputCount; o++) {
194+
for (size_t o = 0; o < outputCount; o++) {
175195
auto data_ptr = mPreparedOutputs[o].data_ptr;
176196
auto data_size = mPreparedOutputs[o].size;
177197
if (IsCached</*isInput=*/false>(o, data_ptr)) {

backends/mediatek/runtime/include/NeuronBackend.h

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,45 @@ using executorch::runtime::EValue;
3232
using executorch::runtime::FreeableBuffer;
3333
using executorch::runtime::Result;
3434

35+
class NeuronSharedWeights {
36+
public:
37+
explicit NeuronSharedWeights(const FreeableBuffer& shared_weights_buffer) {
38+
auto& buffer_allocator = GET_NEURON_ALLOCATOR;
39+
nbytes_ = shared_weights_buffer.size();
40+
data_ = buffer_allocator.Allocate(nbytes_);
41+
ET_CHECK_MSG(
42+
data_ != nullptr,
43+
"Error: Failed to allocate memory for shared weights of size %zu",
44+
nbytes_);
45+
std::memcpy(data_, shared_weights_buffer.data(), nbytes_);
46+
}
47+
48+
explicit NeuronSharedWeights(FreeableBuffer&& shared_weights_buffer)
49+
: NeuronSharedWeights(shared_weights_buffer) {
50+
shared_weights_buffer.Free();
51+
}
52+
53+
~NeuronSharedWeights() {
54+
if (data_ == nullptr || nbytes_ == 0) {
55+
return;
56+
}
57+
auto& buffer_allocator = GET_NEURON_ALLOCATOR;
58+
buffer_allocator.RemoveBuffer(data_);
59+
}
60+
61+
void* data() const {
62+
return data_;
63+
}
64+
65+
size_t size() const {
66+
return nbytes_;
67+
}
68+
69+
private:
70+
void* data_ = nullptr;
71+
size_t nbytes_ = 0;
72+
};
73+
3574
class NeuronBackend final : public ::executorch::runtime::BackendInterface {
3675
public:
3776
::executorch::runtime::Result<::executorch::runtime::DelegateHandle*> init(
@@ -48,6 +87,10 @@ class NeuronBackend final : public ::executorch::runtime::BackendInterface {
4887
void destroy(::executorch::runtime::DelegateHandle* handle) const override;
4988

5089
bool is_available() const override;
90+
91+
private:
92+
mutable std::unordered_map<std::string, std::weak_ptr<NeuronSharedWeights>>
93+
neuron_shared_weights_cache_;
5194
};
5295

5396
extern const char kHighAddrKey[];
@@ -79,8 +122,7 @@ class NeuronExecuTorchDelegate {
79122
void* data_ptr;
80123
size_t size;
81124

82-
InputOutputInfo(void* ptr, size_t sz)
83-
: data_ptr(ptr), size(sz) {}
125+
InputOutputInfo(void* ptr, size_t sz) : data_ptr(ptr), size(sz) {}
84126
};
85127

86128
class MemoryCache {
@@ -129,8 +171,8 @@ class NeuronExecuTorchDelegate {
129171
return NEURON_NO_ERROR;
130172
}
131173

132-
int SetSharedWeights(FreeableBuffer& buffer) {
133-
neuron_shared_weights_.push_back(std::move(buffer));
174+
int SetSharedWeights(std::shared_ptr<NeuronSharedWeights> sharedWeights) {
175+
neuron_shared_weights_.push_back(sharedWeights);
134176
return NEURON_NO_ERROR;
135177
}
136178

@@ -202,11 +244,12 @@ class NeuronExecuTorchDelegate {
202244
mPreparedInputs.push_back(InputOutputInfo{data_ptr, data_size});
203245
}
204246

205-
// Prepare shared weights if any as the last model input
247+
// Prepare shared weights if any as the last model inputs
206248
if (has_shared_weights_input) {
207-
FreeableBuffer& buffer = neuron_shared_weights_.at(0);
208-
mPreparedInputs.push_back(
209-
InputOutputInfo{const_cast<void*>(buffer.data()), buffer.size()});
249+
for (const auto& shared_weights : neuron_shared_weights_) {
250+
mPreparedInputs.push_back(
251+
InputOutputInfo{shared_weights->data(), shared_weights->size()});
252+
}
210253
}
211254

212255
// Prepare output data
@@ -242,7 +285,8 @@ class NeuronExecuTorchDelegate {
242285

243286
mutable std::unordered_set<const void*> mHasImported;
244287

245-
mutable std::vector<FreeableBuffer> neuron_shared_weights_;
288+
mutable std::vector<std::shared_ptr<NeuronSharedWeights>>
289+
neuron_shared_weights_;
246290

247291
private:
248292
NeuronExecuTorchDelegate(const NeuronExecuTorchDelegate&);

examples/mediatek/executor_runner/llama_runner/LlamaConfig.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ struct LlamaModelPaths {
4040
std::string token_embedding_path;
4141
std::vector<std::string> prompt_model_paths;
4242
std::vector<std::string> gen_model_paths;
43+
std::vector<std::string> model_package_paths;
4344
};
4445

4546
} // namespace example

examples/mediatek/executor_runner/llama_runner/LlamaModelChunk.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include "LlamaConfig.h"
2323
#include "LlamaModelChunk.h"
24+
#include "Utils.h"
2425
#include "llm_helper/include/llm_types.h"
2526

2627
#include "llm_helper/include/mask_builder.h"
@@ -42,11 +43,13 @@ inline std::vector<size_t> getIndexRange(
4243
LlamaModelChunk::LlamaModelChunk(
4344
const ModelPathMap& modelPathMap,
4445
const LlamaModelOptions& modelOptions,
46+
const bool useSharedWeights,
4547
const size_t initBatchSize,
4648
const size_t numCache,
4749
const size_t numRotEmbInputs,
4850
const RotaryEmbeddingMasterLut* rotEmbMasterLut)
4951
: ModelChunk(modelPathMap, initBatchSize),
52+
kIsSharedWeightsUsed(useSharedWeights),
5053
kMaxTokenLength(modelOptions.max_token_length),
5154
kCacheLength(modelOptions.cache_size),
5255
kMaskType(modelOptions.mask_type),
@@ -61,6 +64,29 @@ LlamaModelChunk::LlamaModelChunk(
6164

6265
LlamaModelChunk::~LlamaModelChunk() {}
6366

67+
std::string LlamaModelChunk::SelectMethod(
68+
const std::vector<std::string>& methodNames) const {
69+
const size_t curTokenSize = GetModelId();
70+
for (const auto& methodName : methodNames) {
71+
const auto matches = utils::extract_substr(methodName, "([0-9]+)t[0-9]+c");
72+
ET_CHECK_MSG(
73+
matches.size() == 2, "Invalid method name: %s", methodName.c_str());
74+
// Extract the first match group as token size
75+
const size_t methodTokenSize =
76+
static_cast<size_t>(std::atol(matches[1].c_str()));
77+
if (curTokenSize == methodTokenSize) {
78+
ET_LOG(
79+
Debug,
80+
"Selected method \"%s\" for token size %zu",
81+
methodName.c_str(),
82+
curTokenSize);
83+
return methodName;
84+
}
85+
}
86+
ET_LOG(Error, "Unable to find suitable method, fallback to use the first method.");
87+
return {};
88+
}
89+
6490
size_t LlamaModelChunk::GetExpectedInputCount() const {
6591
const size_t rotEmbInputCount = kRotEmbInputIndexes.size();
6692
const size_t cacheInputCount = kCacheInputIndexes.size();

examples/mediatek/executor_runner/llama_runner/LlamaModelChunk.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class LlamaModelChunk : public ModelChunk {
4444
explicit LlamaModelChunk(
4545
const ModelPathMap& modelPathMap,
4646
const LlamaModelOptions& modelOptions,
47+
const bool useSharedWeights,
4748
const size_t initBatchSize,
4849
const size_t numCache,
4950
const size_t numRotEmbInputs,
@@ -104,6 +105,17 @@ class LlamaModelChunk : public ModelChunk {
104105
size_t GetExpectedOutputCount() const;
105106

106107
private:
108+
bool AllowModelsCoexist() const override {
109+
return kIsSharedWeightsUsed;
110+
}
111+
112+
std::string SelectMethod(
113+
const std::vector<std::string>& methodNames) const override;
114+
115+
private:
116+
// Whether shared weights is used
117+
bool kIsSharedWeightsUsed = false;
118+
107119
// Input/Output Indexes
108120
const size_t kMaskInputIndex;
109121
const std::vector<size_t> kRotEmbInputIndexes;

examples/mediatek/executor_runner/llama_runner/LlamaRuntime.cpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@ void LlamaRuntime::Initialize(
2424
const LlamaModelOptions& modelOptions,
2525
const LlamaModelPaths& modelPaths) {
2626
mModelOptions = modelOptions;
27-
const size_t numChunk = modelPaths.gen_model_paths.size();
28-
const size_t numCache = 2 * modelOptions.num_layer / numChunk;
29-
ET_CHECK_MSG(numChunk > 0, "No model to initialize");
3027

3128
// Initialize rotary embedding master lookup table
3229
const size_t rotEmbDim = modelOptions.hidden_size / modelOptions.num_head;
@@ -37,25 +34,49 @@ void LlamaRuntime::Initialize(
3734
modelOptions.rot_emb_base);
3835
mRotEmbMasterLut->generate();
3936

37+
const bool useSharedWeights = !modelPaths.model_package_paths.empty();
38+
39+
ET_CHECK_MSG(
40+
!useSharedWeights ||
41+
modelPaths.prompt_model_paths.empty() &&
42+
modelPaths.gen_model_paths.empty(),
43+
"The paths for both prompt and gen model paths should be empty when shared weights is used.");
44+
45+
const size_t numChunk = useSharedWeights
46+
? modelPaths.model_package_paths.size()
47+
: modelPaths.gen_model_paths.size();
48+
ET_CHECK_MSG(numChunk > 0, "No model to initialize");
49+
const size_t numCache = 2 * modelOptions.num_layer / numChunk;
50+
4051
constexpr size_t numRotEmbInputs = 1;
41-
const bool usePromptModel = !modelPaths.prompt_model_paths.empty();
52+
const bool usePromptModel = !modelPaths.prompt_model_paths.empty() ||
53+
!modelPaths.model_package_paths.empty();
4254
const size_t initBatchSize =
4355
usePromptModel ? modelOptions.prompt_token_batch_size : 1;
4456
mTokenBatchSize = initBatchSize;
4557

58+
// Get effective prompt and gen model paths
59+
const auto& [prompt_model_paths, gen_model_paths] = [&] {
60+
if (useSharedWeights) {
61+
return std::pair{
62+
modelPaths.model_package_paths, modelPaths.model_package_paths};
63+
}
64+
return std::pair{modelPaths.prompt_model_paths, modelPaths.gen_model_paths};
65+
}();
66+
4667
for (size_t chunkIdx = 0; chunkIdx < numChunk; chunkIdx++) {
4768
ModelPathMap modelPathMap;
4869
auto addModelPath = [&](const auto& modelPaths, const size_t batchSize) {
4970
if (modelPaths.empty())
5071
return;
5172
modelPathMap[batchSize] = modelPaths[chunkIdx];
5273
};
53-
addModelPath(
54-
modelPaths.prompt_model_paths, modelOptions.prompt_token_batch_size);
55-
addModelPath(modelPaths.gen_model_paths, 1);
74+
addModelPath(prompt_model_paths, modelOptions.prompt_token_batch_size);
75+
addModelPath(gen_model_paths, 1);
5676
auto llamaChunk = std::make_unique<LlamaModelChunk>(
5777
modelPathMap,
5878
modelOptions,
79+
useSharedWeights,
5980
initBatchSize,
6081
numCache,
6182
numRotEmbInputs,

0 commit comments

Comments
 (0)