Skip to content

Commit 566247c

Browse files
committed
feat: added src code
1 parent c34e056 commit 566247c

File tree

16 files changed

+975
-0
lines changed

16 files changed

+975
-0
lines changed

.idea/compiler.xml

Lines changed: 18 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/encodings.xml

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/jarRepositories.xml

Lines changed: 25 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) 2024, Oracle and/or its affiliates.
2+
// Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
3+
4+
// ORACLE AND ITS AFFILIATES DO NOT PROVIDE ANY WARRANTY WHATSOEVER, EXPRESS OR IMPLIED,
5+
// FOR ANY SOFTWARE, MATERIAL OR CONTENT OF ANY KIND CONTAINED OR PRODUCED WITHIN THIS REPOSITORY,
6+
// AND IN PARTICULAR SPECIFICALLY DISCLAIM ANY AND ALL IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT,
7+
// MERCHANTABILITY, AND FITNESS FOR A PARTICULAR PURPOSE. FURTHERMORE, ORACLE AND ITS AFFILIATES
8+
// DO NOT REPRESENT THAT ANY CUSTOMARY SECURITY REVIEW HAS BEEN PERFORMED WITH RESPECT TO ANY SOFTWARE,
9+
// MATERIAL OR CONTENT CONTAINED OR PRODUCED WITHIN THIS REPOSITORY. IN ADDITION, AND WITHOUT LIMITING
10+
// THE FOREGOING, THIRD PARTIES MAY HAVE POSTED SOFTWARE, MATERIAL OR CONTENT TO THIS REPOSITORY WITHOUT
11+
// ANY REVIEW. USE AT YOUR OWN RISK.
12+
package com.example.demoai;
13+
14+
15+
import org.springframework.ai.embedding.EmbeddingClient;
16+
import org.springframework.ai.openai.OpenAiEmbeddingClient;
17+
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
18+
import org.springframework.ai.vectorstore.VectorStore;
19+
import org.springframework.beans.factory.annotation.Autowired;
20+
import org.springframework.boot.SpringApplication;
21+
import org.springframework.boot.autoconfigure.SpringBootApplication;
22+
import org.springframework.context.annotation.Bean;
23+
import org.springframework.jdbc.core.JdbcTemplate;
24+
//CHANGE
25+
//import org.springframework.ai.ollama.OllamaEmbeddingClient;
26+
27+
28+
@SpringBootApplication
29+
public class DemoaiApplication {
30+
31+
32+
public static void main(String[] args) {
33+
SpringApplication.run(DemoaiApplication.class, args);
34+
}
35+
36+
//CHANGE
37+
@Bean
38+
VectorStore vectorStore(EmbeddingClient ec, JdbcTemplate t) {
39+
//VectorStore vectorStore(OllamaEmbeddingClient ec, JdbcTemplate t) {
40+
return new OracleDBVectorStore(t, ec);
41+
}
42+
43+
@Bean
44+
TokenTextSplitter tokenTextSplitter() {
45+
return new TokenTextSplitter();
46+
}
47+
}
48+
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
// Copyright (c) 2024, Oracle and/or its affiliates.
2+
// Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
3+
4+
// ORACLE AND ITS AFFILIATES DO NOT PROVIDE ANY WARRANTY WHATSOEVER, EXPRESS OR IMPLIED,
5+
// FOR ANY SOFTWARE, MATERIAL OR CONTENT OF ANY KIND CONTAINED OR PRODUCED WITHIN THIS REPOSITORY,
6+
// AND IN PARTICULAR SPECIFICALLY DISCLAIM ANY AND ALL IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT,
7+
// MERCHANTABILITY, AND FITNESS FOR A PARTICULAR PURPOSE. FURTHERMORE, ORACLE AND ITS AFFILIATES
8+
// DO NOT REPRESENT THAT ANY CUSTOMARY SECURITY REVIEW HAS BEEN PERFORMED WITH RESPECT TO ANY SOFTWARE,
9+
// MATERIAL OR CONTENT CONTAINED OR PRODUCED WITHIN THIS REPOSITORY. IN ADDITION, AND WITHOUT LIMITING
10+
// THE FOREGOING, THIRD PARTIES MAY HAVE POSTED SOFTWARE, MATERIAL OR CONTENT TO THIS REPOSITORY WITHOUT
11+
// ANY REVIEW. USE AT YOUR OWN RISK.
12+
13+
package com.example.demoai;
14+
15+
import java.util.List;
16+
import java.util.Map;
17+
import java.util.Optional;
18+
import java.util.Arrays;
19+
import java.util.Collections;
20+
21+
import org.springframework.ai.document.Document;
22+
import org.springframework.ai.embedding.EmbeddingClient;
23+
import org.springframework.ai.vectorstore.SearchRequest;
24+
import org.springframework.ai.vectorstore.VectorStore;
25+
import org.springframework.beans.factory.InitializingBean;
26+
import org.springframework.beans.factory.annotation.Value;
27+
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
28+
import org.springframework.jdbc.core.JdbcTemplate;
29+
import org.springframework.jdbc.core.PreparedStatementSetter;
30+
import org.springframework.jdbc.core.RowMapper;
31+
32+
import com.example.model.VectorData;
33+
34+
35+
//Oracle DB
36+
import oracle.jdbc.OracleType;
37+
import oracle.sql.json.OracleJsonFactory;
38+
import oracle.sql.json.OracleJsonObject;
39+
40+
41+
import java.sql.*;
42+
import java.util.ArrayList;
43+
import java.util.HashMap;
44+
45+
import org.slf4j.Logger;
46+
import org.slf4j.LoggerFactory;
47+
48+
public class OracleDBVectorStore implements VectorStore, InitializingBean {
49+
50+
private static final List<String> DISTANCE_METRICS = Collections.unmodifiableList(
51+
List.of("MANHATTAN", "EUCLIDEAN", "DOT", "COSINE"));
52+
53+
private Map<String, String> DISTANCE_METRICS_FUNC;
54+
55+
@Value("${config.vectorDB:vectortab}")
56+
public String VECTOR_TABLE;
57+
58+
public int BATCH_SIZE = 100;
59+
60+
private JdbcTemplate jdbcTemplate;
61+
62+
EmbeddingClient embeddingClient;
63+
64+
@Value("${config.dropDb}")
65+
private Boolean dropAtStartup;
66+
67+
@Value("${config.distance}")
68+
private String distance_metric = "COSINE";
69+
70+
private static final Logger logger = LoggerFactory.getLogger(OracleDBVectorStore.class);
71+
72+
public OracleDBVectorStore(JdbcTemplate jdbcTemplate, EmbeddingClient embClient) {
73+
74+
this.jdbcTemplate = jdbcTemplate;
75+
this.embeddingClient = embClient;
76+
this.DISTANCE_METRICS_FUNC = new HashMap<>();
77+
this.DISTANCE_METRICS_FUNC.put("MANHATTAN", "L1_DISTANCE");
78+
this.DISTANCE_METRICS_FUNC.put("EUCLIDEAN", "L2_DISTANCE");
79+
this.DISTANCE_METRICS_FUNC.put("DOT", "INNER_PRODUCT");
80+
this.DISTANCE_METRICS_FUNC.put("COSINE", "COSINE_DISTANCE");
81+
}
82+
83+
@Override
84+
public void add(List<Document> documents) {
85+
86+
int size = documents.size();
87+
88+
this.jdbcTemplate.batchUpdate("INSERT INTO " + this.VECTOR_TABLE + " (text,embeddings,metadata) VALUES (?,?,?)",
89+
new BatchPreparedStatementSetter() {
90+
@Override
91+
public void setValues(PreparedStatement ps, int i) throws SQLException {
92+
93+
var document = documents.get(i);
94+
var text = document.getContent();
95+
96+
OracleJsonFactory factory = new OracleJsonFactory();
97+
OracleJsonObject jsonObj = factory.createObject();
98+
Map<String, Object> metaData = document.getMetadata();
99+
for (Map.Entry<String, Object> entry : metaData.entrySet()) {
100+
jsonObj.put(entry.getKey(), String.valueOf(entry.getValue()));
101+
}
102+
103+
List<Double> vectorList = embeddingClient.embed(document);
104+
double[] embeddings = new double[vectorList.size()];
105+
for (int j = 0; j < vectorList.size(); j++) {
106+
embeddings[j] = vectorList.get(j);
107+
}
108+
109+
ps.setString(1, text);
110+
ps.setObject(2, embeddings, OracleType.VECTOR);
111+
ps.setObject(3, jsonObj, OracleType.JSON);
112+
113+
}
114+
115+
@Override
116+
public int getBatchSize() {
117+
return size;
118+
}
119+
});
120+
121+
}
122+
123+
124+
@Override
125+
public Optional<Boolean> delete(List<String> idList) {
126+
127+
String sql = "DELETE FROM " + this.VECTOR_TABLE + " WHERE id = ?";
128+
int count[][] = jdbcTemplate.batchUpdate(sql, idList, BATCH_SIZE, (ps, argument) -> {
129+
ps.setString(1, argument);
130+
});
131+
132+
int sum = Arrays.stream(count).flatMapToInt(Arrays::stream).sum();
133+
logger.info("MSG: Deleted " + sum + " records");
134+
135+
return Optional.of(sum == idList.size());
136+
}
137+
138+
139+
@Override
140+
public List<Document> similaritySearch(SearchRequest request) {
141+
142+
List<VectorData> nearest = new ArrayList<>();
143+
144+
logger.info("MSG: REQUESTED QUERY " + request.getQuery());
145+
List<Double> queryEmbeddings = embeddingClient.embed(request.getQuery());
146+
logger.info("MSG: EMBEDDINGS SIZE: " + queryEmbeddings.size());
147+
148+
logger.info("MSG: DISTANCE METRICS: " + distance_metric);
149+
if (DISTANCE_METRICS_FUNC.get(distance_metric) == null) {
150+
logger.error(
151+
"ERROR: wrong distance metrics set. Allowed values are: " + String.join(",", DISTANCE_METRICS));
152+
System.exit(1);
153+
}
154+
logger.info("MSG: DISTANCE METRICS FUNCTION: " + DISTANCE_METRICS_FUNC.get(distance_metric));
155+
int topK = request.getTopK();
156+
157+
try {
158+
nearest = similaritySearchByMetrics(VECTOR_TABLE, queryEmbeddings, topK,
159+
this.DISTANCE_METRICS_FUNC.get(distance_metric));
160+
} catch (Exception e) {
161+
logger.error(e.toString());
162+
}
163+
164+
List<Document> documents = new ArrayList<>();
165+
166+
for (VectorData d : nearest) {
167+
OracleJsonObject metadata = d.getMetadata();
168+
Map<String, Object> map = new HashMap<>();
169+
for (String key : metadata.keySet()) {
170+
map.put(key, metadata.get(key).toString());
171+
}
172+
Document doc = new Document(d.getText(), map);
173+
documents.add(doc);
174+
175+
}
176+
return documents;
177+
178+
}
179+
180+
List<VectorData> similaritySearchByMetrics(String vectortab, List<Double> vector, int topK,
181+
String distance_metrics_func) throws SQLException {
182+
List<VectorData> results = new ArrayList<>();
183+
double[] doubleVector = new double[vector.size()];
184+
for (int i = 0; i < vector.size(); i++) {
185+
doubleVector[i] = vector.get(i);
186+
}
187+
188+
try {
189+
190+
String similaritySql = "SELECT id,embeddings,metadata,text FROM " + vectortab
191+
+ " ORDER BY " + distance_metrics_func + "(embeddings, ?)"
192+
+ " FETCH FIRST ? ROWS ONLY";
193+
194+
results = jdbcTemplate.query(similaritySql,
195+
new PreparedStatementSetter() {
196+
public void setValues(java.sql.PreparedStatement ps) throws SQLException {
197+
ps.setObject(1, doubleVector, OracleType.VECTOR);
198+
ps.setObject(2, topK, OracleType.NUMBER);
199+
}
200+
},
201+
new RowMapper<VectorData>() {
202+
public VectorData mapRow(ResultSet rs, int rowNum) throws SQLException {
203+
return new VectorData(rs.getString("id"),
204+
rs.getObject("embeddings", double[].class),
205+
rs.getObject("text", String.class),
206+
rs.getObject("metadata", OracleJsonObject.class));
207+
}
208+
});
209+
210+
} catch (Exception e) {
211+
logger.error("ERROR: " + e.getMessage());
212+
}
213+
return results;
214+
}
215+
216+
217+
@Override
218+
public void afterPropertiesSet() throws Exception {
219+
int initialSize = 0;
220+
try {
221+
222+
initialSize = jdbcTemplate.queryForObject("select count(*) from " + this.VECTOR_TABLE, Integer.class);
223+
logger.info("MSG: table " + this.VECTOR_TABLE + " exists with " + initialSize + " chunks");
224+
225+
if (dropAtStartup) {
226+
logger.info("MSG: DROP TABLE " + this.VECTOR_TABLE + " AT EVERY STARTUP");
227+
throw new Exception("DROP TABLE EVERY STARTUP");
228+
}
229+
} catch (Exception e) {
230+
try {
231+
logger.info("MSG: DROPPING TABLE " + this.VECTOR_TABLE);
232+
jdbcTemplate.execute(
233+
"BEGIN\n" +
234+
" EXECUTE IMMEDIATE 'DROP TABLE " + this.VECTOR_TABLE + " CASCADE CONSTRAINTS';\n" +
235+
"EXCEPTION\n" +
236+
" WHEN OTHERS THEN\n" +
237+
" IF SQLCODE != -942 THEN\n" +
238+
" RAISE;\n" +
239+
" END IF;\n" +
240+
"END;");
241+
} catch (Exception ex1) {
242+
logger.error("ERROR: DROP TABLE " + this.VECTOR_TABLE + " \n" + e.getMessage());
243+
System.exit(1);
244+
}
245+
String createTableSql = "CREATE TABLE " + this.VECTOR_TABLE + "(" +
246+
"id NUMBER GENERATED AS IDENTITY ," +
247+
"text CLOB," +
248+
"embeddings VECTOR," +
249+
"metadata JSON," +
250+
"PRIMARY KEY (id))";
251+
try {
252+
this.jdbcTemplate.execute(
253+
"BEGIN\n" +
254+
" EXECUTE IMMEDIATE '" +
255+
createTableSql + "' ;\n" +
256+
"EXCEPTION\n" +
257+
" WHEN OTHERS THEN\n" +
258+
" IF SQLCODE != -942 THEN\n" +
259+
" RAISE;\n" +
260+
" END IF;\n" +
261+
"END;");
262+
logger.info("OK: CREATE TABLE " + this.VECTOR_TABLE);
263+
} catch (Exception ex2) {
264+
logger.error("ERROR: CREATE TABLE" + e.getMessage());
265+
System.exit(1);
266+
}
267+
}
268+
return;
269+
}
270+
271+
}

0 commit comments

Comments
 (0)