Skip to content

Commit 2605913

Browse files
committed
Add Bedrock
1 parent b0452d1 commit 2605913

File tree

2 files changed

+153
-1
lines changed

2 files changed

+153
-1
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
name: AstraDev | AWS Bedrock
2+
3+
on:
4+
#push:
5+
# branches: [ main ]
6+
#pull_request:
7+
# branches: [ main ]
8+
workflow_dispatch:
9+
10+
jobs:
11+
BUILD:
12+
runs-on: ubuntu-latest
13+
steps:
14+
- name: Checkout Code
15+
uses: actions/checkout@v2
16+
- name: Set up JDK 11
17+
uses: actions/setup-java@v2
18+
with:
19+
distribution: 'adopt'
20+
java-version: 11
21+
- name: Build with Maven
22+
run: |
23+
mvn install -Dmaven.test.skip=true
24+
25+
AWS_BEDROCK:
26+
needs: BUILD
27+
runs-on: ubuntu-latest
28+
steps:
29+
- name: Checkout repository
30+
uses: actions/checkout@v2
31+
- name: Set up JDK
32+
uses: actions/setup-java@v2
33+
with:
34+
distribution: 'adopt'
35+
java-version: 11
36+
- name: Maven Test
37+
env:
38+
ASTRA_DB_APPLICATION_TOKEN_DEV: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN_DEV }}
39+
ASTRA_CLOUD_PROVIDER_DEV: AWS
40+
ASTRA_CLOUD_REGION_DEV: us-west-2
41+
EMBEDDING_PROVIDER: bedrock
42+
BEDROCK_HEADER_AWS_ACCESS_ID: ${{ secrets.BEDROCK_HEADER_AWS_ACCESS_ID }}
43+
BEDROCK_HEADER_AWS_SECRET_IDx: ${{ secrets.BEDROCK_HEADER_AWS_SECRET_ID }}
44+
BEDROCK_REGION: ${{ vars.BEDROCK_REGION }}
45+
run: |
46+
cd astra-db-java
47+
mvn test -Dtest=com.datastax.astra.test.integration.dev_vectorize.AstraDevVectorizeAwsBedRockITTest
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,109 @@
11
package com.datastax.astra.test.integration.dev_vectorize;
22

3-
public class AstraDevVectorizeAwsBedRockITTest {
3+
import com.datastax.astra.client.Collection;
4+
import com.datastax.astra.client.DataAPIClient;
5+
import com.datastax.astra.client.DataAPIOptions;
6+
import com.datastax.astra.client.Database;
7+
import com.datastax.astra.client.admin.DatabaseAdmin;
8+
import com.datastax.astra.client.auth.AWSEmbeddingHeadersProvider;
9+
import com.datastax.astra.client.auth.EmbeddingHeadersProvider;
10+
import com.datastax.astra.client.model.CollectionOptions;
11+
import com.datastax.astra.client.model.DataAPIKeywords;
12+
import com.datastax.astra.client.model.Document;
13+
import com.datastax.astra.client.model.FindEmbeddingProvidersResult;
14+
import com.datastax.astra.client.model.FindOneOptions;
15+
import com.datastax.astra.client.model.InsertManyOptions;
16+
import com.datastax.astra.client.model.InsertManyResult;
17+
import com.datastax.astra.client.model.Projections;
18+
import com.datastax.astra.internal.command.LoggingCommandObserver;
19+
import com.datastax.astra.test.integration.AbstractVectorizeITTest;
20+
import com.dtsx.astra.sdk.db.domain.CloudProviderType;
21+
import com.dtsx.astra.sdk.utils.AstraEnvironment;
22+
import lombok.extern.slf4j.Slf4j;
23+
import org.junit.jupiter.api.Test;
24+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
25+
26+
import java.util.List;
27+
import java.util.Map;
28+
import java.util.Optional;
29+
30+
import static org.assertj.core.api.Assertions.assertThat;
31+
32+
@Slf4j
33+
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN_DEV", matches = "Astra.*")
34+
@EnabledIfEnvironmentVariable(named = "ASTRA_CLOUD_PROVIDER_DEV", matches = ".*")
35+
@EnabledIfEnvironmentVariable(named = "ASTRA_CLOUD_REGION_DEV", matches = ".*")
36+
@EnabledIfEnvironmentVariable(named = "EMBEDDING_PROVIDER", matches = ".*")
37+
@EnabledIfEnvironmentVariable(named = "BEDROCK_HEADER_AWS_ACCESS_ID", matches = ".*")
38+
@EnabledIfEnvironmentVariable(named = "BEDROCK_HEADER_AWS_SECRET_ID", matches = ".*")
39+
@EnabledIfEnvironmentVariable(named = "BEDROCK_REGION", matches = ".*")
40+
public class AstraDevVectorizeAwsBedRockITTest extends AbstractVectorizeITTest {
41+
42+
@Override
43+
public AstraEnvironment getAstraEnvironment() {
44+
return AstraEnvironment.DEV;
45+
}
46+
47+
@Override
48+
public CloudProviderType getCloudProvider() {
49+
return CloudProviderType.valueOf(System.getenv("ASTRA_CLOUD_PROVIDER_DEV"));
50+
}
51+
52+
@Override
53+
public String getRegion() {
54+
return System.getenv("ASTRA_CLOUD_REGION_DEV");
55+
}
56+
57+
@Test
58+
public void shouldTestAwsBedRock() {
59+
String token = System.getenv("ASTRA_DB_APPLICATION_TOKEN_DEV");
60+
EmbeddingHeadersProvider awsAuthProvider = new AWSEmbeddingHeadersProvider(
61+
System.getenv("BEDROCK_HEADER_AWS_ACCESS_ID"),
62+
System.getenv("BEDROCK_HEADER_AWS_SECRET_ID")
63+
);
64+
65+
String providerName = "bedrock";
66+
String providerModel = "amazon.titan-embed-text-v1";
67+
String collectionName = "aws_bedrock_titan_v1";
68+
69+
// Validate that 'bedrock' is a valid provider
70+
FindEmbeddingProvidersResult result = databaseAdmin.findEmbeddingProviders();
71+
assertThat(result).isNotNull();
72+
assertThat(result.getEmbeddingProviders()).isNotNull();
73+
assertThat(result.getEmbeddingProviders()).containsKey(providerName);
74+
75+
// Create collection for AWS Bedrock
76+
Collection<Document> collection = getDatabase().createCollection(collectionName, CollectionOptions
77+
.builder()
78+
.vectorize(providerName, providerModel,null,
79+
Map.of("region", System.getenv("BEDROCK_REGION")))
80+
.build());;
81+
assertThat(getDatabase().collectionExists(collectionName)).isTrue();
82+
// Insertion With Vectorize
83+
List<Document> entries = List.of(
84+
new Document(1).vectorize("A lovestruck Romeo sings the streets a serenade"),
85+
new Document(2).vectorize("Finds a streetlight, steps out of the shade"),
86+
new Document(3).vectorize("Says something like, You and me babe, how about it?"),
87+
new Document(4).vectorize("Juliet says,Hey, it's Romeo, you nearly gimme a heart attack"),
88+
new Document(5).vectorize("He's underneath the window"),
89+
new Document(6).vectorize("She's singing, Hey la, my boyfriend's back"),
90+
new Document(7).vectorize("You shouldn't come around here singing up at people like that"),
91+
new Document(8).vectorize("Anyway, what you gonna do about it?")
92+
);
93+
94+
InsertManyResult res = collection.insertMany(entries, new InsertManyOptions().embeddingAuthProvider(awsAuthProvider));
95+
assertThat(res.getInsertedIds()).hasSize(8);
96+
log.info("{} Documents inserted", res.getInsertedIds().size());
97+
Optional<Document> doc = collection.findOne(null,
98+
new FindOneOptions()
99+
.sort("You shouldn't come around here singing up at people like tha")
100+
.projection(Projections.exclude(DataAPIKeywords.VECTOR.getKeyword()))
101+
.embeddingAuthProvider(awsAuthProvider)
102+
.includeSimilarity());
103+
log.info("Document found {}", doc);
104+
assertThat(doc).isPresent();
105+
assertThat(doc.get().getId(Integer.class)).isEqualTo(7);
106+
assertThat(doc.get().getDouble(DataAPIKeywords.SIMILARITY.getKeyword())).isGreaterThan(.8);
107+
}
108+
4109
}

0 commit comments

Comments
 (0)