Skip to content

Commit 1896c73

Browse files
committed
feat: expose ModelRegistry to the plugin, ref #13
1 parent 5567a93 commit 1896c73

File tree

6 files changed

+22
-21
lines changed

6 files changed

+22
-21
lines changed

ac-local-plugin/code/LocalLlama.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,14 @@ class LlamaModel final : public Model {
236236
public:
237237

238238
LlamaModel(const std::string& gguf, std::span<std::string> loras, std::vector<llama::ControlVector::LoadInfo>& ctrlVectors, llama::ModelLoadProgressCb pcb, llama::Model::Params params)
239-
: m_model(std::make_shared<llama::Model>(llama::Model(gguf.c_str(), loras, astl::move(pcb), astl::move(params))))
239+
: m_model(std::make_shared<llama::Model>(llama::ModelRegistry::getInstance().loadModel(gguf.c_str(), astl::move(pcb), params), astl::move(params)))
240240
, m_ctrlVectors(astl::move(ctrlVectors))
241-
{}
241+
{
242+
for(auto& loraPath: loras) {
243+
auto lora = llama::ModelRegistry::getInstance().loadLora(m_model.get(), loraPath);
244+
m_model->addLora(lora);;
245+
}
246+
}
242247

243248
virtual std::unique_ptr<Instance> createInstance(std::string_view type, Dict params) override {
244249
ac::llama::ControlVector ctrlVector(*m_model, m_ctrlVectors);

code/ac/llama/Model.cpp

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,10 @@ llama_model_params llamaFromModelParams(const Model::Params& params, ModelLoadPr
3434
} // namespace
3535

3636

37-
Model::Model(const char* pathToGguf, std::span<std::string> loras, ModelLoadProgressCb loadProgressCb, Params params)
38-
: m_params(astl::move(params))
39-
{
40-
m_lmodel = ModelRegistry::getInstance().loadModel(pathToGguf, std::move(loadProgressCb), m_params);
41-
if (!m_lmodel) {
42-
throw std::runtime_error("Failed to load model");
43-
}
44-
45-
for(auto& loraPath: loras) {
46-
auto lora = ModelRegistry::getInstance().loadLora(this, loraPath);
47-
m_loras.push_back(lora);
48-
}
49-
}
37+
Model::Model(std::shared_ptr<llama_model> lmodel, Params params)
38+
: m_lmodel(std::move(lmodel))
39+
, m_params(astl::move(params))
40+
{}
5041

5142
Model::~Model() = default;
5243

code/ac/llama/Model.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class AC_LLAMA_EXPORT Model {
2626
bool prefixInputsWithBos = false; // add bos token to interactive inputs (#13)
2727
};
2828

29-
Model(const char* pathToGguf, std::span<std::string> loras, ModelLoadProgressCb loadProgressCb, Params params);
29+
Model(std::shared_ptr<llama_model> model, Params params);
3030
~Model();
3131

3232
const Params& params() const noexcept { return m_params; }
@@ -57,7 +57,7 @@ class AC_LLAMA_EXPORT Model {
5757
Vocab m_vocab{*this};
5858
};
5959

60-
class ModelRegistry {
60+
class AC_LLAMA_EXPORT ModelRegistry {
6161
public:
6262
static ModelRegistry& getInstance() {
6363
static ModelRegistry instance;

example/e-basic.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ int main() try {
4545
}
4646
return true;
4747
};
48-
ac::llama::Model model(modelGguf.c_str(), {}, modelLoadProgressCallback, modelParams);
48+
auto lmodel = ac::llama::ModelRegistry::getInstance().loadModel(modelGguf, modelLoadProgressCallback, modelParams);
49+
ac::llama::Model model(lmodel, modelParams);
4950

5051

5152
// create inference instance

example/e-gui.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class UModel {
6161
class State {
6262
public:
6363
State(const std::string& ggufPath, const ac::llama::Model::Params& modelParams)
64-
: m_model(ggufPath.c_str(), {}, printModelLoadProgress, modelParams)
64+
: m_model(ac::llama::ModelRegistry::getInstance().loadModel(ggufPath.c_str(), printModelLoadProgress, modelParams), modelParams)
6565
{}
6666

6767
class Instance {

test/t-integration.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ GlobalFixture globalFixture;
2222
const char* Model_117m_q6_k = AC_TEST_DATA_LLAMA_DIR "/gpt2-117m-q6_k.gguf";
2323

2424
TEST_CASE("vocab only") {
25-
ac::llama::Model model(Model_117m_q6_k, {}, {}, { .vocabOnly = true });
25+
ac::llama::Model::Params iParams = { .vocabOnly = true };
26+
auto lmodel = ac::llama::ModelRegistry::getInstance().loadModel(Model_117m_q6_k, {}, iParams);
27+
ac::llama::Model model(lmodel, iParams);
2628
CHECK(!!model.lmodel());
2729

2830
auto& params = model.params();
@@ -40,7 +42,9 @@ TEST_CASE("vocab only") {
4042
}
4143

4244
TEST_CASE("inference") {
43-
ac::llama::Model model(Model_117m_q6_k, {}, {}, {});
45+
ac::llama::Model::Params iParams = {};
46+
auto lmodel = ac::llama::ModelRegistry::getInstance().loadModel(Model_117m_q6_k, {}, iParams);
47+
ac::llama::Model model(lmodel, iParams);
4448
CHECK(!!model.lmodel());
4549

4650
auto& params = model.params();

0 commit comments

Comments
 (0)