Skip to content

Commit 39afdb6

Browse files
authored
feat: add configurable Gemini provider with VertexAI and GeminiAPI support (#240)
- Add support for Gemini provider configuration, including project, location, backend, and API key options - Ensure gemini.api_key is treated as sensitive and hidden in config listing - Allow Gemini client to fall back to openai.api_key if gemini.api_key is not set - Pass Gemini backend, project, and location settings to the Gemini client - Refactor Gemini client initialization to support both VertexAI and GeminiAPI backends with appropriate configuration - Add validation for required project and location fields when using the VertexAI backend - Introduce new option functions: WithProject, WithLocation, and WithBackend for flexible Gemini configuration Signed-off-by: Bo-Yi Wu <[email protected]>
1 parent 2d74b54 commit 39afdb6

File tree

5 files changed

+87
-22
lines changed

5 files changed

+87
-22
lines changed

cmd/config_list.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ var availableKeys = map[string]string{
3737
"openai.frequency_penalty": "Parameter to reduce repetition by penalizing tokens based on their frequency",
3838
"openai.presence_penalty": "Parameter to encourage topic diversity by penalizing previously used tokens",
3939
"prompt.folder": "Directory path for custom prompt templates",
40+
"gemini.project": "VertexAI project for Gemini provider",
41+
"gemini.location": "VertexAI location for Gemini provider",
42+
"gemini.backend": "Gemini backend (BackendGeminiAPI or BackendVertexAI)",
43+
"gemini.api_key": "API key for Gemini provider",
4044
}
4145

4246
// configListCmd represents the command to list the configuration values.
@@ -65,7 +69,7 @@ var configListCmd = &cobra.Command{
6569
// Add the key and value to the table
6670
for _, v := range keys {
6771
// Hide the api key
68-
if v == "openai.api_key" {
72+
if v == "openai.api_key" || v == "gemini.api_key" {
6973
tbl.AddRow(v, "****************")
7074
continue
7175
}

cmd/config_set.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ func init() {
3333
configSetCmd.Flags().StringP("headers", "", "", availableKeys["openai.headers"])
3434
configSetCmd.Flags().StringP("api_version", "", "", availableKeys["openai.api_version"])
3535
configSetCmd.Flags().StringP("prompt_folder", "", "", availableKeys["prompt.folder"])
36+
// Gemini flags
37+
configSetCmd.Flags().String("gemini.project", "", availableKeys["gemini.project"])
38+
configSetCmd.Flags().String("gemini.location", "", availableKeys["gemini.location"])
39+
configSetCmd.Flags().String("gemini.backend", "BackendGeminiAPI", availableKeys["gemini.backend"])
40+
configSetCmd.Flags().String("gemini.api_key", "", availableKeys["gemini.api_key"])
41+
3642
_ = viper.BindPFlag("openai.base_url", configSetCmd.Flags().Lookup("base_url"))
3743
_ = viper.BindPFlag("openai.org_id", configSetCmd.Flags().Lookup("org_id"))
3844
_ = viper.BindPFlag("openai.api_key", configSetCmd.Flags().Lookup("api_key"))
@@ -52,6 +58,10 @@ func init() {
5258
_ = viper.BindPFlag("openai.headers", configSetCmd.Flags().Lookup("headers"))
5359
_ = viper.BindPFlag("openai.api_version", configSetCmd.Flags().Lookup("api_version"))
5460
_ = viper.BindPFlag("prompt.folder", configSetCmd.Flags().Lookup("prompt_folder"))
61+
_ = viper.BindPFlag("gemini.project", configSetCmd.Flags().Lookup("gemini.project"))
62+
_ = viper.BindPFlag("gemini.location", configSetCmd.Flags().Lookup("gemini.location"))
63+
_ = viper.BindPFlag("gemini.backend", configSetCmd.Flags().Lookup("gemini.backend"))
64+
_ = viper.BindPFlag("gemini.api_key", configSetCmd.Flags().Lookup("gemini.api_key"))
5565
}
5666

5767
// configSetCmd updates the config value.

cmd/provider.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,20 @@ func NewOpenAI() (*openai.Client, error) {
3535

3636
// NewGemini returns a new Gemini client
3737
func NewGemini(ctx context.Context) (*gemini.Client, error) {
38+
apiKey := viper.GetString("gemini.api_key")
39+
if apiKey == "" {
40+
apiKey = viper.GetString("openai.api_key")
41+
}
3842
return gemini.New(
3943
ctx,
40-
gemini.WithToken(viper.GetString("openai.api_key")),
44+
gemini.WithToken(apiKey),
4145
gemini.WithModel(viper.GetString("openai.model")),
4246
gemini.WithMaxTokens(viper.GetInt32("openai.max_tokens")),
4347
gemini.WithTemperature(float32(viper.GetFloat64("openai.temperature"))),
4448
gemini.WithTopP(float32(viper.GetFloat64("openai.top_p"))),
49+
gemini.WithBackend(viper.GetString("gemini.backend")),
50+
gemini.WithProject(viper.GetString("gemini.project")),
51+
gemini.WithLocation(viper.GetString("gemini.location")),
4552
)
4653
}
4754

provider/gemini/gemini.go

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,26 @@ func New(ctx context.Context, opts ...Option) (c *Client, err error) {
159159
},
160160
}
161161

162-
client, err := genai.NewClient(ctx, &genai.ClientConfig{
163-
APIKey: cfg.token,
164-
HTTPClient: httpClient,
165-
Backend: genai.BackendGeminiAPI,
166-
})
162+
var clientConfig *genai.ClientConfig
163+
switch cfg.backend {
164+
case genai.BackendVertexAI:
165+
clientConfig = &genai.ClientConfig{
166+
HTTPClient: httpClient,
167+
Backend: cfg.backend,
168+
Project: cfg.project,
169+
Location: cfg.location,
170+
}
171+
case genai.BackendGeminiAPI, genai.BackendUnspecified:
172+
fallthrough
173+
default:
174+
cfg.backend = genai.BackendGeminiAPI
175+
clientConfig = &genai.ClientConfig{
176+
APIKey: cfg.token,
177+
HTTPClient: httpClient,
178+
Backend: cfg.backend,
179+
}
180+
}
181+
client, err := genai.NewClient(ctx, clientConfig)
167182
if err != nil {
168183
return nil, err
169184
}

provider/gemini/options.go

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@ package gemini
22

33
import (
44
"errors"
5+
6+
"google.golang.org/genai"
57
)
68

79
var (
8-
errorsMissingToken = errors.New("missing gemini api key")
9-
errorsMissingModel = errors.New("missing model")
10+
errorsMissingToken = errors.New("missing gemini api key")
11+
errorsMissingTokenOrProject = errors.New("missing token or project")
12+
errorsMissingModel = errors.New("missing model")
1013
)
1114

1215
const (
@@ -79,45 +82,71 @@ func WithTopP(val float32) Option {
7982
})
8083
}
8184

82-
// config is a struct that stores configuration options for the instrumentation.
8385
type config struct {
8486
token string
8587
model string
8688
maxTokens int32
8789
temperature float32
8890
topP float32
91+
project string
92+
location string
93+
backend genai.Backend
8994
}
9095

91-
// valid checks whether a config object is valid, returning an error if it is not.
9296
func (cfg *config) valid() error {
93-
// Check that the token is not empty.
94-
if cfg.token == "" {
95-
return errorsMissingToken
97+
if cfg.backend == genai.BackendVertexAI {
98+
if cfg.project == "" || cfg.location == "" {
99+
return errorsMissingTokenOrProject
100+
}
101+
} else {
102+
if cfg.token == "" {
103+
return errorsMissingToken
104+
}
96105
}
97-
98106
if cfg.model == "" {
99107
return errorsMissingModel
100108
}
101-
102-
// If all checks pass, return nil (no error).
103109
return nil
104110
}
105111

112+
func WithLocation(val string) Option {
113+
return optionFunc(func(c *config) {
114+
c.location = val
115+
})
116+
}
117+
118+
func WithBackend(val string) Option {
119+
return optionFunc(func(c *config) {
120+
switch val {
121+
case "BackendVertexAI":
122+
c.backend = genai.BackendVertexAI
123+
case "BackendGeminiAPI":
124+
c.backend = genai.BackendGeminiAPI
125+
case "BackendUnspecified":
126+
fallthrough
127+
default:
128+
c.backend = genai.BackendGeminiAPI
129+
}
130+
})
131+
}
132+
133+
func WithProject(val string) Option {
134+
return optionFunc(func(c *config) {
135+
c.project = val
136+
})
137+
}
138+
106139
// newConfig creates a new config object with default values, and applies the given options.
107140
func newConfig(opts ...Option) *config {
108-
// Create a new config object with default values.
109141
c := &config{
110142
model: defaultModel,
111143
maxTokens: defaultMaxTokens,
112144
temperature: defaultTemperature,
113145
topP: defaultTopP,
146+
backend: genai.BackendGeminiAPI,
114147
}
115-
116-
// Apply each of the given options to the config object.
117148
for _, opt := range opts {
118149
opt.apply(c)
119150
}
120-
121-
// Return the resulting config object.
122151
return c
123152
}

0 commit comments

Comments
 (0)