Skip to content

Commit a70e942

Browse files
committed
feat: support Grok-4 by adding x.ai API category (xai)
1 parent c37f154 commit a70e942

File tree

2 files changed

+38
-27
lines changed

2 files changed

+38
-27
lines changed

lua/gp/config.lua

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ local config = {
7272
endpoint = "https://api.anthropic.com/v1/messages",
7373
secret = os.getenv("ANTHROPIC_API_KEY"),
7474
},
75+
xai = {
76+
disable = true,
77+
endpoint = "https://api.x.ai/v1/chat/completions",
78+
secret = os.getenv("XAI_API_KEY"),
79+
},
7580
},
7681

7782
-- prefix for all commands
@@ -310,6 +315,15 @@ local config = {
310315
model = { model = "claude-3-5-haiku-latest", temperature = 0.8, top_p = 1 },
311316
system_prompt = require("gp.defaults").code_system_prompt,
312317
},
318+
{
319+
provider = "xai",
320+
name = "Grok-4",
321+
chat = false,
322+
command = true,
323+
-- string with model name or table with model name and parameters
324+
model = { model = "grok-4-latest", temperature = 0 },
325+
system_prompt = require("gp.defaults").code_system_prompt,
326+
},
313327
{
314328
provider = "ollama",
315329
name = "CodeOllamaLlama3.1-8B",

lua/gp/dispatcher.lua

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -169,38 +169,27 @@ D.prepare_payload = function(messages, model, provider)
169169
return payload
170170
end
171171

172-
if provider == "ollama" then
172+
if provider == "xai" then
173+
local system = ""
174+
local i = 1
175+
while i < #messages do
176+
if messages[i].role == "system" then
177+
system = system .. messages[i].content .. "\n"
178+
table.remove(messages, i)
179+
else
180+
i = i + 1
181+
end
182+
end
183+
173184
local payload = {
174185
model = model.model,
175186
stream = true,
176187
messages = messages,
188+
system = system,
189+
max_tokens = model.max_tokens or 4096,
190+
temperature = math.max(0, math.min(2, model.temperature or 1)),
191+
top_p = math.max(0, math.min(1, model.top_p or 1)),
177192
}
178-
179-
if model.think ~= nil then
180-
payload.think = model.think
181-
end
182-
183-
local options = {}
184-
if model.temperature then
185-
options.temperature = math.max(0, math.min(2, model.temperature))
186-
end
187-
if model.top_p then
188-
options.top_p = math.max(0, math.min(1, model.top_p))
189-
end
190-
if model.min_p then
191-
options.min_p = math.max(0, math.min(1, model.min_p))
192-
end
193-
if model.num_ctx then
194-
options.num_ctx = model.num_ctx
195-
end
196-
if model.top_k then
197-
options.top_k = model.top_k
198-
end
199-
200-
if next(options) then
201-
payload.options = options
202-
end
203-
204193
return payload
205194
end
206195

@@ -454,6 +443,14 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
454443
endpoint = render.template_replace(endpoint, "{{model}}", payload.model)
455444
elseif provider == "ollama" then
456445
headers = {}
446+
elseif provider == "xai" then
447+
-- currently xai only uses bearer token for authentication.
448+
-- since I cannot sure its going to be that way for long time
449+
-- branching out as another condition.
450+
headers = {
451+
"-H",
452+
"Authorization: Bearer " .. bearer,
453+
}
457454
else -- default to openai compatible headers
458455
headers = {
459456
"-H",

0 commit comments

Comments
 (0)