Skip to content

Commit d4f651b

Browse files
committed
Store the last POI results in a thread for reuse.
Signed-off-by: Katharine Berry <ktbry@google.com>
1 parent c5162e9 commit d4f651b

File tree

3 files changed

+30
-12
lines changed

3 files changed

+30
-12
lines changed

service/assistant/functions/poi.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,6 @@ import (
2929
"strings"
3030
)
3131

32-
type POIQuery struct {
33-
Location string
34-
Query string
35-
LanguageCode string
36-
Units string
37-
}
38-
3932
type POIResponse struct {
4033
Results []util.POI
4134
Warning string `json:"CriticalRequirement,omitempty"`
@@ -70,12 +63,12 @@ func init() {
7063
},
7164
Fn: searchPoi,
7265
Thought: searchPoiThought,
73-
InputType: POIQuery{},
66+
InputType: util.POIQuery{},
7467
})
7568
}
7669

7770
func searchPoiThought(args any) string {
78-
poiQuery := args.(*POIQuery)
71+
poiQuery := args.(*util.POIQuery)
7972
if poiQuery.Location != "" {
8073
location, _, _ := strings.Cut(poiQuery.Location, ",")
8174
return fmt.Sprintf("Looking for %s near %s...", poiQuery.Query, location)
@@ -86,7 +79,14 @@ func searchPoiThought(args any) string {
8679
func searchPoi(ctx context.Context, quotaTracker *quota.Tracker, args any) any {
8780
ctx, span := beeline.StartSpan(ctx, "search_poi")
8881
defer span.Send()
89-
poiQuery := args.(*POIQuery)
82+
threadContext := query.ThreadContextFromContext(ctx)
83+
poiQuery := args.(*util.POIQuery)
84+
if threadContext.ContextStorage.PoiQuery != nil && poiQuery.Equal(threadContext.ContextStorage.PoiQuery) {
85+
log.Printf("Reusing the POI results from before.")
86+
return &POIResponse{
87+
Results: threadContext.ContextStorage.POIs,
88+
}
89+
}
9090
span.AddField("query", poiQuery.Query)
9191
location := query.LocationFromContext(ctx)
9292
if poiQuery.Location != "" {
@@ -183,8 +183,8 @@ func searchPoi(ctx context.Context, quotaTracker *quota.Tracker, args any) any {
183183
}
184184
}
185185

186-
threadContext := query.ThreadContextFromContext(ctx)
187186
threadContext.ContextStorage.POIs = pois
187+
threadContext.ContextStorage.PoiQuery = poiQuery
188188

189189
var attributionList []string
190190
for provider := range attributions {

service/assistant/persistence/persistence.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ type SerializedMessage struct {
3333
}
3434

3535
type StoredContext struct {
36-
POIs []util.POI `json:"pois"`
36+
PoiQuery *util.POIQuery `json:"poiQuery"`
37+
POIs []util.POI `json:"pois"`
3738
}
3839

3940
type ThreadContext struct {

service/assistant/util/poi.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,20 @@ type POI struct {
1414
DistanceMiles float64 `json:"DistanceMiles,omitempty"`
1515
Coordinates Coords `json:"Coordinates,omitempty"`
1616
}
17+
18+
type POIQuery struct {
19+
Location string
20+
Query string
21+
LanguageCode string
22+
Units string
23+
}
24+
25+
func (p *POIQuery) Equal(other *POIQuery) bool {
26+
if other == nil {
27+
return false
28+
}
29+
return p.Location == other.Location &&
30+
p.Query == other.Query &&
31+
p.LanguageCode == other.LanguageCode &&
32+
p.Units == other.Units
33+
}

0 commit comments

Comments
 (0)