|
| 1 | +// Copyright (c) 2024, Oracle and/or its affiliates. |
| 2 | +// Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
| 3 | + |
| 4 | +// ORACLE AND ITS AFFILIATES DO NOT PROVIDE ANY WARRANTY WHATSOEVER, EXPRESS OR IMPLIED, |
| 5 | +// FOR ANY SOFTWARE, MATERIAL OR CONTENT OF ANY KIND CONTAINED OR PRODUCED WITHIN THIS REPOSITORY, |
| 6 | +// AND IN PARTICULAR SPECIFICALLY DISCLAIM ANY AND ALL IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, |
| 7 | +// MERCHANTABILITY, AND FITNESS FOR A PARTICULAR PURPOSE. FURTHERMORE, ORACLE AND ITS AFFILIATES |
| 8 | +// DO NOT REPRESENT THAT ANY CUSTOMARY SECURITY REVIEW HAS BEEN PERFORMED WITH RESPECT TO ANY SOFTWARE, |
| 9 | +// MATERIAL OR CONTENT CONTAINED OR PRODUCED WITHIN THIS REPOSITORY. IN ADDITION, AND WITHOUT LIMITING |
| 10 | +// THE FOREGOING, THIRD PARTIES MAY HAVE POSTED SOFTWARE, MATERIAL OR CONTENT TO THIS REPOSITORY WITHOUT |
| 11 | +// ANY REVIEW. USE AT YOUR OWN RISK. |
| 12 | + |
| 13 | +package com.example.demoai; |
| 14 | + |
| 15 | +import java.util.List; |
| 16 | +import java.util.Map; |
| 17 | +import java.util.Optional; |
| 18 | +import java.util.Arrays; |
| 19 | +import java.util.Collections; |
| 20 | + |
| 21 | +import org.springframework.ai.document.Document; |
| 22 | +import org.springframework.ai.embedding.EmbeddingClient; |
| 23 | +import org.springframework.ai.vectorstore.SearchRequest; |
| 24 | +import org.springframework.ai.vectorstore.VectorStore; |
| 25 | +import org.springframework.beans.factory.InitializingBean; |
| 26 | +import org.springframework.beans.factory.annotation.Value; |
| 27 | +import org.springframework.jdbc.core.BatchPreparedStatementSetter; |
| 28 | +import org.springframework.jdbc.core.JdbcTemplate; |
| 29 | +import org.springframework.jdbc.core.PreparedStatementSetter; |
| 30 | +import org.springframework.jdbc.core.RowMapper; |
| 31 | + |
| 32 | +import com.example.model.VectorData; |
| 33 | + |
| 34 | + |
| 35 | +//Oracle DB |
| 36 | +import oracle.jdbc.OracleType; |
| 37 | +import oracle.sql.json.OracleJsonFactory; |
| 38 | +import oracle.sql.json.OracleJsonObject; |
| 39 | + |
| 40 | + |
| 41 | +import java.sql.*; |
| 42 | +import java.util.ArrayList; |
| 43 | +import java.util.HashMap; |
| 44 | + |
| 45 | +import org.slf4j.Logger; |
| 46 | +import org.slf4j.LoggerFactory; |
| 47 | + |
| 48 | +public class OracleDBVectorStore implements VectorStore, InitializingBean { |
| 49 | + |
| 50 | + private static final List<String> DISTANCE_METRICS = Collections.unmodifiableList( |
| 51 | + List.of("MANHATTAN", "EUCLIDEAN", "DOT", "COSINE")); |
| 52 | + |
| 53 | + private Map<String, String> DISTANCE_METRICS_FUNC; |
| 54 | + |
| 55 | + @Value("${config.vectorDB:vectortab}") |
| 56 | + public String VECTOR_TABLE; |
| 57 | + |
| 58 | + public int BATCH_SIZE = 100; |
| 59 | + |
| 60 | + private JdbcTemplate jdbcTemplate; |
| 61 | + |
| 62 | + EmbeddingClient embeddingClient; |
| 63 | + |
| 64 | + @Value("${config.dropDb}") |
| 65 | + private Boolean dropAtStartup; |
| 66 | + |
| 67 | + @Value("${config.distance}") |
| 68 | + private String distance_metric = "COSINE"; |
| 69 | + |
| 70 | + private static final Logger logger = LoggerFactory.getLogger(OracleDBVectorStore.class); |
| 71 | + |
| 72 | + public OracleDBVectorStore(JdbcTemplate jdbcTemplate, EmbeddingClient embClient) { |
| 73 | + |
| 74 | + this.jdbcTemplate = jdbcTemplate; |
| 75 | + this.embeddingClient = embClient; |
| 76 | + this.DISTANCE_METRICS_FUNC = new HashMap<>(); |
| 77 | + this.DISTANCE_METRICS_FUNC.put("MANHATTAN", "L1_DISTANCE"); |
| 78 | + this.DISTANCE_METRICS_FUNC.put("EUCLIDEAN", "L2_DISTANCE"); |
| 79 | + this.DISTANCE_METRICS_FUNC.put("DOT", "INNER_PRODUCT"); |
| 80 | + this.DISTANCE_METRICS_FUNC.put("COSINE", "COSINE_DISTANCE"); |
| 81 | + } |
| 82 | + |
| 83 | + @Override |
| 84 | + public void add(List<Document> documents) { |
| 85 | + |
| 86 | + int size = documents.size(); |
| 87 | + |
| 88 | + this.jdbcTemplate.batchUpdate("INSERT INTO " + this.VECTOR_TABLE + " (text,embeddings,metadata) VALUES (?,?,?)", |
| 89 | + new BatchPreparedStatementSetter() { |
| 90 | + @Override |
| 91 | + public void setValues(PreparedStatement ps, int i) throws SQLException { |
| 92 | + |
| 93 | + var document = documents.get(i); |
| 94 | + var text = document.getContent(); |
| 95 | + |
| 96 | + OracleJsonFactory factory = new OracleJsonFactory(); |
| 97 | + OracleJsonObject jsonObj = factory.createObject(); |
| 98 | + Map<String, Object> metaData = document.getMetadata(); |
| 99 | + for (Map.Entry<String, Object> entry : metaData.entrySet()) { |
| 100 | + jsonObj.put(entry.getKey(), String.valueOf(entry.getValue())); |
| 101 | + } |
| 102 | + |
| 103 | + List<Double> vectorList = embeddingClient.embed(document); |
| 104 | + double[] embeddings = new double[vectorList.size()]; |
| 105 | + for (int j = 0; j < vectorList.size(); j++) { |
| 106 | + embeddings[j] = vectorList.get(j); |
| 107 | + } |
| 108 | + |
| 109 | + ps.setString(1, text); |
| 110 | + ps.setObject(2, embeddings, OracleType.VECTOR); |
| 111 | + ps.setObject(3, jsonObj, OracleType.JSON); |
| 112 | + |
| 113 | + } |
| 114 | + |
| 115 | + @Override |
| 116 | + public int getBatchSize() { |
| 117 | + return size; |
| 118 | + } |
| 119 | + }); |
| 120 | + |
| 121 | + } |
| 122 | + |
| 123 | + |
| 124 | + @Override |
| 125 | + public Optional<Boolean> delete(List<String> idList) { |
| 126 | + |
| 127 | + String sql = "DELETE FROM " + this.VECTOR_TABLE + " WHERE id = ?"; |
| 128 | + int count[][] = jdbcTemplate.batchUpdate(sql, idList, BATCH_SIZE, (ps, argument) -> { |
| 129 | + ps.setString(1, argument); |
| 130 | + }); |
| 131 | + |
| 132 | + int sum = Arrays.stream(count).flatMapToInt(Arrays::stream).sum(); |
| 133 | + logger.info("MSG: Deleted " + sum + " records"); |
| 134 | + |
| 135 | + return Optional.of(sum == idList.size()); |
| 136 | + } |
| 137 | + |
| 138 | + |
| 139 | + @Override |
| 140 | + public List<Document> similaritySearch(SearchRequest request) { |
| 141 | + |
| 142 | + List<VectorData> nearest = new ArrayList<>(); |
| 143 | + |
| 144 | + logger.info("MSG: REQUESTED QUERY " + request.getQuery()); |
| 145 | + List<Double> queryEmbeddings = embeddingClient.embed(request.getQuery()); |
| 146 | + logger.info("MSG: EMBEDDINGS SIZE: " + queryEmbeddings.size()); |
| 147 | + |
| 148 | + logger.info("MSG: DISTANCE METRICS: " + distance_metric); |
| 149 | + if (DISTANCE_METRICS_FUNC.get(distance_metric) == null) { |
| 150 | + logger.error( |
| 151 | + "ERROR: wrong distance metrics set. Allowed values are: " + String.join(",", DISTANCE_METRICS)); |
| 152 | + System.exit(1); |
| 153 | + } |
| 154 | + logger.info("MSG: DISTANCE METRICS FUNCTION: " + DISTANCE_METRICS_FUNC.get(distance_metric)); |
| 155 | + int topK = request.getTopK(); |
| 156 | + |
| 157 | + try { |
| 158 | + nearest = similaritySearchByMetrics(VECTOR_TABLE, queryEmbeddings, topK, |
| 159 | + this.DISTANCE_METRICS_FUNC.get(distance_metric)); |
| 160 | + } catch (Exception e) { |
| 161 | + logger.error(e.toString()); |
| 162 | + } |
| 163 | + |
| 164 | + List<Document> documents = new ArrayList<>(); |
| 165 | + |
| 166 | + for (VectorData d : nearest) { |
| 167 | + OracleJsonObject metadata = d.getMetadata(); |
| 168 | + Map<String, Object> map = new HashMap<>(); |
| 169 | + for (String key : metadata.keySet()) { |
| 170 | + map.put(key, metadata.get(key).toString()); |
| 171 | + } |
| 172 | + Document doc = new Document(d.getText(), map); |
| 173 | + documents.add(doc); |
| 174 | + |
| 175 | + } |
| 176 | + return documents; |
| 177 | + |
| 178 | + } |
| 179 | + |
| 180 | + List<VectorData> similaritySearchByMetrics(String vectortab, List<Double> vector, int topK, |
| 181 | + String distance_metrics_func) throws SQLException { |
| 182 | + List<VectorData> results = new ArrayList<>(); |
| 183 | + double[] doubleVector = new double[vector.size()]; |
| 184 | + for (int i = 0; i < vector.size(); i++) { |
| 185 | + doubleVector[i] = vector.get(i); |
| 186 | + } |
| 187 | + |
| 188 | + try { |
| 189 | + |
| 190 | + String similaritySql = "SELECT id,embeddings,metadata,text FROM " + vectortab |
| 191 | + + " ORDER BY " + distance_metrics_func + "(embeddings, ?)" |
| 192 | + + " FETCH FIRST ? ROWS ONLY"; |
| 193 | + |
| 194 | + results = jdbcTemplate.query(similaritySql, |
| 195 | + new PreparedStatementSetter() { |
| 196 | + public void setValues(java.sql.PreparedStatement ps) throws SQLException { |
| 197 | + ps.setObject(1, doubleVector, OracleType.VECTOR); |
| 198 | + ps.setObject(2, topK, OracleType.NUMBER); |
| 199 | + } |
| 200 | + }, |
| 201 | + new RowMapper<VectorData>() { |
| 202 | + public VectorData mapRow(ResultSet rs, int rowNum) throws SQLException { |
| 203 | + return new VectorData(rs.getString("id"), |
| 204 | + rs.getObject("embeddings", double[].class), |
| 205 | + rs.getObject("text", String.class), |
| 206 | + rs.getObject("metadata", OracleJsonObject.class)); |
| 207 | + } |
| 208 | + }); |
| 209 | + |
| 210 | + } catch (Exception e) { |
| 211 | + logger.error("ERROR: " + e.getMessage()); |
| 212 | + } |
| 213 | + return results; |
| 214 | + } |
| 215 | + |
| 216 | + |
| 217 | + @Override |
| 218 | + public void afterPropertiesSet() throws Exception { |
| 219 | + int initialSize = 0; |
| 220 | + try { |
| 221 | + |
| 222 | + initialSize = jdbcTemplate.queryForObject("select count(*) from " + this.VECTOR_TABLE, Integer.class); |
| 223 | + logger.info("MSG: table " + this.VECTOR_TABLE + " exists with " + initialSize + " chunks"); |
| 224 | + |
| 225 | + if (dropAtStartup) { |
| 226 | + logger.info("MSG: DROP TABLE " + this.VECTOR_TABLE + " AT EVERY STARTUP"); |
| 227 | + throw new Exception("DROP TABLE EVERY STARTUP"); |
| 228 | + } |
| 229 | + } catch (Exception e) { |
| 230 | + try { |
| 231 | + logger.info("MSG: DROPPING TABLE " + this.VECTOR_TABLE); |
| 232 | + jdbcTemplate.execute( |
| 233 | + "BEGIN\n" + |
| 234 | + " EXECUTE IMMEDIATE 'DROP TABLE " + this.VECTOR_TABLE + " CASCADE CONSTRAINTS';\n" + |
| 235 | + "EXCEPTION\n" + |
| 236 | + " WHEN OTHERS THEN\n" + |
| 237 | + " IF SQLCODE != -942 THEN\n" + |
| 238 | + " RAISE;\n" + |
| 239 | + " END IF;\n" + |
| 240 | + "END;"); |
| 241 | + } catch (Exception ex1) { |
| 242 | + logger.error("ERROR: DROP TABLE " + this.VECTOR_TABLE + " \n" + e.getMessage()); |
| 243 | + System.exit(1); |
| 244 | + } |
| 245 | + String createTableSql = "CREATE TABLE " + this.VECTOR_TABLE + "(" + |
| 246 | + "id NUMBER GENERATED AS IDENTITY ," + |
| 247 | + "text CLOB," + |
| 248 | + "embeddings VECTOR," + |
| 249 | + "metadata JSON," + |
| 250 | + "PRIMARY KEY (id))"; |
| 251 | + try { |
| 252 | + this.jdbcTemplate.execute( |
| 253 | + "BEGIN\n" + |
| 254 | + " EXECUTE IMMEDIATE '" + |
| 255 | + createTableSql + "' ;\n" + |
| 256 | + "EXCEPTION\n" + |
| 257 | + " WHEN OTHERS THEN\n" + |
| 258 | + " IF SQLCODE != -942 THEN\n" + |
| 259 | + " RAISE;\n" + |
| 260 | + " END IF;\n" + |
| 261 | + "END;"); |
| 262 | + logger.info("OK: CREATE TABLE " + this.VECTOR_TABLE); |
| 263 | + } catch (Exception ex2) { |
| 264 | + logger.error("ERROR: CREATE TABLE" + e.getMessage()); |
| 265 | + System.exit(1); |
| 266 | + } |
| 267 | + } |
| 268 | + return; |
| 269 | + } |
| 270 | + |
| 271 | +} |
0 commit comments