Skip to content

Commit 42d0f82

Browse files
committed
Add custom jinja template for text generation
1 parent e4edb5e commit 42d0f82

File tree

4 files changed

+55
-32
lines changed

4 files changed

+55
-32
lines changed

lib/langchain/chain_builder.dart

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,44 +17,34 @@ String combineDocuments(
1717
documents.map((final d) => d.pageContent).join(separator);
1818

1919

20-
RAGChain buildRAGChain(LLMInference llmInference, Embeddings embeddings, OpenVINOLLMOptions options, List<VectorStore> stores) {
20+
RAGChain buildRAGChain(LLMInference llmInference, Embeddings embeddings, OpenVINOLLMOptions options, List<VectorStore> stores, BaseChatMemory memory) {
2121
final retrievers = combineStores(stores);
2222

23+
final tokenizerConfig = jsonDecode(llmInference.getTokenizerConfig()) as Map<String, dynamic>;
24+
2325
final retrievedDocs = Runnable.fromMap({
2426
'docs': Runnable.getItemFromMap('question') | retrievers,
2527
'question': Runnable.getItemFromMap('question'),
2628
});
2729

28-
if (stores.isEmpty) {
29-
final model = OpenVINOLLM(llmInference, defaultOptions: options.copyWith(applyTemplate: true));
30-
final answer = PromptTemplate.fromTemplate('{question}') | model;
31-
return RAGChain(retrievedDocs, answer);
32-
}
33-
34-
35-
final tokenizerConfig = jsonDecode(llmInference.getTokenizerConfig()) as Map<String, dynamic>;
36-
37-
final hasChatTemplate = tokenizerConfig.containsKey("chat_template");
38-
39-
// if chat template, otherwise
40-
final promptTemplate = hasChatTemplate
41-
? JinjaPromptTemplate.fromTemplateConfig(tokenizerConfig)
42-
: ChatPromptTemplate.fromTemplate('''
43-
Answer the question based only on the following context without specifically naming that it's from that context:
44-
{context}
45-
46-
Question: {question}
47-
''');
30+
final promptTemplate = JinjaPromptTemplate.fromTemplateConfig(tokenizerConfig);
4831

4932
final finalInputs = Runnable.fromMap({
5033
'context': Runnable.getItemFromMap<List<Document>>('docs') |
5134
Runnable.mapInput<List<Document>, String>(combineDocuments),
5235
'question': Runnable.getItemFromMap('question'),
36+
'history': Runnable.getItemFromMap('question') | Runnable.mapInput((_) async {
37+
final m = await memory.loadMemoryVariables();
38+
return m['history'];
39+
}),
5340
});
5441
final model = OpenVINOLLM(llmInference, defaultOptions: options.copyWith(applyTemplate: false));
5542

5643
final answer = finalInputs | promptTemplate | model;
5744

45+
finalInputs.invoke({'docs': List<Document>.from([]), 'question': "What is the color of the sun?"}).then(print);
46+
47+
5848
return RAGChain(retrievedDocs, answer);
5949
}
6050

lib/langchain/jinja_prompt_template.dart

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@
55
import 'package:jinja/jinja.dart';
66
import 'package:langchain/langchain.dart';
77

8+
const textGenerationTemplate = """
9+
{% for message in messages %}{%- if message['role'] == 'system' %}{{message['content']}}
10+
11+
Question:
12+
{%- endif %}{% endfor %}
13+
{% for message in messages %}{%- if message['role'] == 'user' %}{{message['content']}}{%- endif %}{% endfor %}
14+
""";
15+
816
final class JinjaPromptTemplate extends BaseChatPromptTemplate {
917
final Template jinjaTemplate;
1018

@@ -19,7 +27,9 @@ final class JinjaPromptTemplate extends BaseChatPromptTemplate {
1927
});
2028

2129
factory JinjaPromptTemplate.fromTemplateConfig(Map<String, dynamic> chatTemplateConfig, [Set<String> inputVariables = const {}]) {
22-
final chatTemplate = chatTemplateConfig["chat_template"];
30+
final chatTemplate = chatTemplateConfig.containsKey("chat_template")
31+
? chatTemplateConfig["chat_template"]
32+
: textGenerationTemplate;
2333
final env = Environment();
2434
final template = env.fromString(chatTemplate);
2535

@@ -46,10 +56,23 @@ final class JinjaPromptTemplate extends BaseChatPromptTemplate {
4656

4757
@override
4858
PromptValue formatPrompt(final InputValues values) {
49-
final messages =[
50-
{"role": "system", "content": "Answer the question based on some info:\n ${values['context']}"},
51-
{"role": "user", "content": values['question']},
52-
];
59+
List<Map<String, dynamic>> messages = [];
60+
if (values.containsKey('history')) {
61+
for (final message in values['history']) {
62+
if (message is AIChatMessage) {
63+
messages.add({"role": "assistant", "content": message.contentAsString});
64+
}
65+
if (message is HumanChatMessage) {
66+
messages.add({"role": "user", "content": message.contentAsString});
67+
}
68+
}
69+
}
70+
if (values.containsKey('context') && values['context'] != "") {
71+
messages.add({"role": "system", "content": "Answer the question based on some info:\n ${values['context']}"});
72+
}
73+
if (values.containsKey('question')) {
74+
messages.add({"role": "user", "content": values['question']});
75+
}
5376

5477
return PromptValue.string(jinjaTemplate.render(
5578
{

lib/providers/text_inference_provider.dart

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ class TextInferenceProvider extends ChangeNotifier {
6262
String? get device => _device;
6363
Metrics? get metrics => _messages.lastOrNull?.metrics;
6464

65+
final memory = ConversationBufferMemory(returnMessages: true);
66+
6567
final List<UserFile> _userFiles = [];
6668

6769
Future<void> addUserFiles(List<UserFile> files ) async {
@@ -199,10 +201,11 @@ class TextInferenceProvider extends ChangeNotifier {
199201
stores.add(ObjectBoxStore(embeddings: embeddingsModel!, group: knowledgeGroup!));
200202
}
201203

202-
final chain = buildRAGChain(_inference!, embeddingsModel!, OpenVINOLLMOptions(temperature: temperature, topP: topP), stores);
204+
final chain = buildRAGChain(_inference!, embeddingsModel!, OpenVINOLLMOptions(temperature: temperature, topP: topP), stores, memory);
203205
final input = await chain.documentChain.invoke({"question": message}) as Map;
204-
print(input);
205-
final docs = List<String>.from(input["docs"].map((Document doc) => doc.metadata["source"]).toSet());
206+
final docs = input.containsKey("docs")
207+
? List<String>.from(input["docs"].map((Document doc) => doc.metadata["source"]).toSet())
208+
: null;
206209

207210
String modelOutput = "";
208211
Metrics? metrics;
@@ -216,6 +219,11 @@ class TextInferenceProvider extends ChangeNotifier {
216219
onToken(token);
217220
}
218221

222+
memory.saveContext(
223+
inputValues: {'input': message},
224+
outputValues: {'output': modelOutput},
225+
);
226+
219227
if (_messages.isNotEmpty) {
220228
_messages.add(Message(Speaker.assistant, modelOutput, metrics, DateTime.now(), sources: docs));
221229
}
@@ -250,6 +258,7 @@ class TextInferenceProvider extends ChangeNotifier {
250258
void reset() {
251259
_inference?.forceStop();
252260
_inference?.clearHistory();
261+
memory.clear();
253262
for (final file in _userFiles) {
254263
final ids = file.documents.map((p) => p.id).whereType<String>().toList();
255264
store?.delete(ids: ids);

openvino_bindings/src/llm/llm_inference.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@ ov::genai::DecodedResults LLMInference::prompt(std::string message, bool apply_t
3030
history.push_back({{"role", "user"}, {"content", message}});
3131
_stop = false;
3232

33-
auto prompt = (apply_template && has_chat_template()
34-
? pipe.get_tokenizer().apply_chat_template(history, true)
35-
: message);
33+
//auto prompt = (apply_template && has_chat_template()
34+
// ? pipe.get_tokenizer().apply_chat_template(history, true)
35+
// : message);
36+
auto prompt = message;
3637

3738
ov::genai::GenerationConfig config;
3839
config.max_new_tokens = 1000;

0 commit comments

Comments
 (0)