Skip to content

Commit 58aa471

Browse files
authored
Merge pull request #133 from Ryszard-Trojnacki/feature/pgvector-container
Add PGvector test container
2 parents af2c990 + 70bce9f commit 58aa471

File tree

5 files changed

+171
-0
lines changed

5 files changed

+171
-0
lines changed

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,19 @@ occurring automatically on JVM shutdown.
164164

165165
```
166166

167+
#### Postgres PGvector
168+
```java
169+
var container = PGvectorContainer.builder("pg18")
170+
.start();
171+
```
172+
173+
#### Postgres Postgis
174+
```java
175+
PostgisContainer container = PostgisContainer.builder("15")
176+
.useLW(true) // use LW compression JDBC urls
177+
.build();
178+
```
179+
167180
#### SqlServer
168181

169182
```java

pom.xml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,13 @@
151151
<scope>test</scope>
152152
</dependency>
153153

154+
<dependency>
155+
<groupId>com.pgvector</groupId>
156+
<artifactId>pgvector</artifactId>
157+
<version>0.1.6</version>
158+
<scope>test</scope>
159+
</dependency>
160+
154161
<dependency>
155162
<groupId>ru.yandex.clickhouse</groupId>
156163
<artifactId>clickhouse-jdbc</artifactId>

src/main/java/io/ebean/test/containers/ContainerFactory.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ private void init() {
6161
if (pgisVersion != null) {
6262
containers.add(PostgisContainer.builder(pgisVersion).properties(properties).build());
6363
}
64+
String pgvectorVersion = runWithVersion("pgvector");
65+
if (pgvectorVersion != null) {
66+
containers.add(PGvectorContainer.builder(pgvectorVersion).properties(properties).build());
67+
}
6468
String mysqlVersion = runWithVersion("mysql");
6569
if (mysqlVersion != null) {
6670
containers.add(MySqlContainer.builder(mysqlVersion).properties(properties).build());
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package io.ebean.test.containers;
2+
3+
/**
4+
* Commands for controlling a pgvector docker container.
5+
*/
6+
public class PGvectorContainer extends BasePostgresContainer<PGvectorContainer> {
7+
8+
@Override
9+
public PGvectorContainer start() {
10+
startOrThrow();
11+
return this;
12+
}
13+
14+
/**
15+
* Create a builder for PGvectorContainer.
16+
*/
17+
public static Builder builder(String version) {
18+
return new Builder(version);
19+
}
20+
21+
private PGvectorContainer(Builder config) {
22+
super(config);
23+
}
24+
25+
/**
26+
* Builder for Postgis container.
27+
*/
28+
public static class Builder extends BaseDbBuilder<PGvectorContainer, Builder> {
29+
30+
private Builder(String version) {
31+
super("pgvector", 6435, 5432, version);
32+
this.image = "pgvector/pgvector:" + version;
33+
this.adminUsername = "postgres";
34+
this.tmpfs = "/var/lib/postgresql/data:rw";
35+
this.extensions = "vector";
36+
this.extra.extensions = extensions;
37+
this.extra2.extensions = extensions;
38+
}
39+
40+
@Override
41+
protected String buildJdbcUrl() {
42+
return "jdbc:postgresql://" + host + ":" + port + "/" + dbName;
43+
}
44+
45+
@Override
46+
protected String buildJdbcAdminUrl() {
47+
return "jdbc:postgresql://" + host + ":" + port + "/postgres";
48+
}
49+
50+
@Override
51+
protected String buildExtraJdbcUrl(String dbName) {
52+
return "jdbc:postgresql://" + host + ":" + port + "/" + dbName;
53+
}
54+
55+
@Override
56+
public PGvectorContainer build() {
57+
return new PGvectorContainer(this);
58+
}
59+
60+
@Override
61+
public PGvectorContainer start() {
62+
return build().start();
63+
}
64+
}
65+
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package io.ebean.test.containers;
2+
3+
import io.ebean.Database;
4+
import io.ebean.datasource.DataSourcePool;
5+
import org.junit.jupiter.api.Test;
6+
7+
import java.sql.Connection;
8+
import java.sql.PreparedStatement;
9+
import java.sql.SQLException;
10+
import java.util.HashSet;
11+
12+
import com.pgvector.PGvector;
13+
14+
import static org.assertj.core.api.Assertions.assertThat;
15+
16+
public class PGvectorContainerTest {
17+
private final HashSet<Connection> connections = new HashSet<>();
18+
19+
/**
20+
* Helper function to register the PGvector types only once per connection.
21+
* @param connection the connection
22+
* @return the same connection
23+
* @throws SQLException if an error occurs
24+
*/
25+
private Connection wrapConnection(Connection connection) throws SQLException {
26+
if(connections.add(connection)) {
27+
PGvector.registerTypes(connection);
28+
}
29+
return connection;
30+
}
31+
32+
@Test
33+
void extraDb() throws java.sql.SQLException {
34+
PGvectorContainer container = PGvectorContainer.builder("pg18")
35+
.port(0)
36+
.extraDb("myextra")
37+
.build();
38+
39+
container.startMaybe();
40+
assertThat(container.port()).isGreaterThan(0);
41+
42+
ContainerConfig containerConfig = container.config();
43+
assertThat(containerConfig.port()).isEqualTo(container.port());
44+
45+
String jdbcUrl = container.config().jdbcUrl();
46+
assertThat(jdbcUrl).contains(":" + containerConfig.port());
47+
runSomeSql(container);
48+
49+
DataSourcePool dataSource = container.ebean().dataSourceBuilder().build();
50+
try (Connection connection = wrapConnection(dataSource.getConnection())) {
51+
exeSql(connection, "INSERT INTO items (embedding) values ('[7,8,9]')");
52+
}
53+
dataSource.shutdown();
54+
55+
Database ebean = container.ebean().builder()
56+
.register(false)
57+
.defaultDatabase(false)
58+
.build();
59+
// This can't be done yet, because Ebean doesn't know about PGvector type
60+
// ebean.sqlUpdate("insert into items (embedding) values (?)")
61+
// .setParameter(new PGvector(new float[] { 10f, 11f, 12f }))
62+
// .execute();
63+
64+
ebean.shutdown();
65+
}
66+
67+
private void runSomeSql(PGvectorContainer container) {
68+
try {
69+
Connection connection = wrapConnection(container.createConnection());
70+
exeSql(connection, "CREATE TABLE items (id bigserial PRIMARY KEY, embedding vector(3))");
71+
exeSql(connection, "INSERT INTO items (embedding) VALUES ('[1,2,3]'), ('[4,5,6]')");
72+
} catch (SQLException e) {
73+
throw new RuntimeException(e);
74+
}
75+
}
76+
77+
private static void exeSql(Connection connection, String sql) throws SQLException {
78+
try (PreparedStatement st = connection.prepareStatement(sql)) {
79+
st.execute();
80+
}
81+
}
82+
}

0 commit comments

Comments
 (0)