1919import java .util .Map ;
2020import java .util .Optional ;
2121import java .util .Set ;
22- import java .util .concurrent .atomic .AtomicReference ;
2322import java .util .function .Predicate ;
2423import java .util .stream .Collectors ;
2524import java .util .stream .Stream ;
2625
2726import org .slf4j .Logger ;
2827import org .slf4j .LoggerFactory ;
2928import org .springframework .ai .chat .client .ChatClient ;
29+ import org .springframework .ai .chat .client .ChatClient .Builder ;
3030import org .springframework .ai .chat .messages .Message ;
3131import org .springframework .ai .chat .messages .UserMessage ;
3232import org .springframework .ai .chat .model .ChatResponse ;
3737import org .springframework .jdbc .core .JdbcTemplate ;
3838import org .springframework .jdbc .support .rowset .SqlRowSet ;
3939import org .springframework .stereotype .Service ;
40- import org .springframework .ai .chat .client .ChatClient .Builder ;
4140
4241import ch .xxx .aidoclibchat .domain .client .ImportClient ;
4342import ch .xxx .aidoclibchat .domain .common .MetaData ;
@@ -99,6 +98,10 @@ Pay attention to use date('now') function to get the current date, if the questi
9998 @ Value ("${spring.profiles.active:}" )
10099 private String activeProfile ;
101100
101+ record MyTableData (String joinColumn , String joinTable , String columnValue , List <TableNameSchema > tableRecords ,
102+ TableColumnNames tableColumnNames ) {
103+ }
104+
102105 public TableService (ImportClient importClient , ImportService importService , Builder builder ,
103106 JdbcTemplate jdbcTemplate , TableMetadataRepository tableMetadataRepository ,
104107 DocumentVsRepository documentVsRepository ) {
@@ -149,37 +152,48 @@ private Prompt createPrompt(SearchDto searchDto, EmbeddingContainer documentCont
149152 List <TableNameSchema > tableRecords = this .tableMetadataRepository
150153 .findByTableNameIn (tableColumnNames .tableNames ()).stream ()
151154 .map (tableMetaData -> new TableNameSchema (tableMetaData .getTableName (), tableMetaData .getTableDdl ()))
152- .collect (Collectors .toList ());
153- final AtomicReference <String > joinColumn = new AtomicReference <String >("" );
154- final AtomicReference <String > joinTable = new AtomicReference <String >("" );
155- final AtomicReference <String > columnValue = new AtomicReference <String >("" );
156- sortedRowDocs .stream ().filter (myDoc -> minRowDistance <= MAX_ROW_DISTANCE )
155+ .collect (Collectors .toList ());
156+ var result = sortedRowDocs .stream ().filter (myDoc -> minRowDistance <= MAX_ROW_DISTANCE )
157157 .filter (myRowDoc -> tableRecords .stream ()
158158 .filter (myRecord -> myRecord .name ().equals (myRowDoc .getMetadata ().get (MetaData .TABLE_NAME )))
159159 .findFirst ().isEmpty ())
160- .findFirst ().ifPresent (myRowDoc -> {
161- joinTable .set (((String ) myRowDoc .getMetadata ().get (MetaData .TABLE_NAME )));
162- joinColumn .set (((String ) myRowDoc .getMetadata ().get (MetaData .DATANAME )));
163- tableColumnNames .columnNames ().add (((String ) myRowDoc .getMetadata ().get (MetaData .DATANAME )));
164- columnValue .set (myRowDoc .getText ());
165- this .tableMetadataRepository
166- .findByTableNameIn (List .of (((String ) myRowDoc .getMetadata ().get (MetaData .TABLE_NAME ))))
167- .stream ()
168- .map (myTableMetadata -> new TableNameSchema (myTableMetadata .getTableName (),
169- myTableMetadata .getTableDdl ()))
170- .findFirst ().ifPresent (myRecord -> tableRecords .add (myRecord ));
171- });
172- var messages = this .createMessages (searchDto , minRowDistance , tableColumnNames , tableRecords , joinColumn ,
173- joinTable , columnValue );
160+ .findFirst ().map (myRowDoc -> createTableData (tableColumnNames , tableRecords , myRowDoc ))
161+ .orElseThrow ();
162+ var messages = this .createMessages (searchDto , minRowDistance , result .tableColumnNames (), result .tableRecords (), result .joinColumn (),
163+ result .joinTable (), result .columnValue ());
174164 Prompt prompt = new Prompt (messages );
175165// LOGGER.info("Prompt: {}", prompt.getContents());
176166 return prompt ;
177167 }
178168
169+ private MyTableData createTableData (TableColumnNames tableColumnNames , List <TableNameSchema > tableRecords ,
170+ Document myRowDoc ) {
171+ tableColumnNames .columnNames ().add (((String ) myRowDoc .getMetadata ().get (MetaData .DATANAME )));
172+ return findTable (myRowDoc ).map (myRecord -> {
173+ tableRecords .add (myRecord );
174+ return createMyTableResult (tableColumnNames , tableRecords , myRowDoc );
175+ }).orElse (createMyTableResult (tableColumnNames , tableRecords , myRowDoc ));
176+ }
177+
178+ private MyTableData createMyTableResult (TableColumnNames tableColumnNames , List <TableNameSchema > tableRecords ,
179+ Document myRowDoc ) {
180+ return new MyTableData (((String ) myRowDoc .getMetadata ().get (MetaData .DATANAME )),
181+ ((String ) myRowDoc .getMetadata ().get (MetaData .TABLE_NAME )), myRowDoc .getText (), tableRecords ,
182+ tableColumnNames );
183+ }
184+
185+ private Optional <TableNameSchema > findTable (Document myRowDoc ) {
186+ return this .tableMetadataRepository
187+ .findByTableNameIn (List .of (((String ) myRowDoc .getMetadata ().get (MetaData .TABLE_NAME )))).stream ()
188+ .map (myTableMetadata -> new TableNameSchema (myTableMetadata .getTableName (),
189+ myTableMetadata .getTableDdl ()))
190+ .findFirst ();
191+ }
192+
179193 private List <Message > createMessages (SearchDto searchDto , final Float minRowDistance ,
180194 TableColumnNames tableColumnNames , List <TableNameSchema > tableRecords ,
181- final AtomicReference < String > joinColumn , final AtomicReference < String > joinTable ,
182- final AtomicReference < String > columnValue ) {
195+ final String joinColumn , final String joinTable ,
196+ final String columnValue ) {
183197 SystemPromptTemplate systemPromptTemplate = this .activeProfile .contains ("ollama" )
184198 ? new SystemPromptTemplate (minRowDistance > MAX_ROW_DISTANCE ? String .format (this .ollamaPrompt , "" )
185199 : String .format (this .ollamaPrompt , columnMatch ))
@@ -188,8 +202,8 @@ private List<Message> createMessages(SearchDto searchDto, final Float minRowDist
188202 Message systemMessage = systemPromptTemplate .createMessage (
189203 Map .of ("columns" , tableColumnNames .columnNames ().stream ().collect (Collectors .joining ("," )), "schemas" ,
190204 tableRecords .stream ().map (myRecord -> myRecord .schema ()).collect (Collectors .joining (";" )),
191- "prompt" , searchDto .getSearchString (), "joinColumn" , joinColumn . get () , "joinTable" ,
192- joinTable . get () , "columnValue" , columnValue . get () ));
205+ "prompt" , searchDto .getSearchString (), "joinColumn" , joinColumn , "joinTable" ,
206+ joinTable , "columnValue" , columnValue ));
193207 UserMessage userMessage = this .activeProfile .contains ("ollama" ) ? new UserMessage (systemMessage .getText ())
194208 : new UserMessage (searchDto .getSearchString ());
195209 return List .of (systemMessage , userMessage );
0 commit comments