Skip to content

Commit a5e958e

Browse files
authored
Merge pull request #1420 from rumpl/title-gen
Fix races
2 parents fc56a07 + 2fcfb8d commit a5e958e

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

pkg/runtime/runtime.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,8 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c
703703
r.registerDefaultTools()
704704

705705
if sess.Title == "" {
706-
r.titleGen.Generate(ctx, sess, events)
706+
userMessage := sess.GetLastUserMessageContent()
707+
r.titleGen.Generate(ctx, sess, userMessage, events)
707708
}
708709

709710
iteration := 0

pkg/runtime/runtime_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"io"
88
"reflect"
9+
"sync"
910
"testing"
1011
"time"
1112

@@ -568,12 +569,15 @@ func TestToolCallVariations(t *testing.T) {
568569
// queueProvider returns a different stream on each CreateChatCompletionStream call.
569570
type queueProvider struct {
570571
id string
572+
mu sync.Mutex
571573
streams []chat.MessageStream
572574
}
573575

574576
func (p *queueProvider) ID() string { return p.id }
575577

576578
func (p *queueProvider) CreateChatCompletionStream(context.Context, []chat.Message, []tools.Tool) (chat.MessageStream, error) {
579+
p.mu.Lock()
580+
defer p.mu.Unlock()
577581
if len(p.streams) == 0 {
578582
return &mockStream{}, nil
579583
}

pkg/runtime/title_generator.go

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,22 @@ func newTitleGenerator(model provider.Provider) *titleGenerator {
2929
}
3030
}
3131

32-
func (t *titleGenerator) Generate(ctx context.Context, sess *session.Session, events chan<- Event) {
32+
func (t *titleGenerator) Generate(ctx context.Context, sess *session.Session, userMessage string, events chan<- Event) {
33+
if userMessage == "" {
34+
return
35+
}
3336
t.wg.Go(func() {
34-
t.generate(ctx, sess, events)
37+
t.generate(ctx, sess, userMessage, events)
3538
})
3639
}
3740

3841
func (t *titleGenerator) Wait() {
3942
t.wg.Wait()
4043
}
4144

42-
func (t *titleGenerator) generate(ctx context.Context, sess *session.Session, events chan<- Event) {
45+
func (t *titleGenerator) generate(ctx context.Context, sess *session.Session, firstUserMessage string, events chan<- Event) {
4346
slog.Debug("Generating title for session", "session_id", sess.ID)
4447

45-
firstUserMessage := sess.GetLastUserMessageContent()
46-
if firstUserMessage == "" {
47-
return
48-
}
49-
5048
userPrompt := fmt.Sprintf(titleUserPromptFormat, firstUserMessage)
5149

5250
titleModel := provider.CloneWithOptions(

0 commit comments

Comments
 (0)