Skip to content

Commit e874b9e

Browse files
committed
feat: Add ReasoningBank for reusable reasoning strategies
Implements ReasoningBank feature based on arXiv:2509.25140 paper. Core components: - ReasoningStrategy: distilled, reusable reasoning approach - ReasoningTrace: captures raw task execution data - BaseReasoningBankService: interface for storage/retrieval - InMemoryReasoningBankService: in-memory implementation with keyword matching - LoadReasoningStrategyTool: tool for agents to retrieve relevant strategies - SearchReasoningResponse: response model for strategy search Integration: - Added reasoningBankService to InvocationContext - Added searchReasoningStrategies() to ToolContext Includes 20 unit tests covering data models and service functionality.
1 parent b48b194 commit e874b9e

12 files changed

+1169
-0
lines changed

core/src/main/java/com/google/adk/agents/InvocationContext.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import com.google.adk.memory.BaseMemoryService;
2323
import com.google.adk.models.LlmCallsLimitExceededException;
2424
import com.google.adk.plugins.PluginManager;
25+
import com.google.adk.reasoning.BaseReasoningBankService;
2526
import com.google.adk.sessions.BaseSessionService;
2627
import com.google.adk.sessions.Session;
2728
import com.google.common.collect.ImmutableSet;
@@ -42,6 +43,7 @@ public class InvocationContext {
4243
private final BaseSessionService sessionService;
4344
private final BaseArtifactService artifactService;
4445
private final BaseMemoryService memoryService;
46+
private final BaseReasoningBankService reasoningBankService;
4547
private final PluginManager pluginManager;
4648
private final Optional<LiveRequestQueue> liveRequestQueue;
4749
private final Map<String, ActiveStreamingTool> activeStreamingTools = new ConcurrentHashMap<>();
@@ -60,6 +62,7 @@ private InvocationContext(Builder builder) {
6062
this.sessionService = builder.sessionService;
6163
this.artifactService = builder.artifactService;
6264
this.memoryService = builder.memoryService;
65+
this.reasoningBankService = builder.reasoningBankService;
6366
this.pluginManager = builder.pluginManager;
6467
this.liveRequestQueue = builder.liveRequestQueue;
6568
this.branch = builder.branch;
@@ -204,6 +207,7 @@ public static InvocationContext copyOf(InvocationContext other) {
204207
.sessionService(other.sessionService)
205208
.artifactService(other.artifactService)
206209
.memoryService(other.memoryService)
210+
.reasoningBankService(other.reasoningBankService)
207211
.pluginManager(other.pluginManager)
208212
.liveRequestQueue(other.liveRequestQueue)
209213
.branch(other.branch)
@@ -234,6 +238,11 @@ public BaseMemoryService memoryService() {
234238
return memoryService;
235239
}
236240

241+
/** Returns the reasoning bank service for accessing reasoning strategies. */
242+
public BaseReasoningBankService reasoningBankService() {
243+
return reasoningBankService;
244+
}
245+
237246
/** Returns the plugin manager for accessing tools and plugins. */
238247
public PluginManager pluginManager() {
239248
return pluginManager;
@@ -376,6 +385,7 @@ public static class Builder {
376385
private BaseSessionService sessionService;
377386
private BaseArtifactService artifactService;
378387
private BaseMemoryService memoryService;
388+
private BaseReasoningBankService reasoningBankService;
379389
private PluginManager pluginManager = new PluginManager();
380390
private Optional<LiveRequestQueue> liveRequestQueue = Optional.empty();
381391
private Optional<String> branch = Optional.empty();
@@ -423,6 +433,18 @@ public Builder memoryService(BaseMemoryService memoryService) {
423433
return this;
424434
}
425435

436+
/**
437+
* Sets the reasoning bank service for accessing reasoning strategies.
438+
*
439+
* @param reasoningBankService the reasoning bank service to use.
440+
* @return this builder instance for chaining.
441+
*/
442+
@CanIgnoreReturnValue
443+
public Builder reasoningBankService(BaseReasoningBankService reasoningBankService) {
444+
this.reasoningBankService = reasoningBankService;
445+
return this;
446+
}
447+
426448
/**
427449
* Sets the plugin manager for accessing tools and plugins.
428450
*
@@ -608,6 +630,7 @@ public boolean equals(Object o) {
608630
&& Objects.equals(sessionService, that.sessionService)
609631
&& Objects.equals(artifactService, that.artifactService)
610632
&& Objects.equals(memoryService, that.memoryService)
633+
&& Objects.equals(reasoningBankService, that.reasoningBankService)
611634
&& Objects.equals(pluginManager, that.pluginManager)
612635
&& Objects.equals(liveRequestQueue, that.liveRequestQueue)
613636
&& Objects.equals(activeStreamingTools, that.activeStreamingTools)
@@ -626,6 +649,7 @@ public int hashCode() {
626649
sessionService,
627650
artifactService,
628651
memoryService,
652+
reasoningBankService,
629653
pluginManager,
630654
liveRequestQueue,
631655
activeStreamingTools,
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.reasoning;
18+
19+
import io.reactivex.rxjava3.core.Completable;
20+
import io.reactivex.rxjava3.core.Single;
21+
22+
/**
23+
* Base contract for reasoning bank services.
24+
*
25+
* <p>The service provides functionalities to store and retrieve reasoning strategies that can be
26+
* used to augment LLM prompts with relevant problem-solving approaches.
27+
*
28+
* <p>Based on the ReasoningBank paper (arXiv:2509.25140).
29+
*/
30+
public interface BaseReasoningBankService {
31+
32+
/**
33+
* Stores a reasoning strategy in the bank.
34+
*
35+
* @param appName The name of the application.
36+
* @param strategy The strategy to store.
37+
* @return A Completable that completes when the strategy is stored.
38+
*/
39+
Completable storeStrategy(String appName, ReasoningStrategy strategy);
40+
41+
/**
42+
* Stores a reasoning trace for later distillation into strategies.
43+
*
44+
* @param appName The name of the application.
45+
* @param trace The trace to store.
46+
* @return A Completable that completes when the trace is stored.
47+
*/
48+
Completable storeTrace(String appName, ReasoningTrace trace);
49+
50+
/**
51+
* Searches for reasoning strategies that match the given query.
52+
*
53+
* @param appName The name of the application.
54+
* @param query The query to search for (typically a task description).
55+
* @return A {@link SearchReasoningResponse} containing matching strategies.
56+
*/
57+
Single<SearchReasoningResponse> searchStrategies(String appName, String query);
58+
59+
/**
60+
* Searches for reasoning strategies that match the given query with a limit.
61+
*
62+
* @param appName The name of the application.
63+
* @param query The query to search for.
64+
* @param maxResults Maximum number of strategies to return.
65+
* @return A {@link SearchReasoningResponse} containing matching strategies.
66+
*/
67+
Single<SearchReasoningResponse> searchStrategies(String appName, String query, int maxResults);
68+
}
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.reasoning;
18+
19+
import com.google.common.collect.ImmutableList;
20+
import com.google.common.collect.ImmutableSet;
21+
import io.reactivex.rxjava3.core.Completable;
22+
import io.reactivex.rxjava3.core.Single;
23+
import java.util.ArrayList;
24+
import java.util.Collections;
25+
import java.util.HashSet;
26+
import java.util.List;
27+
import java.util.Locale;
28+
import java.util.Map;
29+
import java.util.Set;
30+
import java.util.concurrent.ConcurrentHashMap;
31+
import java.util.regex.Matcher;
32+
import java.util.regex.Pattern;
33+
34+
/**
35+
* An in-memory reasoning bank service for prototyping purposes only.
36+
*
37+
* <p>Uses keyword matching instead of semantic search. For production use, consider implementing a
38+
* service backed by vector embeddings for semantic similarity matching.
39+
*/
40+
public final class InMemoryReasoningBankService implements BaseReasoningBankService {
41+
42+
private static final int DEFAULT_MAX_RESULTS = 5;
43+
44+
// Pattern to extract words for keyword matching.
45+
private static final Pattern WORD_PATTERN = Pattern.compile("[A-Za-z]+");
46+
47+
/** Keys are app names, values are lists of strategies. */
48+
private final Map<String, List<ReasoningStrategy>> strategies;
49+
50+
/** Keys are app names, values are lists of traces. */
51+
private final Map<String, List<ReasoningTrace>> traces;
52+
53+
public InMemoryReasoningBankService() {
54+
this.strategies = new ConcurrentHashMap<>();
55+
this.traces = new ConcurrentHashMap<>();
56+
}
57+
58+
@Override
59+
public Completable storeStrategy(String appName, ReasoningStrategy strategy) {
60+
return Completable.fromAction(
61+
() -> {
62+
List<ReasoningStrategy> appStrategies =
63+
strategies.computeIfAbsent(
64+
appName, k -> Collections.synchronizedList(new ArrayList<>()));
65+
appStrategies.add(strategy);
66+
});
67+
}
68+
69+
@Override
70+
public Completable storeTrace(String appName, ReasoningTrace trace) {
71+
return Completable.fromAction(
72+
() -> {
73+
List<ReasoningTrace> appTraces =
74+
traces.computeIfAbsent(appName, k -> Collections.synchronizedList(new ArrayList<>()));
75+
appTraces.add(trace);
76+
});
77+
}
78+
79+
@Override
80+
public Single<SearchReasoningResponse> searchStrategies(String appName, String query) {
81+
return searchStrategies(appName, query, DEFAULT_MAX_RESULTS);
82+
}
83+
84+
@Override
85+
public Single<SearchReasoningResponse> searchStrategies(
86+
String appName, String query, int maxResults) {
87+
return Single.fromCallable(
88+
() -> {
89+
if (!strategies.containsKey(appName)) {
90+
return SearchReasoningResponse.builder().build();
91+
}
92+
93+
List<ReasoningStrategy> appStrategies = strategies.get(appName);
94+
ImmutableSet<String> queryWords = extractWords(query);
95+
96+
if (queryWords.isEmpty()) {
97+
return SearchReasoningResponse.builder().build();
98+
}
99+
100+
List<ScoredStrategy> scoredStrategies = new ArrayList<>();
101+
102+
for (ReasoningStrategy strategy : appStrategies) {
103+
int score = calculateMatchScore(strategy, queryWords);
104+
if (score > 0) {
105+
scoredStrategies.add(new ScoredStrategy(strategy, score));
106+
}
107+
}
108+
109+
// Sort by score descending
110+
scoredStrategies.sort((a, b) -> Integer.compare(b.score, a.score));
111+
112+
// Take top results
113+
List<ReasoningStrategy> matchingStrategies = new ArrayList<>();
114+
for (int i = 0; i < Math.min(maxResults, scoredStrategies.size()); i++) {
115+
matchingStrategies.add(scoredStrategies.get(i).strategy);
116+
}
117+
118+
return SearchReasoningResponse.builder()
119+
.setStrategies(ImmutableList.copyOf(matchingStrategies))
120+
.build();
121+
});
122+
}
123+
124+
private int calculateMatchScore(ReasoningStrategy strategy, Set<String> queryWords) {
125+
int score = 0;
126+
127+
// Check problem pattern
128+
Set<String> patternWords = extractWords(strategy.problemPattern());
129+
score += countOverlap(queryWords, patternWords) * 3; // Weight pattern matches higher
130+
131+
// Check name
132+
Set<String> nameWords = extractWords(strategy.name());
133+
score += countOverlap(queryWords, nameWords) * 2;
134+
135+
// Check tags
136+
for (String tag : strategy.tags()) {
137+
Set<String> tagWords = extractWords(tag);
138+
score += countOverlap(queryWords, tagWords);
139+
}
140+
141+
// Check steps (lower weight)
142+
for (String step : strategy.steps()) {
143+
Set<String> stepWords = extractWords(step);
144+
if (!Collections.disjoint(queryWords, stepWords)) {
145+
score += 1;
146+
}
147+
}
148+
149+
return score;
150+
}
151+
152+
private int countOverlap(Set<String> set1, Set<String> set2) {
153+
int count = 0;
154+
for (String word : set1) {
155+
if (set2.contains(word)) {
156+
count++;
157+
}
158+
}
159+
return count;
160+
}
161+
162+
private ImmutableSet<String> extractWords(String text) {
163+
if (text == null || text.isEmpty()) {
164+
return ImmutableSet.of();
165+
}
166+
167+
Set<String> words = new HashSet<>();
168+
Matcher matcher = WORD_PATTERN.matcher(text);
169+
while (matcher.find()) {
170+
words.add(matcher.group().toLowerCase(Locale.ROOT));
171+
}
172+
return ImmutableSet.copyOf(words);
173+
}
174+
175+
/** Helper class for scoring strategies during search. */
176+
private static class ScoredStrategy {
177+
final ReasoningStrategy strategy;
178+
final int score;
179+
180+
ScoredStrategy(ReasoningStrategy strategy, int score) {
181+
this.strategy = strategy;
182+
this.score = score;
183+
}
184+
}
185+
}

0 commit comments

Comments
 (0)