1- using ChatGptNet ;
2- using ChatGptNet . Extensions ;
3- using DatabaseGpt . Abstractions ;
1+ using DatabaseGpt . Abstractions ;
42using DatabaseGpt . Exceptions ;
53using DatabaseGpt . Models ;
64using DatabaseGpt . Settings ;
5+ using Microsoft . Extensions . AI ;
6+ using Microsoft . Extensions . Caching . Hybrid ;
77using Polly ;
88using Polly . Registry ;
9+ using ChatHistory = System . Collections . Generic . List < Microsoft . Extensions . AI . ChatMessage > ;
910
1011namespace DatabaseGpt ;
1112
12- internal class DatabaseGptClient ( IChatGptClient chatGptClient , ResiliencePipelineProvider < string > pipelineProvider , IServiceProvider serviceProvider , DatabaseGptSettings databaseGptSettings ) : IDatabaseGptClient
13+ internal class DatabaseGptClient ( IChatClient chatGptClient , HybridCache cache , ResiliencePipelineProvider < string > pipelineProvider , IServiceProvider serviceProvider , DatabaseGptSettings databaseGptSettings ) : IDatabaseGptClient
1314{
1415 private readonly IDatabaseGptProvider provider = databaseGptSettings . CreateProvider ( ) ;
1516 private readonly ResiliencePipeline pipeline = pipelineProvider . GetPipeline ( nameof ( DatabaseGptClient ) ) ;
@@ -48,8 +49,8 @@ private async Task<DatabaseGptQueryResult> ExecuteNaturalLanguageQueryInternalAs
4849
4950 private async Task < Guid > CreateSessionAsync ( Guid sessionId , CancellationToken cancellationToken )
5051 {
51- var conversationExists = await chatGptClient . ConversationExistsAsync ( sessionId , cancellationToken ) ;
52- if ( ! conversationExists )
52+ var history = await GetChatHistoryAsync ( sessionId , cancellationToken ) ;
53+ if ( history . Count == 0 )
5354 {
5455 var tables = await provider . GetTablesAsync ( databaseGptSettings . IncludedTables , databaseGptSettings . ExcludedTables , cancellationToken ) ;
5556
@@ -67,7 +68,8 @@ private async Task<Guid> CreateSessionAsync(Guid sessionId, CancellationToken ca
6768 """ ;
6869 }
6970
70- sessionId = await chatGptClient . SetupAsync ( sessionId , systemMessage , cancellationToken ) ;
71+ history . Add ( new ( ChatRole . System , systemMessage ) ) ;
72+ await UpdateCacheAsync ( sessionId , history , cancellationToken ) ;
7173 }
7274
7375 return sessionId ;
@@ -90,9 +92,15 @@ The selected tables should be returned in a comma separated list. Your response
9092 await options . OnStarting . Invoke ( serviceProvider ) ;
9193 }
9294
93- var response = await chatGptClient . AskAsync ( sessionId , request , cancellationToken : cancellationToken ) ;
95+ var chat = await GetChatHistoryAsync ( sessionId , cancellationToken ) ;
96+ chat . Add ( new ( ChatRole . User , question ) ) ;
9497
95- var candidateTables = response . GetContent ( ) ! . Trim ( '\' ' ) ;
98+ var response = await chatGptClient . GetResponseAsync ( chat , cancellationToken : cancellationToken ) ;
99+
100+ chat . Add ( new ( ChatRole . Assistant , response . Text ) ) ;
101+ await UpdateCacheAsync ( sessionId , chat , cancellationToken ) ;
102+
103+ var candidateTables = response . Text . Trim ( '\' ' ) ;
96104 if ( candidateTables == "NONE" )
97105 {
98106 throw new NoTableFoundException ( $ "No available information in the provided tables can be useful for the question '{ question } '.") ;
@@ -134,9 +142,15 @@ CREATE TABLE Table2 (Column3 VARCHAR(255), Column4 VARCHAR(255))
134142 request += $ "{ Environment . NewLine } { queryHints } ";
135143 }
136144
137- var response = await chatGptClient . AskAsync ( sessionId , request , cancellationToken : cancellationToken ) ;
145+ var chat = await GetChatHistoryAsync ( sessionId , cancellationToken ) ;
146+ chat . Add ( new ( ChatRole . User , question ) ) ;
147+
148+ var response = await chatGptClient . GetResponseAsync ( chat , cancellationToken : cancellationToken ) ;
149+
150+ chat . Add ( new ( ChatRole . Assistant , response . Text ) ) ;
151+ await UpdateCacheAsync ( sessionId , chat , cancellationToken ) ;
138152
139- var query = response . GetContent ( ) ! ;
153+ var query = response . Text ;
140154 if ( query == "NONE" )
141155 {
142156 throw new InvalidSqlException ( $ "The question '{ question } ' requires an INSERT, UPDATE or DELETE command, that isn't supported.") ;
@@ -155,6 +169,27 @@ CREATE TABLE Table2 (Column3 VARCHAR(255), Column4 VARCHAR(255))
155169 return query ;
156170 }
157171
172+ private async Task UpdateCacheAsync ( Guid conversationId , ChatHistory chat , CancellationToken cancellationToken )
173+ {
174+ if ( chat . Count > databaseGptSettings . MessageLimit )
175+ {
176+ chat . RemoveRange ( 0 , chat . Count - databaseGptSettings . MessageLimit ) ;
177+ }
178+
179+ await cache . SetAsync ( conversationId . ToString ( ) , chat , cancellationToken : cancellationToken ) ;
180+ }
181+
182+ private async Task < ChatHistory > GetChatHistoryAsync ( Guid conversationId , CancellationToken cancellationToken )
183+ {
184+ var historyCache = await cache . GetOrCreateAsync ( conversationId . ToString ( ) , ( cancellationToken ) =>
185+ {
186+ return ValueTask . FromResult < ChatHistory > ( [ ] ) ;
187+ } , cancellationToken : cancellationToken ) ;
188+
189+ var chat = new ChatHistory ( historyCache ) ;
190+ return chat ;
191+ }
192+
158193 protected virtual void Dispose ( bool disposing )
159194 {
160195 if ( ! disposedValue )
0 commit comments