Skip to content

Commit d4be369

Browse files
committed
Include function call history in restored threads.
This should discourage the model from lying about doing things the second time they are requested in a thread, by demonstrating to it how it handled the request the last time. Signed-off-by: Katharine Berry <[email protected]>
1 parent 0f27ba2 commit d4be369

File tree

3 files changed

+41
-11
lines changed

3 files changed

+41
-11
lines changed

service/assistant/functions/functions.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ type Registration struct {
4646
Cb CallbackFunction
4747
// A function that summarises the provided input
4848
Thought ThoughtFunction
49+
// Whether to redact the function output in chat history.
50+
RedactOutputInChatHistory bool
4951
// An instance of the object used to hold the function's parameters. This is what will be passed to
5052
// either Fn or Cb, and it will also be processed to pass to the model - including the comments.
5153
InputType interface{}
@@ -231,3 +233,12 @@ func GetFunctionDefinitionsForCapabilities(capabilities []string) []*genai.Funct
231233
}
232234
return definitions
233235
}
236+
func GetFunctionRegistration(fn string) *Registration {
237+
if realFunction, ok := functionAliases[fn]; ok {
238+
fn = realFunction
239+
}
240+
if reg, ok := functionMap[fn]; ok {
241+
return &reg
242+
}
243+
return nil
244+
}

service/assistant/functions/wikipedia.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,10 @@ func init() {
7474
Required: []string{"wiki", "article_name"},
7575
},
7676
},
77-
Fn: queryWiki,
78-
Thought: queryWikiThought,
79-
InputType: WikiRequest{},
77+
Fn: queryWiki,
78+
Thought: queryWikiThought,
79+
RedactOutputInChatHistory: true,
80+
InputType: WikiRequest{},
8081
})
8182
}
8283

service/assistant/session.go

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -380,20 +380,38 @@ func (ps *PromptSession) Run(ctx context.Context) {
380380
}
381381

382382
type SerializedMessage struct {
383-
Role string `json:"role"`
384-
Content string `json:"content"`
383+
Role string `json:"role"`
384+
Content string `json:"content"`
385+
FunctionCall *genai.FunctionCall `json:"functionCall,omitempty"`
386+
FunctionResponse *genai.FunctionResponse `json:"functionResponse,omitempty"`
385387
}
386388

387389
func (ps *PromptSession) storeThread(ctx context.Context, messages []*genai.Content) error {
388390
ctx, span := beeline.StartSpan(ctx, "store_thread")
389391
defer span.Send()
390392
var toStore []SerializedMessage
391393
for _, m := range messages {
392-
if len(m.Parts) != 0 && (m.Role == "user" || m.Role == "model") && len(strings.TrimSpace(m.Parts[0].Text)) > 0 {
393-
toStore = append(toStore, SerializedMessage{
394-
Content: m.Parts[0].Text,
395-
Role: m.Role,
396-
})
394+
if len(m.Parts) != 0 {
395+
if m.Role == "user" || m.Role == "model" {
396+
sm := SerializedMessage{
397+
Role: m.Role,
398+
Content: m.Parts[0].Text,
399+
FunctionCall: m.Parts[0].FunctionCall,
400+
}
401+
if sm.FunctionCall != nil || len(strings.TrimSpace(m.Parts[0].Text)) > 0 {
402+
toStore = append(toStore, sm)
403+
}
404+
} else if m.Role == "function" && m.Parts[0].FunctionResponse != nil {
405+
fr := *m.Parts[0].FunctionResponse
406+
fnInfo := functions.GetFunctionRegistration(fr.Name)
407+
if fnInfo != nil && fnInfo.RedactOutputInChatHistory {
408+
fr.Response = map[string]any{"redacted": "redacted to reduce context size, call again if necessary"}
409+
}
410+
toStore = append(toStore, SerializedMessage{
411+
Role: m.Role,
412+
FunctionResponse: &fr,
413+
})
414+
}
397415
}
398416
}
399417
j, err := json.Marshal(toStore)
@@ -421,7 +439,7 @@ func (ps *PromptSession) restoreThread(ctx context.Context, oldThreadId string)
421439
var result []*genai.Content
422440
for _, m := range messages {
423441
result = append(result, &genai.Content{
424-
Parts: []*genai.Part{{Text: m.Content}},
442+
Parts: []*genai.Part{{Text: m.Content, FunctionCall: m.FunctionCall, FunctionResponse: m.FunctionResponse}},
425443
Role: m.Role,
426444
})
427445
}

0 commit comments

Comments
 (0)