Skip to content

Commit 73acf9e

Browse files
committed
Added encoding and decoding to PgvectorNIO
1 parent a0b7b4e commit 73acf9e

File tree

6 files changed

+241
-8
lines changed

6 files changed

+241
-8
lines changed

Sources/Pgvector/SparseVector.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@ public struct SparseVector: Equatable {
1818
self.values = values
1919
}
2020

21+
// TODO check indices sorted and non-negative
22+
public init?(dim: Int, indices: [Int], values: [Float]) {
23+
guard indices.count == values.count else {
24+
return nil
25+
}
26+
27+
self.dim = dim
28+
self.indices = indices
29+
self.values = values
30+
}
31+
2132
public init?(_ string: String) {
2233
let parts = string.split(separator: "/", maxSplits: 2)
2334
guard parts.count == 2 else {
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import Pgvector
2+
import PostgresNIO
3+
4+
extension HalfVector: @retroactive PostgresEncodable {
5+
public static var psqlType: PostgresDataType = PostgresDataType(1)
6+
7+
public static var psqlFormat: PostgresFormat {
8+
.binary
9+
}
10+
11+
public func encode<JSONEncoder: PostgresJSONEncoder>(
12+
into byteBuffer: inout ByteBuffer,
13+
context: PostgresEncodingContext<JSONEncoder>
14+
) {
15+
byteBuffer.writeInteger(Int16(value.count), as: Int16.self)
16+
byteBuffer.writeInteger(0, as: Int16.self)
17+
for v in value {
18+
byteBuffer.writeInteger(v.bitPattern, as: UInt16.self)
19+
}
20+
}
21+
}
22+
23+
extension HalfVector: @retroactive PostgresDecodable {
24+
public init<JSONDecoder: PostgresJSONDecoder>(
25+
from buffer: inout ByteBuffer,
26+
type: PostgresDataType,
27+
format: PostgresFormat,
28+
context: PostgresDecodingContext<JSONDecoder>
29+
) throws {
30+
guard type.isUserDefined else {
31+
throw PostgresDecodingError.Code.typeMismatch
32+
}
33+
34+
guard format == .binary else {
35+
throw PostgresDecodingError.Code.failure;
36+
}
37+
38+
guard buffer.readableBytes >= 2, let dim = buffer.readInteger(as: Int16.self) else {
39+
throw PostgresDecodingError.Code.failure
40+
}
41+
42+
guard buffer.readableBytes >= 2, let unused = buffer.readInteger(as: Int16.self), unused == 0 else {
43+
throw PostgresDecodingError.Code.failure
44+
}
45+
46+
var value: [Float16] = []
47+
for _ in 0..<dim {
48+
guard buffer.readableBytes >= 2, let v = buffer.readInteger(as: UInt16.self) else {
49+
throw PostgresDecodingError.Code.failure
50+
}
51+
value.append(Float16(bitPattern: v))
52+
}
53+
self.init(value)
54+
}
55+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import Pgvector
2+
import PostgresNIO
3+
4+
public enum PgvectorError: Error {
5+
case string(String)
6+
}
7+
8+
public struct PgvectorNIO {
9+
// note: global OID for each type is not ideal since it will be different for each database
10+
public static func registerTypes(_ client: PostgresClient) async throws {
11+
let rows = try await client.query("SELECT regtype('vector')::oid::integer, regtype('halfvec')::oid::integer, regtype('sparsevec')::oid::integer")
12+
13+
var iterator = rows.makeAsyncIterator()
14+
guard let row = try await iterator.next() else {
15+
throw PgvectorError.string("unreachable")
16+
}
17+
let (vectorOid, halfvecOid, sparsevecOid) = try row.decode((Int?, Int?, Int?).self)
18+
19+
if let oid = vectorOid {
20+
Vector.psqlType = PostgresDataType(UInt32(oid))
21+
} else {
22+
throw PgvectorError.string("vector type not found in the database")
23+
}
24+
25+
if let oid = halfvecOid {
26+
HalfVector.psqlType = PostgresDataType(UInt32(oid))
27+
}
28+
29+
if let oid = sparsevecOid {
30+
SparseVector.psqlType = PostgresDataType(UInt32(oid))
31+
}
32+
}
33+
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import Pgvector
2+
import PostgresNIO
3+
4+
extension SparseVector: @retroactive PostgresEncodable {
5+
public static var psqlType: PostgresDataType = PostgresDataType(1)
6+
7+
public static var psqlFormat: PostgresFormat {
8+
.binary
9+
}
10+
11+
public func encode<JSONEncoder: PostgresJSONEncoder>(
12+
into byteBuffer: inout ByteBuffer,
13+
context: PostgresEncodingContext<JSONEncoder>
14+
) {
15+
byteBuffer.writeInteger(Int32(dim), as: Int32.self)
16+
byteBuffer.writeInteger(Int32(indices.count), as: Int32.self)
17+
byteBuffer.writeInteger(0, as: Int32.self)
18+
for v in indices {
19+
byteBuffer.writeInteger(Int32(v), as: Int32.self)
20+
}
21+
for v in values {
22+
byteBuffer.writeInteger(v.bitPattern, as: UInt32.self)
23+
}
24+
}
25+
}
26+
27+
extension SparseVector: @retroactive PostgresDecodable {
28+
public init<JSONDecoder: PostgresJSONDecoder>(
29+
from buffer: inout ByteBuffer,
30+
type: PostgresDataType,
31+
format: PostgresFormat,
32+
context: PostgresDecodingContext<JSONDecoder>
33+
) throws {
34+
guard type.isUserDefined else {
35+
throw PostgresDecodingError.Code.typeMismatch
36+
}
37+
38+
guard format == .binary else {
39+
throw PostgresDecodingError.Code.failure;
40+
}
41+
42+
guard buffer.readableBytes >= 4, let dim = buffer.readInteger(as: Int32.self) else {
43+
throw PostgresDecodingError.Code.failure
44+
}
45+
46+
guard buffer.readableBytes >= 4, let nnz = buffer.readInteger(as: Int32.self) else {
47+
throw PostgresDecodingError.Code.failure
48+
}
49+
50+
guard buffer.readableBytes >= 4, let unused = buffer.readInteger(as: Int32.self), unused == 0 else {
51+
throw PostgresDecodingError.Code.failure
52+
}
53+
54+
var indices: [Int] = []
55+
for _ in 0..<nnz {
56+
guard buffer.readableBytes >= 4, let v = buffer.readInteger(as: Int32.self) else {
57+
throw PostgresDecodingError.Code.failure
58+
}
59+
indices.append(Int(v))
60+
}
61+
62+
var values: [Float] = []
63+
for _ in 0..<nnz {
64+
guard buffer.readableBytes >= 4, let v = buffer.readInteger(as: UInt32.self) else {
65+
throw PostgresDecodingError.Code.failure
66+
}
67+
values.append(Float(bitPattern: v))
68+
}
69+
70+
self.init(dim: Int(dim), indices: indices, values: values)!
71+
}
72+
}

Sources/PgvectorNIO/Vector.swift

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,55 @@
11
import Pgvector
22
import PostgresNIO
3+
4+
extension Vector: @retroactive PostgresEncodable {
5+
public static var psqlType: PostgresDataType = PostgresDataType(1)
6+
7+
public static var psqlFormat: PostgresFormat {
8+
.binary
9+
}
10+
11+
public func encode<JSONEncoder: PostgresJSONEncoder>(
12+
into byteBuffer: inout ByteBuffer,
13+
context: PostgresEncodingContext<JSONEncoder>
14+
) {
15+
byteBuffer.writeInteger(Int16(value.count), as: Int16.self)
16+
byteBuffer.writeInteger(0, as: Int16.self)
17+
for v in value {
18+
byteBuffer.writeInteger(v.bitPattern)
19+
}
20+
}
21+
}
22+
23+
extension Vector: @retroactive PostgresDecodable {
24+
public init<JSONDecoder: PostgresJSONDecoder>(
25+
from buffer: inout ByteBuffer,
26+
type: PostgresDataType,
27+
format: PostgresFormat,
28+
context: PostgresDecodingContext<JSONDecoder>
29+
) throws {
30+
guard type.isUserDefined else {
31+
throw PostgresDecodingError.Code.typeMismatch
32+
}
33+
34+
guard format == .binary else {
35+
throw PostgresDecodingError.Code.failure;
36+
}
37+
38+
guard buffer.readableBytes >= 2, let dim = buffer.readInteger(as: Int16.self) else {
39+
throw PostgresDecodingError.Code.failure
40+
}
41+
42+
guard buffer.readableBytes >= 2, let unused = buffer.readInteger(as: Int16.self), unused == 0 else {
43+
throw PostgresDecodingError.Code.failure
44+
}
45+
46+
var value: [Float] = []
47+
for _ in 0..<dim {
48+
guard buffer.readableBytes >= 4, let v = buffer.readInteger(as: UInt32.self) else {
49+
throw PostgresDecodingError.Code.failure
50+
}
51+
value.append(Float(bitPattern: v))
52+
}
53+
self.init(value)
54+
}
55+
}

Tests/PgvectorTests/PostgresNIOTests.swift

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,31 @@ final class PostgresNIOTests {
2323
}
2424

2525
try await client.query("CREATE EXTENSION IF NOT EXISTS vector")
26+
try await PgvectorNIO.registerTypes(client)
27+
2628
try await client.query("DROP TABLE IF EXISTS nio_items")
2729
try await client.query("CREATE TABLE nio_items (id bigserial PRIMARY KEY, embedding vector(3))")
2830

29-
let embedding1 = Vector([1, 1, 1]).text()
30-
let embedding2 = Vector([2, 2, 2]).text()
31-
let embedding3 = Vector([1, 1, 2]).text()
32-
try await client.query("INSERT INTO nio_items (embedding) VALUES (\(embedding1)::vector), (\(embedding2)::vector), (\(embedding3)::vector)")
31+
let embedding1 = Vector([1, 1, 1])
32+
let embedding2 = Vector([2, 2, 2])
33+
let embedding3 = Vector([1, 1, 2])
34+
try await client.query("INSERT INTO nio_items (embedding) VALUES (\(embedding1)), (\(embedding2)), (\(embedding3))")
3335

34-
let embedding = Vector([1, 1, 1]).text()
35-
let rows = try await client.query("SELECT id, embedding::text FROM nio_items ORDER BY embedding <-> \(embedding)::vector LIMIT 5")
36-
for try await (id, embedding) in rows.decode((Int, String).self) {
37-
print(id, Vector(embedding)!)
36+
let embedding = Vector([1, 1, 1])
37+
let rows = try await client.query("SELECT id, embedding FROM nio_items ORDER BY embedding <-> \(embedding) LIMIT 5")
38+
for try await (id, embedding) in rows.decode((Int, Vector).self) {
39+
print(id, embedding)
3840
}
3941

4042
try await client.query("CREATE INDEX ON nio_items USING hnsw (embedding vector_l2_ops)")
4143

44+
let halfEmbedding = HalfVector([1, 2, 3])
45+
let sparseEmbedding = SparseVector([1, 0, 2, 0, 3, 0])
46+
let typeRows = try await client.query("SELECT \(embedding), \(halfEmbedding), \(sparseEmbedding)")
47+
for try await (v, h, s) in typeRows.decode((Vector, HalfVector, SparseVector).self) {
48+
print(v, h, s)
49+
}
50+
4251
taskGroup.cancelAll()
4352
}
4453
}

0 commit comments

Comments
 (0)