Skip to content

Commit 3ea1edf

Browse files
authored
feat(api): switch func calling to tools (#157)
- Add imports for "fmt" and "log" in `commit.go` - Refactor message output to use a variable in `commit.go` - Update error handling for completion response to check for both errors and the number of choices in `commit.go` - Simplify the handling of summary prefix by removing conditional checks and replacing them with a single error check and assignment in `commit.go` - Change `CreateFunctionCall` to accept a single `FunctionDefinition` instead of a variadic slice in `openai.go` - Remove hardcoded system message content in `openai.go` - Replace `Functions` field with `Tools` in `CreateFunctionCall` to use the new `Tool` struct in `openai.go` - Remove specific function call strings from the allow list in `openai.go` and replace with a single entry Signed-off-by: Bo-Yi Wu <appleboy.tw@gmail.com>
1 parent 3059df1 commit 3ea1edf

File tree

2 files changed

+25
-19
lines changed

2 files changed

+25
-19
lines changed

cmd/commit.go

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package cmd
22

33
import (
4+
"fmt"
45
"html"
6+
"log"
57
"os"
68
"path"
79
"strconv"
@@ -203,25 +205,31 @@ var commitCmd = &cobra.Command{
203205
if err != nil {
204206
return err
205207
}
206-
color.Cyan("We are trying to get conventional commit prefix")
208+
message := "We are trying to get conventional commit prefix"
207209
summaryPrix := ""
208210
if client.AllowFuncCall() {
211+
color.Cyan(message + " (Tools)")
209212
resp, err := client.CreateFunctionCall(cmd.Context(), out, openai.SummaryPrefixFunc)
210-
if err != nil {
213+
if err != nil || len(resp.Choices) != 1 {
214+
log.Printf("Completion error: err:%v len(choices):%v\n", err,
215+
len(resp.Choices))
211216
return err
212217
}
213-
if len(resp.Choices) > 0 {
214-
summaryPrix = strings.TrimSpace(resp.Choices[0].Message.Content)
215-
if resp.Choices[0].Message.FunctionCall != nil {
216-
args := openai.GetSummaryPrefixArgs(resp.Choices[0].Message.FunctionCall.Arguments)
217-
summaryPrix = args.Prefix
218-
}
218+
219+
msg := resp.Choices[0].Message
220+
if len(msg.ToolCalls) == 0 {
221+
return fmt.Errorf("current model doesn't support function call")
219222
}
223+
224+
args := openai.GetSummaryPrefixArgs(msg.ToolCalls[len(msg.ToolCalls)-1].Function.Arguments)
225+
summaryPrix = args.Prefix
226+
220227
color.Magenta("PromptTokens: " + strconv.Itoa(resp.Usage.PromptTokens) +
221228
", CompletionTokens: " + strconv.Itoa(resp.Usage.CompletionTokens) +
222229
", TotalTokens: " + strconv.Itoa(resp.Usage.TotalTokens),
223230
)
224231
} else {
232+
color.Cyan(message)
225233
resp, err := client.Completion(cmd.Context(), out)
226234
if err != nil {
227235
return err

openai/openai.go

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,13 @@ type Response struct {
9393
func (c *Client) CreateFunctionCall(
9494
ctx context.Context,
9595
content string,
96-
funcs ...openai.FunctionDefinition,
96+
f openai.FunctionDefinition,
9797
) (resp openai.ChatCompletionResponse, err error) {
98+
t := openai.Tool{
99+
Type: openai.ToolTypeFunction,
100+
Function: &f,
101+
}
102+
98103
req := openai.ChatCompletionRequest{
99104
Model: c.model,
100105
MaxTokens: c.maxTokens,
@@ -103,18 +108,14 @@ func (c *Client) CreateFunctionCall(
103108
FrequencyPenalty: c.frequencyPenalty,
104109
PresencePenalty: c.presencePenalty,
105110
Messages: []openai.ChatCompletionMessage{
106-
{
107-
Role: openai.ChatMessageRoleSystem,
108-
Content: "You are a helpful assistant.",
109-
},
110111
{
111112
Role: openai.ChatMessageRoleUser,
112113
Content: content,
113114
},
114115
},
115-
Functions: funcs,
116-
FunctionCall: "auto",
116+
Tools: []openai.Tool{t},
117117
}
118+
118119
return c.client.CreateChatCompletion(ctx, req)
119120
}
120121

@@ -326,10 +327,7 @@ func (c *Client) allowFuncCall(cfg *config) bool {
326327
openai.GPT3Dot5Turbo0125,
327328
openai.GPT3Dot5Turbo0613,
328329
openai.GPT3Dot5Turbo1106,
329-
groq.LLaMA38b.String(),
330-
groq.LLaMA370b.String(),
331-
groq.Mixtral8x7b.String(),
332-
groq.Gemma7b.String():
330+
groq.LLaMA38b.String():
333331
return true
334332
default:
335333
return false

0 commit comments

Comments
 (0)