|
34 | 34 | import java.util.UUID; |
35 | 35 | import java.util.concurrent.TimeUnit; |
36 | 36 | import java.util.concurrent.atomic.AtomicInteger; |
| 37 | +import java.util.function.BiConsumer; |
37 | 38 | import java.util.function.BiPredicate; |
| 39 | +import java.util.function.Function; |
38 | 40 | import java.util.stream.Collectors; |
39 | 41 | import java.util.stream.Stream; |
40 | 42 | import java.util.stream.StreamSupport; |
|
49 | 51 |
|
50 | 52 | import com.datastax.driver.core.utils.UUIDs; |
51 | 53 | import org.apache.cassandra.Util; |
| 54 | +import org.apache.cassandra.cql3.CQL3Type; |
52 | 55 | import org.apache.cassandra.cql3.QueryProcessor; |
53 | 56 | import org.apache.cassandra.cql3.UntypedResultSet; |
54 | 57 | import org.apache.cassandra.cql3.functions.types.DataType; |
|
58 | 61 | import org.apache.cassandra.cql3.functions.types.UserType; |
59 | 62 | import org.apache.cassandra.db.ColumnFamilyStore; |
60 | 63 | import org.apache.cassandra.db.Keyspace; |
| 64 | +import org.apache.cassandra.db.marshal.AbstractType; |
| 65 | +import org.apache.cassandra.db.marshal.FloatType; |
| 66 | +import org.apache.cassandra.db.marshal.SimpleDateType; |
| 67 | +import org.apache.cassandra.db.marshal.TimeType; |
61 | 68 | import org.apache.cassandra.db.marshal.UTF8Type; |
62 | 69 | import org.apache.cassandra.dht.ByteOrderedPartitioner; |
63 | 70 | import org.apache.cassandra.dht.Murmur3Partitioner; |
|
76 | 83 | import org.apache.cassandra.utils.JavaDriverUtils; |
77 | 84 |
|
78 | 85 | import static org.apache.cassandra.utils.Clock.Global.currentTimeMillis; |
| 86 | +import static org.assertj.core.api.Assertions.assertThat; |
79 | 87 | import static org.junit.Assert.assertEquals; |
80 | 88 | import static org.junit.Assert.assertFalse; |
81 | 89 | import static org.junit.Assert.assertNotNull; |
@@ -1520,6 +1528,75 @@ public void testSkipBuildingIndexesWithSAI() throws Exception |
1520 | 1528 | assertFalse(indexDescriptor.isPerColumnIndexBuildComplete(new IndexIdentifier(keyspace, table, "idx2"))); |
1521 | 1529 | } |
1522 | 1530 |
|
| 1531 | + @Test |
| 1532 | + public void testWritingVectorData() throws Exception |
| 1533 | + { |
| 1534 | + testWritingVectorData(CQL3Type.Native.FLOAT, FloatType.instance, (i) -> (float) i, (i, vector) -> { |
| 1535 | + assertThat(vector).allMatch(val -> val instanceof Float); |
| 1536 | + assertThat(vector).allMatch(val -> (float) val == (float) i); |
| 1537 | + }); |
| 1538 | + |
| 1539 | + perTestSetup(); |
| 1540 | + |
| 1541 | + testWritingVectorData(CQL3Type.Native.DATE, SimpleDateType.instance, LocalDate::fromDaysSinceEpoch, (i, vector) -> { |
| 1542 | + assertThat(vector).allMatch(val -> val instanceof Integer); |
| 1543 | + assertThat(vector).allMatch(val -> { |
| 1544 | + int days = (int) val - Integer.MIN_VALUE; // signed to unsigned conversion |
| 1545 | + return days == i; |
| 1546 | + }); |
| 1547 | + }); |
| 1548 | + |
| 1549 | + perTestSetup(); |
| 1550 | + |
| 1551 | + testWritingVectorData(CQL3Type.Native.TIME, TimeType.instance, (i) -> (long) i, (i, vector) -> { |
| 1552 | + assertThat(vector).allMatch(val -> val instanceof Long); |
| 1553 | + assertThat(vector).allMatch(val -> (long) val == (long) i); |
| 1554 | + }); |
| 1555 | + } |
| 1556 | + |
| 1557 | + private void testWritingVectorData(CQL3Type.Native cqlType, AbstractType<?> subType, Function<Integer, ?> valueFactory, |
| 1558 | + BiConsumer<Integer, List<?>> checkFunction) throws Exception |
| 1559 | + { |
| 1560 | + final int dimensions = 5; |
| 1561 | + final String schema = "CREATE TABLE " + qualifiedTable + " (" |
| 1562 | + + " k int," |
| 1563 | + + " v1 VECTOR<" + cqlType.name() + ", " + dimensions + ">," |
| 1564 | + + " PRIMARY KEY (k)" |
| 1565 | + + ")"; |
| 1566 | + |
| 1567 | + CQLSSTableWriter writer = CQLSSTableWriter.builder() |
| 1568 | + .inDirectory(dataDir) |
| 1569 | + .forTable(schema) |
| 1570 | + .using("INSERT INTO " + keyspace + "." + table + " (k, v1) " + |
| 1571 | + "VALUES (?, ?)").build(); |
| 1572 | + |
| 1573 | + for (int i = 0; i < 100; i++) |
| 1574 | + { |
| 1575 | + List<Object> vector = new ArrayList<>(dimensions); |
| 1576 | + for (int j = 0; j < dimensions; j++) |
| 1577 | + { |
| 1578 | + vector.add(valueFactory.apply(i)); |
| 1579 | + } |
| 1580 | + writer.addRow(i, vector); |
| 1581 | + } |
| 1582 | + |
| 1583 | + writer.close(); |
| 1584 | + loadSSTables(dataDir, keyspace); |
| 1585 | + |
| 1586 | + UntypedResultSet resultSet = QueryProcessor.executeInternal("SELECT * FROM " + keyspace + "." + table); |
| 1587 | + |
| 1588 | + assertEquals(resultSet.size(), 100); |
| 1589 | + int cnt = 0; |
| 1590 | + for (UntypedResultSet.Row row : resultSet) |
| 1591 | + { |
| 1592 | + assertEquals(cnt, row.getInt("k")); |
| 1593 | + List<?> vector = row.getVector("v1", subType, dimensions); |
| 1594 | + assertThat(vector).hasSize(dimensions); |
| 1595 | + checkFunction.accept(cnt, vector); |
| 1596 | + cnt++; |
| 1597 | + } |
| 1598 | + } |
| 1599 | + |
1523 | 1600 | protected void loadSSTables(File dataDir, String ksName) |
1524 | 1601 | { |
1525 | 1602 | ColumnFamilyStore cfs = Keyspace.openWithoutSSTables(ksName).getColumnFamilyStore(table); |
|
0 commit comments