Skip to content

Commit 24a40cf

Browse files
feat(#136): Add support for custom AI provider with URL validation (#137)
1 parent 5be27cb commit 24a40cf

File tree

10 files changed

+133
-21
lines changed

10 files changed

+133
-21
lines changed

cmd/root.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ func NewRootCmd(out, _ io.Writer) *cobra.Command {
2929
PersistentPreRun: func(_ *cobra.Command, _ []string) { params.Log = out },
3030
}
3131
root.PersistentFlags().StringVarP(&params.Provider, "ai", "a", "none", "AI provider to use (openai, deepseek, none)")
32+
root.PersistentFlags().StringVar(&params.ProviderUrl, "ai-url", "", "Custom URL for AI provider")
3233
root.PersistentFlags().StringVarP(&params.Token, "token", "t", "", "Token for the AI provider (if required)")
3334
root.PersistentFlags().StringVar(&params.Playbook, "playbook", "", "Path to a user-defined YAML playbook for AI integration")
3435
root.PersistentFlags().BoolVar(&params.MockProject, "mock-project", false, "Use mock project")

internal/brain/brain.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,30 @@ const ollama = "ollama"
1515

1616
const mock = "mock"
1717

18+
const custom = "custom"
19+
1820
// New creates a new instance of Brain based on the provided provider and optional playbook strings.
19-
func New(provider, token, model, system string, playbook ...string) (Brain, error) {
21+
func New(provider, url, token, model, system string, playbook ...string) (Brain, error) {
2022
switch provider {
2123
case deepseek:
2224
return NewDeepSeek(token, system), nil
2325
case openai:
24-
return NewOpenAI(token, system), nil
26+
return NewOpenAIDefault(token, system)
2527
case mock:
2628
if len(playbook) == 0 {
2729
return NewMock(), nil
2830
}
2931
return NewMock(playbook[0]), nil
3032
case ollama:
3133
return NewOllama("http://localhost:11434", model, token, system), nil
34+
case custom:
35+
if url == "" {
36+
return nil, fmt.Errorf("custom provider requires a URL")
37+
}
38+
if model == "" {
39+
return nil, fmt.Errorf("custom provider requires a model")
40+
}
41+
return NewCustom(token, url, model, system)
3242
default:
3343
return nil, fmt.Errorf("unknown provider: %s", provider)
3444
}

internal/brain/brain_test.go

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@ import (
77
"github.com/stretchr/testify/require"
88
)
99

10-
const system = "system prompt"
10+
const (
11+
system = "system prompt"
12+
address = "http://custom-provider.com/v1/chat/completions"
13+
)
1114

1215
func TestNew_WithDeepSeekProvider_ReturnsDeepSeekBrain(t *testing.T) {
1316
token := "valid_token"
1417

15-
result, err := New(deepseek, token, "gemma3", system)
18+
result, err := New(deepseek, address, token, "gemma3", system)
1619

1720
require.NoError(t, err, "Expected no error when creating DeepSeek brain")
1821
_, ok := result.(*deepSeek)
@@ -22,25 +25,50 @@ func TestNew_WithDeepSeekProvider_ReturnsDeepSeekBrain(t *testing.T) {
2225
func TestNew_WithOpenAIProvider_ReturnsOpenAIBrain(t *testing.T) {
2326
token := "valid_openai_token"
2427

25-
result, err := New(openai, token, "llama", system)
28+
result, err := New(openai, address, token, "llama", system)
2629

2730
require.NoError(t, err, "Expected no error when creating OpenAI brain")
2831
_, ok := result.(*openAI)
2932
require.True(t, ok, "Expected result to be of type OpenAI")
3033
}
3134

3235
func TestNew_MockProviderNoPlaybook_ReturnsMockInstance(t *testing.T) {
33-
result, err := New(mock, "test-token", "qwen", system)
36+
result, err := New(mock, address, "test-token", "qwen", system)
3437

3538
require.NoError(t, err)
3639
_, ok := result.(*mockBrain)
3740
require.True(t, ok, "Expected result to be of type Mock")
3841
}
3942

4043
func TestNew_UnknownProvider_ReturnsError(t *testing.T) {
41-
b, err := New("unknown", "test-token", "deepseek", system)
44+
b, err := New("unknown", address, "test-token", "deepseek", system)
4245

4346
require.Error(t, err)
4447
assert.Nil(t, b)
4548
assert.Contains(t, err.Error(), "unknown provider: unknown")
4649
}
50+
51+
func TestNew_CustomProvider_ReturnsCustomInstance(t *testing.T) {
52+
b, err := New(custom, address, "custom-token", "custom-model", system)
53+
54+
require.NoError(t, err)
55+
assert.NotNil(t, b)
56+
_, ok := b.(*openAI)
57+
require.True(t, ok, "Expected result to be OpenAI compatible instance")
58+
}
59+
60+
func TestNew_CutomProviderMissingURL_ReturnsError(t *testing.T) {
61+
b, err := New(custom, "", "custom-token", "custom-model", system)
62+
63+
require.Error(t, err)
64+
assert.Nil(t, b)
65+
require.Contains(t, err.Error(), "custom provider requires a URL")
66+
}
67+
68+
func TestNew_CustomProviderMissingModel_ReturnsError(t *testing.T) {
69+
b, err := New(custom, address, "custom-token", "", system)
70+
71+
require.Error(t, err)
72+
assert.Nil(t, b)
73+
require.Contains(t, err.Error(), "custom provider requires a model")
74+
}

internal/brain/custom.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package brain
2+
3+
func NewCustom(token, url, model, system string) (Brain, error) {
4+
return NewOpenAI(token, url, model, system)
5+
}

internal/brain/custom_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package brain
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestNewCustom(t *testing.T) {
11+
token := "test_token"
12+
url := "http://test_url.com/v1/chat/completions"
13+
model := "test_model"
14+
system := "test_system"
15+
custom, err := NewCustom(token, url, model, system)
16+
require.NoError(t, err)
17+
assert.NotNil(t, custom)
18+
expected, err := NewOpenAI(token, url, model, system)
19+
require.NoError(t, err)
20+
assert.Equal(t, expected, custom)
21+
}

internal/brain/deepseek.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ type deepseekMsg struct {
4343
func NewDeepSeek(apiKey, _ string) Brain {
4444
return &deepSeek{
4545
token: apiKey,
46-
url: "https://api.deepseek.com/chat/completions",
46+
url: "https://api.deepseek.com/v1/chat/completions",
4747
model: "deepseek-chat",
4848
}
4949
}

internal/brain/openai.go

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"io"
99
"net/http"
1010
"strings"
11+
12+
"github.com/cqfn/refrax/internal/log"
1113
)
1214

1315
// OpenAI represents a client for interacting with the OpenAI API
@@ -19,8 +21,10 @@ type openAI struct {
1921
}
2022

2123
type openaiReq struct {
22-
Model string `json:"model"`
23-
Messages []openaiMsg `json:"messages"`
24+
Model string `json:"model"`
25+
Messages []openaiMsg `json:"messages"`
26+
Stream bool `json:"stream"`
27+
Temperature *float64 `json:"temperature,omitempty"`
2428
}
2529

2630
type openaiResp struct {
@@ -36,14 +40,23 @@ type openaiMsg struct {
3640
Content string `json:"content"`
3741
}
3842

39-
// NewOpenAI creates a new OpenAI instance
40-
func NewOpenAI(apiKey, system string) Brain {
43+
// NewOpenAIDefault creates a new OpenAI instance with default settings
44+
func NewOpenAIDefault(token, system string) (Brain, error) {
45+
return NewOpenAI(token, "https://api.openai.com/v1/chat/completions", "gpt-3.5-turbo", system)
46+
}
47+
48+
// NewOpenAI creates a new OpenAI instance with the provided settings
49+
func NewOpenAI(token, url, model, system string) (Brain, error) {
50+
err := verifyUrl(url)
51+
if err != nil {
52+
return nil, err
53+
}
4154
return &openAI{
42-
token: apiKey,
43-
url: "https://api.openai.com/v1/chat/completions",
44-
model: "gpt-3.5-turbo", // Default model
55+
token: token,
56+
url: url,
57+
model: model,
4558
system: system,
46-
}
59+
}, nil
4760
}
4861

4962
// Ask sends a question to the OpenAI API
@@ -52,12 +65,16 @@ func (o *openAI) Ask(question string) (string, error) {
5265
}
5366

5467
func (o *openAI) send(system, user string) (answer string, err error) {
68+
log.Debug("sending request to '%s', model '%s', and prompt: '%s'", o.url, o.model, user)
69+
temp := float64(0.0)
5570
body := openaiReq{
5671
Model: o.model,
5772
Messages: []openaiMsg{
5873
{Role: "system", Content: system},
5974
{Role: "user", Content: strings.TrimSpace(user)},
6075
},
76+
Stream: false,
77+
Temperature: &temp,
6178
}
6279
data, err := json.Marshal(body)
6380
if err != nil {
@@ -68,7 +85,7 @@ func (o *openAI) send(system, user string) (answer string, err error) {
6885
return "", err
6986
}
7087
req.Header.Set("Content-Type", "application/json")
71-
req.Header.Set("Authorization", "Bearer "+o.token)
88+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", o.token))
7289
resp, err := http.DefaultClient.Do(req)
7390
if err != nil {
7491
return "", fmt.Errorf("API request failed: %w", err)
@@ -80,7 +97,7 @@ func (o *openAI) send(system, user string) (answer string, err error) {
8097
}()
8198
if resp.StatusCode != http.StatusOK {
8299
body, _ := io.ReadAll(resp.Body)
83-
return "", fmt.Errorf("API error: %s", body)
100+
return "", fmt.Errorf("API error (code: %d): %s", resp.StatusCode, body)
84101
}
85102
var response openaiResp
86103
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
@@ -91,3 +108,13 @@ func (o *openAI) send(system, user string) (answer string, err error) {
91108
}
92109
return response.Choices[0].Message.Content, nil
93110
}
111+
112+
func verifyUrl(url string) error {
113+
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
114+
return fmt.Errorf("invalid URL: must start with http:// or https://")
115+
}
116+
if !strings.HasSuffix(url, "/v1/chat/completions") {
117+
return fmt.Errorf("invalid URL: must end with /v1/chat/completions")
118+
}
119+
return nil
120+
}

internal/brain/openai_test.go

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@ package brain
33
import (
44
"testing"
55

6+
"github.com/stretchr/testify/assert"
67
"github.com/stretchr/testify/require"
78
)
89

910
func TestOpenAI_Ask_PositiveCase(t *testing.T) {
1011
server := NewEchoServer(t, "gpt-3.5-turbo", "test_api_key")
1112
defer server.Close()
1213

13-
openai := NewOpenAI("test_api_key", "openai sys prompt")
14+
openai, err := NewOpenAIDefault("test_api_key", "openai sys prompt")
15+
require.NoError(t, err)
1416
openai.(*openAI).url = server.URL
1517

1618
answer, err := openai.Ask("This is a test question")
@@ -23,11 +25,28 @@ func TestOpenAI_Ask_NegativeCase(t *testing.T) {
2325
server := NewErrorServer(t)
2426
defer server.Close()
2527

26-
openai := NewOpenAI("test_api_key", "openai system prompt")
28+
openai, err := NewOpenAIDefault("test_api_key", "openai system prompt")
29+
require.NoError(t, err)
2730
openai.(*openAI).url = server.URL
2831

2932
answer, err := openai.Ask("This is a test question")
3033

3134
require.Error(t, err)
3235
require.Empty(t, answer)
3336
}
37+
38+
func TestOpenAI_WrongUrlPrefix(t *testing.T) {
39+
brain, err := NewOpenAI("token:", "ftp://invalid-url.com", "model", "system")
40+
41+
assert.Nil(t, brain)
42+
assert.Error(t, err)
43+
assert.Contains(t, err.Error(), "must start with http:// or https://")
44+
}
45+
46+
func TestOpenAI_Ask_EmptyResponse(t *testing.T) {
47+
brain, err := NewOpenAI("token", "http://invalid-url.com", "model", "system")
48+
49+
assert.Nil(t, brain)
50+
assert.Error(t, err)
51+
assert.Contains(t, err.Error(), "must end with /v1/chat/completions")
52+
}

internal/client/params.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import "io"
55
// Params holds the configuration parameters for Refrax commands.
66
type Params struct {
77
Provider string
8+
ProviderUrl string
89
Token string
910
Playbook string
1011
MockProject bool

internal/client/refrax_client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ func printStats(p Params, s ...*stats.Stats) error {
305305
}
306306

307307
func mind(p Params, token, model string, system *prompts.System, s *stats.Stats) (brain.Brain, error) {
308-
ai, err := brain.New(p.Provider, token, model, system.String(), p.Playbook)
308+
ai, err := brain.New(p.Provider, p.ProviderUrl, token, model, system.String(), p.Playbook)
309309
if p.Stats {
310310
ai = brain.NewMetricBrain(ai, s)
311311
}

0 commit comments

Comments
 (0)