Skip to content

Commit bc50d28

Browse files
ThomasVitaletzolov
authored andcommitted
Modular RAG: Query Augmentor
* Add new QueryAugmentor API, a component for augmenting a user query with contextual data. * Implement ContextualQueryAugmentor that combines the content of each document and add it to the original user prompt, with support for the scenario where the context is empty. * Extend RetrievalAugmentationAdvisor to use the new augmentation building block. * Introduce utility to assist in validating arguments for prompt-related operations. Relates to gh-spring-projects#1603 Signed-off-by: Thomas Vitale <[email protected]>
1 parent cca4304 commit bc50d28

File tree

8 files changed

+477
-44
lines changed

8 files changed

+477
-44
lines changed

spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java

Lines changed: 12 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import java.util.List;
2121
import java.util.Map;
2222
import java.util.function.Predicate;
23-
import java.util.stream.Collectors;
2423

2524
import reactor.core.publisher.Flux;
2625
import reactor.core.publisher.Mono;
@@ -32,12 +31,12 @@
3231
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
3332
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor;
3433
import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain;
35-
import org.springframework.ai.chat.messages.UserMessage;
3634
import org.springframework.ai.chat.model.ChatResponse;
3735
import org.springframework.ai.chat.prompt.PromptTemplate;
3836
import org.springframework.ai.document.Document;
39-
import org.springframework.ai.model.Content;
4037
import org.springframework.ai.rag.Query;
38+
import org.springframework.ai.rag.augmentation.ContextualQueryAugmentor;
39+
import org.springframework.ai.rag.augmentation.QueryAugmentor;
4140
import org.springframework.ai.rag.retrieval.source.DocumentRetriever;
4241
import org.springframework.lang.Nullable;
4342
import org.springframework.util.Assert;
@@ -60,33 +59,19 @@ public class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAr
6059

6160
public static final String DOCUMENT_CONTEXT = "rag_document_context";
6261

63-
public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
64-
{query}
65-
66-
Context information is below. Use this information to answer the user query.
67-
68-
---------------------
69-
{context}
70-
---------------------
71-
72-
Given the context and provided history information and not prior knowledge,
73-
reply to the user query. If the answer is not in the context, inform
74-
the user that you can't answer the query.
75-
""");
76-
7762
private final DocumentRetriever documentRetriever;
7863

79-
private final PromptTemplate promptTemplate;
64+
private final QueryAugmentor queryAugmentor;
8065

8166
private final boolean protectFromBlocking;
8267

8368
private final int order;
8469

85-
public RetrievalAugmentationAdvisor(DocumentRetriever documentRetriever, @Nullable PromptTemplate promptTemplate,
70+
public RetrievalAugmentationAdvisor(DocumentRetriever documentRetriever, @Nullable QueryAugmentor queryAugmentor,
8671
@Nullable Boolean protectFromBlocking, @Nullable Integer order) {
8772
Assert.notNull(documentRetriever, "documentRetriever cannot be null");
8873
this.documentRetriever = documentRetriever;
89-
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
74+
this.queryAugmentor = queryAugmentor != null ? queryAugmentor : ContextualQueryAugmentor.builder().build();
9075
this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : false;
9176
this.order = order != null ? order : 0;
9277
}
@@ -140,21 +125,10 @@ private AdvisedRequest before(AdvisedRequest request) {
140125
List<Document> documents = this.documentRetriever.retrieve(query);
141126
context.put(DOCUMENT_CONTEXT, documents);
142127

143-
// 2. Combine retrieved documents.
144-
String documentContext = documents.stream()
145-
.map(Content::getContent)
146-
.collect(Collectors.joining(System.lineSeparator()));
147-
148-
// 3. Define augmentation prompt parameters.
149-
Map<String, Object> promptParameters = Map.of("query", query.text(), "context", documentContext);
150-
151-
// 4. Augment user prompt with the context data.
152-
UserMessage augmentedUserMessage = (UserMessage) this.promptTemplate.createMessage(promptParameters);
128+
// 2. Augment user query with the contextual data.
129+
Query augmentedQuery = this.queryAugmentor.augment(query, documents);
153130

154-
return AdvisedRequest.from(request)
155-
.withUserText(augmentedUserMessage.getContent())
156-
.withAdviseContext(context)
157-
.build();
131+
return AdvisedRequest.from(request).withUserText(augmentedQuery.text()).withAdviseContext(context).build();
158132
}
159133

160134
private AdvisedResponse after(AdvisedResponse advisedResponse) {
@@ -185,7 +159,7 @@ public static final class Builder {
185159

186160
private DocumentRetriever documentRetriever;
187161

188-
private PromptTemplate promptTemplate;
162+
private QueryAugmentor queryAugmentor;
189163

190164
private Boolean protectFromBlocking;
191165

@@ -199,8 +173,8 @@ public Builder documentRetriever(DocumentRetriever documentRetriever) {
199173
return this;
200174
}
201175

202-
public Builder promptTemplate(PromptTemplate promptTemplate) {
203-
this.promptTemplate = promptTemplate;
176+
public Builder queryAugmentor(QueryAugmentor queryAugmentor) {
177+
this.queryAugmentor = queryAugmentor;
204178
return this;
205179
}
206180

@@ -215,7 +189,7 @@ public Builder order(Integer order) {
215189
}
216190

217191
public RetrievalAugmentationAdvisor build() {
218-
return new RetrievalAugmentationAdvisor(this.documentRetriever, this.promptTemplate,
192+
return new RetrievalAugmentationAdvisor(this.documentRetriever, this.queryAugmentor,
219193
this.protectFromBlocking, this.order);
220194
}
221195

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.rag.augmentation;
18+
19+
import java.util.List;
20+
import java.util.Map;
21+
import java.util.stream.Collectors;
22+
23+
import org.springframework.ai.chat.prompt.PromptTemplate;
24+
import org.springframework.ai.document.Document;
25+
import org.springframework.ai.model.Content;
26+
import org.springframework.ai.rag.Query;
27+
import org.springframework.ai.util.PromptAssert;
28+
import org.springframework.lang.Nullable;
29+
import org.springframework.util.Assert;
30+
31+
/**
32+
* Augments the user query with contextual data.
33+
*
34+
* <p>
35+
* Example usage: <pre>{@code
36+
* QueryAugmentor augmentor = ContextualQueryAugmentor.builder()
37+
* .promptTemplate(promptTemplate)
38+
* .emptyContextPromptTemplate(emptyContextPromptTemplate)
39+
* .allowEmptyContext(allowEmptyContext)
40+
* .build();
41+
* Query augmentedQuery = augmentor.augment(query, documents);
42+
* }</pre>
43+
*
44+
* @author Thomas Vitale
45+
* @since 1.0.0
46+
*/
47+
public class ContextualQueryAugmentor implements QueryAugmentor {
48+
49+
private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
50+
Context information is below.
51+
52+
---------------------
53+
{context}
54+
---------------------
55+
56+
Given the context information and no prior knowledge, answer the query.
57+
58+
Follow these rules:
59+
60+
1. If the answer is not in the context, just say that you don't know.
61+
2. Avoid statements like "Based on the context..." or "The provided information...".
62+
63+
Query: {query}
64+
65+
Answer:
66+
""");
67+
68+
private static final PromptTemplate DEFAULT_EMPTY_CONTEXT_PROMPT_TEMPLATE = new PromptTemplate("""
69+
The user query is outside your knowledge base.
70+
Politely inform the user that you can't answer it.
71+
""");
72+
73+
private static final boolean DEFAULT_ALLOW_EMPTY_CONTEXT = true;
74+
75+
private final PromptTemplate promptTemplate;
76+
77+
private final PromptTemplate emptyContextPromptTemplate;
78+
79+
private final boolean allowEmptyContext;
80+
81+
public ContextualQueryAugmentor(@Nullable PromptTemplate promptTemplate,
82+
@Nullable PromptTemplate emptyContextPromptTemplate, @Nullable Boolean allowEmptyContext) {
83+
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
84+
this.emptyContextPromptTemplate = emptyContextPromptTemplate != null ? emptyContextPromptTemplate
85+
: DEFAULT_EMPTY_CONTEXT_PROMPT_TEMPLATE;
86+
this.allowEmptyContext = allowEmptyContext != null ? allowEmptyContext : DEFAULT_ALLOW_EMPTY_CONTEXT;
87+
PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "query", "context");
88+
}
89+
90+
@Override
91+
public Query augment(Query query, List<Document> documents) {
92+
Assert.notNull(query, "query cannot be null");
93+
Assert.notNull(documents, "documents cannot be null");
94+
95+
if (documents.isEmpty()) {
96+
return augmentQueryWhenEmptyContext(query);
97+
}
98+
99+
// 1. Join documents.
100+
String documentContext = documents.stream()
101+
.map(Content::getContent)
102+
.collect(Collectors.joining(System.lineSeparator()));
103+
104+
// 2. Define prompt parameters.
105+
Map<String, Object> promptParameters = Map.of("query", query.text(), "context", documentContext);
106+
107+
// 3. Augment user prompt with document context.
108+
return new Query(this.promptTemplate.render(promptParameters));
109+
}
110+
111+
private Query augmentQueryWhenEmptyContext(Query query) {
112+
if (this.allowEmptyContext) {
113+
return query;
114+
}
115+
return new Query(this.emptyContextPromptTemplate.render());
116+
}
117+
118+
public static Builder builder() {
119+
return new Builder();
120+
}
121+
122+
public static class Builder {
123+
124+
private PromptTemplate promptTemplate;
125+
126+
private PromptTemplate emptyContextPromptTemplate;
127+
128+
private Boolean allowEmptyContext;
129+
130+
public Builder promptTemplate(PromptTemplate promptTemplate) {
131+
this.promptTemplate = promptTemplate;
132+
return this;
133+
}
134+
135+
public Builder emptyContextPromptTemplate(PromptTemplate emptyContextPromptTemplate) {
136+
this.emptyContextPromptTemplate = emptyContextPromptTemplate;
137+
return this;
138+
}
139+
140+
public Builder allowEmptyContext(Boolean allowEmptyContext) {
141+
this.allowEmptyContext = allowEmptyContext;
142+
return this;
143+
}
144+
145+
public ContextualQueryAugmentor build() {
146+
return new ContextualQueryAugmentor(this.promptTemplate, this.emptyContextPromptTemplate,
147+
this.allowEmptyContext);
148+
}
149+
150+
}
151+
152+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.rag.augmentation;
18+
19+
import java.util.List;
20+
import java.util.function.BiFunction;
21+
22+
import org.springframework.ai.document.Document;
23+
import org.springframework.ai.rag.Query;
24+
25+
/**
26+
* Component for augmenting a query with contextual data based on a specific strategy.
27+
*
28+
* @author Thomas Vitale
29+
* @since 1.0.0
30+
*/
31+
public interface QueryAugmentor extends BiFunction<Query, List<Document>, Query> {
32+
33+
/**
34+
* Augments the user query with contextual data.
35+
* @param query The user query to augment
36+
* @param documents The contextual data to use for augmentation
37+
* @return The augmented query
38+
*/
39+
Query augment(Query query, List<Document> documents);
40+
41+
/**
42+
* Augments the user query with contextual data.
43+
* @param query The user query to augment
44+
* @param documents The contextual data to use for augmentation
45+
* @return The augmented query
46+
*/
47+
default Query apply(Query query, List<Document> documents) {
48+
return augment(query, documents);
49+
}
50+
51+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
/**
18+
* RAG Module: Augmentation.
19+
* <p>
20+
* This package provides the functional building blocks for augmenting a user query with
21+
* contextual data.
22+
*/
23+
24+
@NonNullApi
25+
@NonNullFields
26+
package org.springframework.ai.rag.augmentation;
27+
28+
import org.springframework.lang.NonNullApi;
29+
import org.springframework.lang.NonNullFields;

0 commit comments

Comments
 (0)