diff --git a/docs/source/gemini.png b/docs/source/gemini.png index f99f86d9..e450744e 100644 Binary files a/docs/source/gemini.png and b/docs/source/gemini.png differ diff --git a/docs/source/magics.rst b/docs/source/magics.rst index d8e64274..d596a71b 100644 --- a/docs/source/magics.rst +++ b/docs/source/magics.rst @@ -12,15 +12,29 @@ Here are the magics available in xeus-cpp. %%xassist ======================== -Leverage the large language models to assist in your development process. Currently supported models are Gemini - gemini-1.5-flash, OpenAI - gpt-3.5-turbo-16k. +Leverage the large language models to assist in your development process. Currently supported models are Gemini, OpenAI, Ollama. -- Save the api key +- Save the api key (for OpenAI and Gemini) .. code:: %%xassist model --save-key key +- Save the model + +- Set the response url (for Ollama) + +.. code:: + + %%xassist model --set-url + key + +.. code:: + + %%xassist model --save-model + key + - Use the model .. code:: @@ -33,9 +47,10 @@ Leverage the large language models to assist in your development process. Curren .. code:: %%xassist model --refresh + -- Example +- Examples .. image:: gemini.png -A new prompt is sent to the model everytime and the functionality to use previous context will be added soon. \ No newline at end of file +.. image:: ollama.png diff --git a/docs/source/ollama.png b/docs/source/ollama.png new file mode 100644 index 00000000..7e161aa6 Binary files /dev/null and b/docs/source/ollama.png differ diff --git a/src/xmagics/xassist.cpp b/src/xmagics/xassist.cpp index a2985e7d..ddcbbe99 100644 --- a/src/xmagics/xassist.cpp +++ b/src/xmagics/xassist.cpp @@ -60,6 +60,80 @@ namespace xcpp } }; + class model_manager + { + public: + + static void save_model(const std::string& model, const std::string& model_name) + { + std::string model_file_path = model + "_model.txt"; + std::ofstream out(model_file_path); + if (out) + { + out << model_name; + out.close(); + std::cout << "Model saved for model " << model << std::endl; + } + else + { + std::cerr << "Failed to open file for writing model for model " << model << std::endl; + } + } + + static std::string load_model(const std::string& model) + { + std::string model_file_path = model + "_model.txt"; + std::ifstream in(model_file_path); + std::string model_name; + if (in) + { + std::getline(in, model_name); + in.close(); + return model_name; + } + + std::cerr << "Failed to open file for reading model for model " << model << std::endl; + return ""; + } + }; + + class url_manager + { + public: + + static void save_url(const std::string& model, const std::string& url) + { + std::string url_file_path = model + "_url.txt"; + std::ofstream out(url_file_path); + if (out) + { + out << url; + out.close(); + std::cout << "URL saved for model " << model << std::endl; + } + else + { + std::cerr << "Failed to open file for writing URL for model " << model << std::endl; + } + } + + static std::string load_url(const std::string& model) + { + std::string url_file_path = model + "_url.txt"; + std::ifstream in(url_file_path); + std::string url; + if (in) + { + std::getline(in, url); + in.close(); + return url; + } + + std::cerr << "Failed to open file for reading URL for model " << model << std::endl; + return ""; + } + }; + class chat_history { public: @@ -209,8 +283,16 @@ namespace xcpp { curl_helper curl_helper; const std::string chat_message = xcpp::chat_history::chat("gemini", "user", cell); - const std::string url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=" - + key; + const std::string model = xcpp::model_manager::load_model("gemini"); + + if (model.empty()) + { + std::cerr << "Model not found." << std::endl; + return ""; + } + + const std::string url = "https://generativelanguage.googleapis.com/v1beta/models/" + model + + ":generateContent?key=" + key; const std::string post_data = R"({"contents": [ )" + chat_message + R"(]})"; std::string response = curl_helper.perform_request(url, post_data); @@ -231,13 +313,64 @@ namespace xcpp return j["candidates"][0]["content"]["parts"][0]["text"]; } + std::string ollama(const std::string& cell) + { + curl_helper curl_helper; + const std::string url = xcpp::url_manager::load_url("ollama"); + const std::string chat_message = xcpp::chat_history::chat("ollama", "user", cell); + const std::string model = xcpp::model_manager::load_model("ollama"); + + if (model.empty()) + { + std::cerr << "Model not found." << std::endl; + return ""; + } + + if (url.empty()) + { + std::cerr << "URL not found." << std::endl; + return ""; + } + + const std::string post_data = R"({ + "model": ")" + model + + R"(", + "messages": [)" + chat_message + + R"(], + "stream": false + })"; + + std::string response = curl_helper.perform_request(url, post_data); + + json j = json::parse(response); + + if (j.find("error") != j.end()) + { + std::cerr << "Error: " << j["error"]["message"] << std::endl; + return ""; + } + + const std::string chat = xcpp::chat_history::chat("ollama", "assistant", j["message"]["content"]); + + return j["message"]["content"]; + } + std::string openai(const std::string& cell, const std::string& key) { curl_helper curl_helper; const std::string url = "https://api.openai.com/v1/chat/completions"; const std::string chat_message = xcpp::chat_history::chat("openai", "user", cell); + const std::string model = xcpp::model_manager::load_model("openai"); + + if (model.empty()) + { + std::cerr << "Model not found." << std::endl; + return ""; + } + const std::string post_data = R"({ - "model": "gpt-3.5-turbo-16k", + "model": [)" + model + + R"(], "messages": [)" + chat_message + R"(], "temperature": 0.7 @@ -273,7 +406,7 @@ namespace xcpp std::istream_iterator() ); - std::vector models = {"gemini", "openai"}; + std::vector models = {"gemini", "openai", "ollama"}; std::string model = tokens[1]; if (std::find(models.begin(), models.end(), model) == models.end()) @@ -295,13 +428,29 @@ namespace xcpp xcpp::chat_history::refresh(model); return; } + + if (tokens[2] == "--save-model") + { + xcpp::model_manager::save_model(model, cell); + return; + } + + if (tokens[2] == "--set-url" && model == "ollama") + { + xcpp::url_manager::save_url(model, cell); + return; + } } - std::string key = xcpp::api_key_manager::load_api_key(model); - if (key.empty()) + std::string key; + if (model != "ollama") { - std::cerr << "API key for model " << model << " is not available." << std::endl; - return; + key = xcpp::api_key_manager::load_api_key(model); + if (key.empty()) + { + std::cerr << "API key for model " << model << " is not available." << std::endl; + return; + } } std::string response; @@ -313,6 +462,10 @@ namespace xcpp { response = openai(cell, key); } + else if (model == "ollama") + { + response = ollama(cell); + } std::cout << response; } diff --git a/test/test_interpreter.cpp b/test/test_interpreter.cpp index eea3c28a..7c4b5cf7 100644 --- a/test/test_interpreter.cpp +++ b/test/test_interpreter.cpp @@ -962,4 +962,39 @@ TEST_SUITE("xassist"){ std::remove("openai_api_key.txt"); } + TEST_CASE("ollama"){ + xcpp::xassist assist; + std::string line = "%%xassist ollama --set-url"; + std::string cell = "1234"; + + assist(line, cell); + + std::ifstream infile("ollama_url.txt"); + std::string content; + std::getline(infile, content); + + REQUIRE(content == "1234"); + infile.close(); + + line = "%%xassist ollama --save-model"; + cell = "1234"; + + assist(line, cell); + + std::ifstream infile_model("ollama_model.txt"); + std::string content_model; + std::getline(infile_model, content_model); + + REQUIRE(content_model == "1234"); + infile_model.close(); + + StreamRedirectRAII redirect(std::cerr); + + assist("%%xassist openai", "hello"); + + REQUIRE(!redirect.getCaptured().empty()); + + std::remove("openai_api_key.txt"); + } + } \ No newline at end of file