|
| 1 | +local Job = require('plenary.job') |
| 2 | +local log = require("codecompanion.utils.log") |
| 3 | +local openai = require("codecompanion.adapters.openai") |
| 4 | +local utils = require("codecompanion.utils.adapters") |
| 5 | + |
| 6 | +---@alias GhToken string|nil |
| 7 | +local _gh_token |
| 8 | + |
| 9 | +local function get_github_token() |
| 10 | + local token |
| 11 | + local job = Job:new({ |
| 12 | + command = 'gh', |
| 13 | + args = { 'auth', 'token', '-h', 'github.com' }, |
| 14 | + on_exit = function(j, return_val) |
| 15 | + if return_val == 0 then |
| 16 | + token = j:result()[1] |
| 17 | + end |
| 18 | + end, |
| 19 | + }) |
| 20 | + |
| 21 | + job:sync() |
| 22 | + return token |
| 23 | +end |
| 24 | + |
| 25 | +---Authorize the GitHub OAuth token |
| 26 | +---@return GhToken |
| 27 | +local function authorize_token() |
| 28 | + if _gh_token then |
| 29 | + log:debug("Reusing gh cli token") |
| 30 | + return _gh_token |
| 31 | + end |
| 32 | + |
| 33 | + log:debug("Getting gh cli token") |
| 34 | + |
| 35 | + _gh_token = get_github_token() |
| 36 | + |
| 37 | + return _gh_token |
| 38 | +end |
| 39 | + |
| 40 | +---@class GitHubModels.Adapter: CodeCompanion.Adapter |
| 41 | +return { |
| 42 | + name = "githubmodels", |
| 43 | + formatted_name = "GitHub Models", |
| 44 | + roles = { |
| 45 | + llm = "assistant", |
| 46 | + user = "user", |
| 47 | + }, |
| 48 | + opts = { |
| 49 | + stream = true, |
| 50 | + }, |
| 51 | + features = { |
| 52 | + text = true, |
| 53 | + tokens = true, |
| 54 | + vision = false, |
| 55 | + }, |
| 56 | + url = "https://models.inference.ai.azure.com/chat/completions", |
| 57 | + env = { |
| 58 | + ---@return string|nil |
| 59 | + api_key = function() |
| 60 | + return authorize_token() |
| 61 | + end, |
| 62 | + }, |
| 63 | + headers = { |
| 64 | + Authorization = "Bearer ${api_key}", |
| 65 | + ["Content-Type"] = "application/json", |
| 66 | + -- Idea below taken from : https://github.com/github/gh-models/blob/d3b8d3e1d4c5a412e9af09a43a42eb365dac5751/internal/azuremodels/azure_client.go#L69 |
| 67 | + -- Azure would like us to send specific user agents to help distinguish |
| 68 | + -- traffic from known sources and other web requests |
| 69 | + -- send both to accommodate various Azure consumers |
| 70 | + ["x-ms-useragent"] = "Neovim/" .. vim.version().major .. "." .. vim.version().minor .. "." .. vim.version().patch, |
| 71 | + ["x-ms-user-agent"] = "Neovim/" .. vim.version().major .. "." .. vim.version().minor .. "." .. vim.version().patch, |
| 72 | + }, |
| 73 | + handlers = { |
| 74 | + ---Check for a token before starting the request |
| 75 | + ---@param self CodeCompanion.Adapter |
| 76 | + ---@return boolean |
| 77 | + setup = function(self) |
| 78 | + local model = self.schema.model.default |
| 79 | + local model_opts = self.schema.model.choices[model] |
| 80 | + if model_opts and model_opts.opts then |
| 81 | + self.opts = vim.tbl_deep_extend("force", self.opts, model_opts.opts) |
| 82 | + end |
| 83 | + |
| 84 | + if self.opts and self.opts.stream then |
| 85 | + self.parameters.stream = true |
| 86 | + end |
| 87 | + |
| 88 | + _gh_token = authorize_token() |
| 89 | + if not _gh_token then |
| 90 | + log:error("GitHub Models Adapter: Could not authorize your GitHub token") |
| 91 | + return false |
| 92 | + end |
| 93 | + |
| 94 | + return true |
| 95 | + end, |
| 96 | + |
| 97 | + --- Use the OpenAI adapter for the bulk of the work |
| 98 | + form_parameters = function(self, params, messages) |
| 99 | + return openai.handlers.form_parameters(self, params, messages) |
| 100 | + end, |
| 101 | + form_messages = function(self, messages) |
| 102 | + return openai.handlers.form_messages(self, messages) |
| 103 | + end, |
| 104 | + tokens = function(self, data) |
| 105 | + if data and data ~= "" then |
| 106 | + local data_mod = utils.clean_streamed_data(data) |
| 107 | + local ok, json = pcall(vim.json.decode, data_mod, { luanil = { object = true } }) |
| 108 | + |
| 109 | + if ok then |
| 110 | + if json.usage then |
| 111 | + local total_tokens = json.usage.total_tokens or 0 |
| 112 | + local completion_tokens = json.usage.completion_tokens or 0 |
| 113 | + local prompt_tokens = json.usage.prompt_tokens or 0 |
| 114 | + local tokens = total_tokens > 0 and total_tokens or completion_tokens + prompt_tokens |
| 115 | + log:trace("Tokens: %s", tokens) |
| 116 | + return tokens |
| 117 | + end |
| 118 | + end |
| 119 | + end |
| 120 | + end, |
| 121 | + chat_output = function(self, data) |
| 122 | + return openai.handlers.chat_output(self, data) |
| 123 | + end, |
| 124 | + inline_output = function(self, data, context) |
| 125 | + return openai.handlers.inline_output(self, data, context) |
| 126 | + end, |
| 127 | + on_exit = function(self, data) |
| 128 | + return openai.handlers.on_exit(self, data) |
| 129 | + end, |
| 130 | + }, |
| 131 | + schema = { |
| 132 | + model = { |
| 133 | + order = 1, |
| 134 | + mapping = "parameters", |
| 135 | + type = "enum", |
| 136 | + desc = |
| 137 | + "ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.", |
| 138 | + ---@type string|fun(): string |
| 139 | + default = "gpt-4o", |
| 140 | + choices = { |
| 141 | + ["o3-mini"] = { opts = { can_reason = true } }, |
| 142 | + ["o1"] = { opts = { can_reason = true } }, |
| 143 | + ["o1-mini"] = { opts = { can_reason = true } }, |
| 144 | + "claude-3.5-sonnet", |
| 145 | + "gpt-4o", |
| 146 | + "gpt-4o-mini", |
| 147 | + "DeepSeek-R1", |
| 148 | + "Codestral-2501", |
| 149 | + }, |
| 150 | + }, |
| 151 | + reasoning_effort = { |
| 152 | + order = 2, |
| 153 | + mapping = "parameters", |
| 154 | + type = "string", |
| 155 | + optional = true, |
| 156 | + condition = function(schema) |
| 157 | + local model = schema.model.default |
| 158 | + if type(model) == "function" then |
| 159 | + model = model() |
| 160 | + end |
| 161 | + if schema.model.choices[model] and schema.model.choices[model].opts then |
| 162 | + return schema.model.choices[model].opts.can_reason |
| 163 | + end |
| 164 | + end, |
| 165 | + default = "medium", |
| 166 | + desc = |
| 167 | + "Constrains effort on reasoning for reasoning models. Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response.", |
| 168 | + choices = { |
| 169 | + "high", |
| 170 | + "medium", |
| 171 | + "low", |
| 172 | + }, |
| 173 | + }, |
| 174 | + temperature = { |
| 175 | + order = 3, |
| 176 | + mapping = "parameters", |
| 177 | + type = "number", |
| 178 | + default = 0, |
| 179 | + condition = function(schema) |
| 180 | + local model = schema.model.default |
| 181 | + if type(model) == "function" then |
| 182 | + model = model() |
| 183 | + end |
| 184 | + return not vim.startswith(model, "o1") |
| 185 | + end, |
| 186 | + desc = |
| 187 | + "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both.", |
| 188 | + }, |
| 189 | + max_tokens = { |
| 190 | + order = 4, |
| 191 | + mapping = "parameters", |
| 192 | + type = "integer", |
| 193 | + default = 4096, |
| 194 | + desc = |
| 195 | + "The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length.", |
| 196 | + }, |
| 197 | + top_p = { |
| 198 | + order = 5, |
| 199 | + mapping = "parameters", |
| 200 | + type = "number", |
| 201 | + default = 1, |
| 202 | + condition = function(schema) |
| 203 | + local model = schema.model.default |
| 204 | + if type(model) == "function" then |
| 205 | + model = model() |
| 206 | + end |
| 207 | + return not vim.startswith(model, "o1") |
| 208 | + end, |
| 209 | + desc = |
| 210 | + "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both.", |
| 211 | + }, |
| 212 | + n = { |
| 213 | + order = 6, |
| 214 | + mapping = "parameters", |
| 215 | + type = "number", |
| 216 | + default = 1, |
| 217 | + condition = function(schema) |
| 218 | + local model = schema.model.default |
| 219 | + if type(model) == "function" then |
| 220 | + model = model() |
| 221 | + end |
| 222 | + return not vim.startswith(model, "o1") |
| 223 | + end, |
| 224 | + desc = "How many chat completions to generate for each prompt.", |
| 225 | + }, |
| 226 | + }, |
| 227 | +} |
0 commit comments