Skip to content

Commit fda3b58

Browse files
committed
ok
1 parent f606652 commit fda3b58

File tree

45 files changed

+718
-56
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+718
-56
lines changed
File renamed without changes.
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.datastax.astra.tool.loader.rag;
2+
3+
import com.datastax.astra.client.core.query.Projection;
4+
import com.datastax.astra.client.tables.Table;
5+
import com.datastax.astra.client.tables.commands.options.TableFindOptions;
6+
import com.datastax.astra.tool.loader.rag.sources.RagSource;
7+
8+
import java.util.UUID;
9+
10+
public class RagGenericTest {
11+
12+
public static void main(String[] args) {
13+
14+
String token = System.getenv("ASTRA_DB_APPLICATION_TOKEN");
15+
UUID TEST_TENANT = UUID.fromString("00000000-0000-0000-0000-000000000000");
16+
17+
//Database db = DataAPIClients.astra(token).getDatabase(TEST_TENANT);
18+
//System.out.println(db.getInfo().getName());
19+
RagRepository repo = new RagRepository(token, "goodbards");
20+
Table<RagSource> tableSources = repo.getTableRagSource(TEST_TENANT);
21+
22+
TableFindOptions options = new TableFindOptions()
23+
.projection(Projection.include("name", "source", "location"));
24+
tableSources.find(options).toList().forEach(System.out::println);
25+
}
26+
27+
28+
}
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
package com.datastax.astra.tool.loader.rag;
2+
3+
import com.datastax.astra.client.DataAPIClient;
4+
import com.datastax.astra.client.DataAPIClients;
5+
import com.datastax.astra.client.admin.AstraDBAdmin;
6+
import com.datastax.astra.client.admin.DatabaseAdmin;
7+
import com.datastax.astra.client.core.vectorize.VectorServiceOptions;
8+
import com.datastax.astra.client.databases.Database;
9+
import com.datastax.astra.client.databases.DatabaseOptions;
10+
import com.datastax.astra.client.databases.definition.DatabaseInfo;
11+
import com.datastax.astra.client.tables.Table;
12+
import com.datastax.astra.client.tables.commands.options.CreateTableOptions;
13+
import com.datastax.astra.client.tables.commands.options.CreateVectorIndexOptions;
14+
import com.datastax.astra.client.tables.definition.rows.Row;
15+
import com.datastax.astra.internal.utils.Utils;
16+
import com.datastax.astra.tool.loader.rag.ingestion.RagEmbeddingsModels;
17+
import com.datastax.astra.tool.loader.rag.ingestion.RagIngestionConfig;
18+
import com.datastax.astra.tool.loader.rag.ingestion.RagIngestionJob;
19+
import com.datastax.astra.tool.loader.rag.sources.RagSource;
20+
import com.datastax.astra.tool.loader.rag.stores.RagStore;
21+
import com.dtsx.astra.sdk.db.domain.CloudProviderType;
22+
import lombok.extern.slf4j.Slf4j;
23+
24+
import java.util.Optional;
25+
import java.util.UUID;
26+
27+
@Slf4j
28+
public class RagRepository {
29+
30+
String token;
31+
32+
String keyspace;
33+
34+
CloudProviderType cloudProvider = CloudProviderType.AWS;
35+
36+
String cloudRegion = "us-east-2";
37+
38+
public RagRepository(String token, String keyspace) {
39+
this.token = token;
40+
this.keyspace = keyspace;
41+
}
42+
43+
public Database getOrCreateDatabase(UUID tenantId) {
44+
DataAPIClient dataApiClient = DataAPIClients.astra(token);
45+
AstraDBAdmin astraDBAdmin = dataApiClient.getAdmin();
46+
// Database
47+
Optional<DatabaseInfo> devopsDB = astraDBAdmin
48+
.listDatabases()
49+
.stream()
50+
.filter(db -> tenantId.toString().equals(db.getName()))
51+
.findFirst();
52+
if (devopsDB.isEmpty()) {
53+
log.info("Database {} does not exists and will be created.", tenantId.toString());
54+
DatabaseAdmin dbAdmin = astraDBAdmin
55+
.createDatabase(tenantId.toString(), cloudProvider, cloudRegion);
56+
dbAdmin.createKeyspace(keyspace, true);
57+
return dbAdmin.getDatabase(keyspace);
58+
}
59+
log.info("Database {} already exists.", tenantId);
60+
return dataApiClient.getDatabase(devopsDB.get().getId(), devopsDB.get().getRegion(),
61+
new DatabaseOptions()
62+
.token(token)
63+
.keyspace(keyspace)
64+
// reusing the logging
65+
.dataAPIClientOptions(dataApiClient.getOptions()));
66+
}
67+
68+
private <T> Table<T> getTable(Database db, Class<T> record) {
69+
String tableName = db.getTableName(record);
70+
db.useKeyspace(keyspace);
71+
if (!db.tableExists(tableName)) {
72+
log.info("Table {} does not exists, creating...", tableName);
73+
db.createTable(record, new CreateTableOptions().keyspace(keyspace));
74+
log.info("Table {} has been successfully created", tableName);
75+
}
76+
return db.getTable(record);
77+
}
78+
79+
// --------------------------------------------------------------------
80+
// Rag Source
81+
// --------------------------------------------------------------------
82+
83+
public Table<RagSource> getTableRagSource(UUID tenantId) {
84+
return getTable(getOrCreateDatabase(tenantId), RagSource.class);
85+
}
86+
87+
public UUID registerSource(UUID tenantId, RagSource source) {
88+
return UUID.fromString((String)
89+
getTableRagSource(tenantId)
90+
.insertOne(source)
91+
.getInsertedId()
92+
.get(0));
93+
}
94+
95+
// --------------------------------------------------------------------
96+
// Config
97+
// --------------------------------------------------------------------
98+
99+
public Table<RagIngestionConfig> getTableRagConfig(UUID tenantId) {
100+
return getTable(getOrCreateDatabase(tenantId), RagIngestionConfig.class);
101+
}
102+
103+
public UUID createConfig(UUID tenantId, RagIngestionConfig config) {
104+
return UUID.fromString((String)
105+
getTableRagConfig(tenantId)
106+
.insertOne(config)
107+
.getInsertedId()
108+
.get(0));
109+
}
110+
111+
// --------------------------------------------------------------------
112+
// Jobs
113+
// --------------------------------------------------------------------
114+
115+
public Table<RagIngestionJob> getTableRagJob(UUID tenantId) {
116+
return getTable(getOrCreateDatabase(tenantId), RagIngestionJob.class);
117+
}
118+
119+
// --------------------------------------------------------------------
120+
// Vector Stores
121+
// --------------------------------------------------------------------
122+
123+
public Table<RagStore> getTableRagStore(UUID tenantId, RagIngestionConfig config) {
124+
return getTableRagStore(getOrCreateDatabase(tenantId),
125+
config.getEmbeddingProvider(),
126+
config.getEmbeddingModel(),
127+
config.getEmbeddingDimension(), null);
128+
}
129+
130+
public Table<RagStore> getTableRagStore(Database db, String provider, String model, int dimension, VectorServiceOptions options) {
131+
Utils.hasLength(provider);
132+
Utils.hasLength(model);
133+
String tableName = RagStore.getTableName(provider, model);
134+
db.useKeyspace(keyspace);
135+
136+
if (!db.tableExists(tableName)) {
137+
log.info("Table {} does not exists, creating...", tableName);
138+
Table<Row> table = db.createTable(tableName,
139+
RagStore.getTableDefinition(dimension, options),
140+
new CreateTableOptions().keyspace(keyspace));
141+
log.info("Table {} has been successfully created", tableName);
142+
143+
table.createIndex(RagStore.getIndexName(provider, model) + "_sourceId","source_id");
144+
145+
String indexName = RagStore.getIndexName(provider, model);
146+
table.createVectorIndex(indexName,
147+
RagStore.getVectorIndexDefinition(options),
148+
CreateVectorIndexOptions.IF_NOT_EXISTS);
149+
log.info("Vector Index {} has been successfully created", indexName);
150+
}
151+
return db.getTable(tableName, RagStore.class);
152+
}
153+
154+
public Table<RagStore> getTableRagStore(Database db, RagEmbeddingsModels model) {
155+
return getTableRagStore(db, model.getProvider(), model.getName(), model.getDimension(), null);
156+
}
157+
158+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package com.datastax.astra.tool.loader.rag.ingestion;
2+
3+
public enum RagEmbeddingsModels {
4+
NVIDIA_NEMO("nvidia", "NV-Embed-QA", 1024),
5+
OPENAI_ADA002("open-ai", "text-embedding-ada-002", 1536),
6+
OPENAI_3_SMALL("open-ai", "text-embedding-3-small", 1536),
7+
OPENAI_3_LARGE("open-ai", "text-embedding-3-large", 3072),
8+
AZURE_OPENAI_SMALL("azure_openai", "text-embedding-3-small", 512),
9+
AZURE_OPENAI_LARGE("azure_openai", "text-embedding-3-large", 1024),
10+
AZURE_OPENAI_ADA002("azure_openai", "text-embedding-ada-002", 1536),
11+
HF_MINI_LM_L6("huggingface", "sentence-transformers/all-MiniLM-L6-v2", 384),
12+
VERTEX_AI_GECKO_003("vertexai", "textembedding-gecko@003", 768),
13+
JINA_AI_EMBEDDINGS_V2_EN("jinaai", "jina-embeddings-v2-base-en", 768),
14+
JINA_AI_EMBEDDINGS_V2_DE("jinaai", "jina-embeddings-v2-base-de", 768),
15+
JINA_AI_EMBEDDINGS_V2_ES("jinaai", "jina-embeddings-v2-base-es", 768),
16+
JINA_AI_EMBEDDINGS_V2_ZH("jinaai", "jina-embeddings-v2-base-zh", 768),
17+
JINA_AI_EMBEDDINGS_V2_CODE("jinaai", "jina-embeddings-v2-base-code", 768),
18+
MISTRAL_AI("mistralai", "mistral-embed", 1024),
19+
VOYAGE_AI_2("voyageai", " voyage-2", 1024),
20+
VOYAGE_AI_LAW_2("voyageai", " voyage-law-2", 1024),
21+
VOYAGE_AI_CODE_2("voyageai", " voyage-code-2", 1536),
22+
VOYAGE_AI_LARGE_2("voyageai", " voyage-large-2", 1536),
23+
VOYAGE_AI_LITE_INSTRUCT("voyageai", "voyage-lite-02-instruct", 1024),
24+
UPSTAGE_AI_SOLAR_MINI_1_QUERY("upstageai", "solar-1-mini-embedding-query", 4096),
25+
UPSTAGE_AI_SOLAR_MINI_1_PASSAGE("upstageai", "solar-1-mini-embedding-passage", 4096),
26+
COHERE_EMBED_ENGLISH_V2("cohere", "embed-english-v2.0", 4096),
27+
COHERE_EMBED_ENGLISH_V3("cohere", "embed-english-v3.0", 1024);
28+
29+
private final String provider;
30+
31+
private final String name;
32+
33+
private final int dimension;
34+
35+
RagEmbeddingsModels(String provider, String name, int dimension) {
36+
this.provider = provider;
37+
this.name = name;
38+
this.dimension = dimension;
39+
}
40+
41+
/**
42+
* Gets provider
43+
*
44+
* @return value of provider
45+
*/
46+
public String getProvider() {
47+
return provider;
48+
}
49+
50+
/**
51+
* Gets name
52+
*
53+
* @return value of name
54+
*/
55+
public String getName() {
56+
return name;
57+
}
58+
59+
/**
60+
* Gets dimension
61+
*
62+
* @return value of dimension
63+
*/
64+
public int getDimension() {
65+
return dimension;
66+
}
67+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package com.datastax.astra.tool.loader.rag.ingestion;
2+
3+
import com.datastax.astra.client.tables.definition.columns.ColumnTypes;
4+
import com.datastax.astra.client.tables.mapping.Column;
5+
import com.datastax.astra.client.tables.mapping.EntityTable;
6+
import com.datastax.astra.client.tables.mapping.PartitionBy;
7+
import lombok.Data;
8+
9+
import java.util.UUID;
10+
11+
@Data
12+
@EntityTable(RagIngestionConfig.TABLE_NAME)
13+
public class RagIngestionConfig {
14+
15+
public static final String TABLE_NAME = "rag_configs";
16+
17+
@PartitionBy(0)
18+
UUID uid = UUID.randomUUID();
19+
20+
@Column(name = "name", type = ColumnTypes.TEXT)
21+
String name;
22+
23+
@Column(name = "description", type = ColumnTypes.TEXT)
24+
String description;
25+
26+
// Splitting
27+
28+
@Column(name = "splitter", type = ColumnTypes.TEXT)
29+
String splitter;
30+
31+
@Column(name = "chunk_size", type = ColumnTypes.INT)
32+
Integer chunkSize;
33+
34+
@Column(name = "chunk_overlap", type = ColumnTypes.INT)
35+
Integer chunkOverlap;
36+
37+
// Embedding
38+
39+
@Column(name = "embedding_model", type = ColumnTypes.TEXT)
40+
String embeddingModel;
41+
42+
@Column(name = "embedding_provider", type = ColumnTypes.TEXT)
43+
String embeddingProvider;
44+
45+
@Column(name = "embedding_dimension", type = ColumnTypes.INT)
46+
Integer embeddingDimension;
47+
48+
// Post Processing
49+
50+
@Column(name = "context_before", type = ColumnTypes.INT)
51+
Integer contextBefore = 1;
52+
53+
@Column(name = "context_after", type = ColumnTypes.INT)
54+
Integer contextAfter = 2;
55+
56+
@Column(name = "enable_nlp_filter", type = ColumnTypes.BOOLEAN)
57+
boolean nlp;
58+
59+
@Column(name = "enable_hyde", type = ColumnTypes.BOOLEAN)
60+
boolean hyde;
61+
62+
public RagIngestionConfig withEmbedding(RagEmbeddingsModels model) {
63+
this.embeddingDimension = model.getDimension();
64+
this.embeddingProvider = model.getProvider();
65+
this.embeddingModel = model.getName();
66+
return this;
67+
}
68+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package com.datastax.astra.tool.loader.rag.ingestion;
2+
3+
import com.datastax.astra.client.core.query.SortOrder;
4+
import com.datastax.astra.client.tables.definition.columns.ColumnTypes;
5+
import com.datastax.astra.client.tables.mapping.Column;
6+
import com.datastax.astra.client.tables.mapping.EntityTable;
7+
import com.datastax.astra.client.tables.mapping.PartitionBy;
8+
import com.datastax.astra.client.tables.mapping.PartitionSort;
9+
import lombok.Data;
10+
11+
import java.time.Instant;
12+
import java.util.UUID;
13+
14+
@Data
15+
@EntityTable(RagIngestionJob.TABLE_NAME)
16+
public class RagIngestionJob {
17+
18+
public static final String TABLE_NAME = "rag_jobs";
19+
20+
@PartitionBy(0)
21+
@Column(name ="source_id", type=ColumnTypes.UUID)
22+
UUID sourceId;
23+
24+
@PartitionSort(position = 0, order=SortOrder.ASCENDING)
25+
@Column(name ="uid", type=ColumnTypes.UUID)
26+
UUID uid = UUID.randomUUID();
27+
28+
@Column(name ="config_id", type=ColumnTypes.UUID)
29+
UUID configId;
30+
31+
@Column(name ="start", type=ColumnTypes.TIMESTAMP)
32+
Instant start = Instant.now();
33+
34+
@Column(name ="stop", type=ColumnTypes.TIMESTAMP)
35+
Instant stop;
36+
37+
@Column(name ="elapsed", type=ColumnTypes.BIGINT)
38+
Long elapsed;
39+
40+
@Column(name ="chunk_count", type=ColumnTypes.INT)
41+
Integer chunkCount = 0;
42+
43+
@Column(name ="token_count", type=ColumnTypes.INT)
44+
Integer tokenCount = 0;
45+
46+
}

0 commit comments

Comments
 (0)