Skip to content

Commit df46511

Browse files
committed
Improvements to EF type mapping
* Support scaffolding of vector types * Implemented more proper support for the size facet * Consolidated different mappins and plugins into the same files * A bit of test code cleanup Closes #44
1 parent 4de753b commit df46511

File tree

8 files changed

+92
-88
lines changed

8 files changed

+92
-88
lines changed

src/Pgvector.EntityFrameworkCore/HalfvecTypeMapping.cs

Lines changed: 0 additions & 19 deletions
This file was deleted.

src/Pgvector.EntityFrameworkCore/HalfvecTypeMappingSourcePlugin.cs

Lines changed: 0 additions & 11 deletions
This file was deleted.

src/Pgvector.EntityFrameworkCore/SparsevecTypeMapping.cs

Lines changed: 0 additions & 19 deletions
This file was deleted.

src/Pgvector.EntityFrameworkCore/SparsevecTypeMappingSourcePlugin.cs

Lines changed: 0 additions & 11 deletions
This file was deleted.

src/Pgvector.EntityFrameworkCore/VectorDbContextOptionsExtension.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ public void ApplyServices(IServiceCollection services)
1717
.TryAdd<IMethodCallTranslatorPlugin, VectorDbFunctionsTranslatorPlugin>();
1818

1919
services.AddSingleton<IRelationalTypeMappingSourcePlugin, VectorTypeMappingSourcePlugin>();
20-
services.AddSingleton<IRelationalTypeMappingSourcePlugin, HalfvecTypeMappingSourcePlugin>();
21-
services.AddSingleton<IRelationalTypeMappingSourcePlugin, SparsevecTypeMappingSourcePlugin>();
2220
}
2321

2422
public void Validate(IDbContextOptions options) { }

src/Pgvector.EntityFrameworkCore/VectorTypeMapping.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,18 @@ namespace Pgvector.EntityFrameworkCore;
66

77
public class VectorTypeMapping : RelationalTypeMapping
88
{
9-
public static VectorTypeMapping Default { get; } = new();
9+
public static VectorTypeMapping Default { get; } = new("vector", typeof(Vector));
1010

11-
public VectorTypeMapping() : base("vector", typeof(Vector)) { }
12-
13-
public VectorTypeMapping(string storeType) : base(storeType, typeof(Vector)) { }
11+
public VectorTypeMapping(string storeType, Type clrType, int? size = null)
12+
: this(
13+
new RelationalTypeMappingParameters(
14+
new CoreTypeMappingParameters(clrType),
15+
storeType,
16+
StoreTypePostfix.Size,
17+
size: size,
18+
fixedLength: true))
19+
{
20+
}
1421

1522
protected VectorTypeMapping(RelationalTypeMappingParameters parameters) : base(parameters) { }
1623

src/Pgvector.EntityFrameworkCore/VectorTypeMappingSourcePlugin.cs

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,30 @@ namespace Pgvector.EntityFrameworkCore;
55
public class VectorTypeMappingSourcePlugin : IRelationalTypeMappingSourcePlugin
66
{
77
public RelationalTypeMapping? FindMapping(in RelationalTypeMappingInfo mappingInfo)
8-
=> mappingInfo.ClrType == typeof(Vector)
9-
? new VectorTypeMapping(mappingInfo.StoreTypeName ?? "vector")
10-
: null;
8+
{
9+
if (mappingInfo.StoreTypeName is not null)
10+
{
11+
VectorTypeMapping? mapping = (mappingInfo.StoreTypeNameBase ?? mappingInfo.StoreTypeName) switch
12+
{
13+
"vector" => new(mappingInfo.StoreTypeName, typeof(Vector), mappingInfo.Size),
14+
"halfvec" => new(mappingInfo.StoreTypeName, typeof(HalfVector), mappingInfo.Size),
15+
"sparsevec" => new(mappingInfo.StoreTypeName, typeof(SparseVector), mappingInfo.Size),
16+
_ => null,
17+
};
18+
19+
// If the caller hasn't specified a CLR type (this is scaffolding), or if the user has specified
20+
// the one matching the store type, return the mapping.
21+
return mappingInfo.ClrType is null || mappingInfo.ClrType == mapping?.ClrType
22+
? mapping : null;
23+
}
24+
25+
// No store type specified, look up by the CLR type only
26+
return mappingInfo.ClrType switch
27+
{
28+
var t when t == typeof(Vector) => new VectorTypeMapping("vector", typeof(Vector), mappingInfo.Size),
29+
var t when t == typeof(HalfVector) => new VectorTypeMapping("halfvec", typeof(HalfVector), mappingInfo.Size),
30+
var t when t == typeof(SparseVector) => new VectorTypeMapping("sparsevec", typeof(SparseVector), mappingInfo.Size),
31+
_ => null,
32+
};
33+
}
1134
}

tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -66,79 +66,115 @@ public async Task Main()
6666

6767
var embedding = new Vector(new float[] { 1, 1, 1 });
6868
var items = await ctx.Items.FromSql($"SELECT * FROM efcore_items ORDER BY embedding <-> {embedding} LIMIT 5").ToListAsync();
69-
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
70-
Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray());
71-
Assert.Equal(new Half[] { (Half)1, (Half)1, (Half)1 }, items[0].HalfEmbedding!.ToArray());
69+
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
70+
Assert.Equal([1, 1, 1], items[0].Embedding!.ToArray());
71+
Assert.Equal([(Half)1, (Half)1, (Half)1], items[0].HalfEmbedding!.ToArray());
7272
Assert.Equal(new BitArray(new bool[] { false, false, false }), items[0].BinaryEmbedding!);
73-
Assert.Equal(new float[] { 1, 1, 1 }, items[0].SparseEmbedding!.ToArray());
73+
Assert.Equal([1, 1, 1], items[0].SparseEmbedding!.ToArray());
7474

7575
// vector distance functions
7676

7777
items = await ctx.Items.OrderBy(x => x.Embedding!.L2Distance(embedding)).Take(5).ToListAsync();
78-
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
79-
Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray());
78+
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
79+
Assert.Equal([1, 1, 1], items[0].Embedding!.ToArray());
8080

8181
items = await ctx.Items.OrderBy(x => x.Embedding!.MaxInnerProduct(embedding)).Take(5).ToListAsync();
82-
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
82+
Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());
8383

8484
items = await ctx.Items.OrderBy(x => x.Embedding!.CosineDistance(embedding)).Take(5).ToListAsync();
8585
Assert.Equal(3, items[2].Id);
8686

8787
items = await ctx.Items.OrderBy(x => x.Embedding!.L1Distance(embedding)).Take(5).ToListAsync();
88-
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
88+
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
8989

9090
// halfvec distance functions
9191

9292
var halfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 });
9393
items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L2Distance(halfEmbedding)).Take(5).ToListAsync();
94-
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
94+
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
9595

9696
items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.MaxInnerProduct(halfEmbedding)).Take(5).ToListAsync();
97-
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
97+
Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());
9898

9999
items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.CosineDistance(halfEmbedding)).Take(5).ToListAsync();
100100
Assert.Equal(3, items[2].Id);
101101

102102
items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L1Distance(halfEmbedding)).Take(5).ToListAsync();
103-
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
103+
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
104104

105105
// sparsevec distance functions
106106

107107
var sparseEmbedding = new SparseVector(new float[] { 1, 1, 1 });
108108
items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.L2Distance(sparseEmbedding)).Take(5).ToListAsync();
109-
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
109+
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
110110

111111
items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.MaxInnerProduct(sparseEmbedding)).Take(5).ToListAsync();
112-
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
112+
Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());
113113

114114
items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.CosineDistance(sparseEmbedding)).Take(5).ToListAsync();
115115
Assert.Equal(3, items[2].Id);
116116

117117
items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.L1Distance(sparseEmbedding)).Take(5).ToListAsync();
118-
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
118+
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
119119

120120
// bit distance functions
121121

122122
var binaryEmbedding = new BitArray(new bool[] { true, false, true });
123123
items = await ctx.Items.OrderBy(x => x.BinaryEmbedding!.HammingDistance(binaryEmbedding)).Take(5).ToListAsync();
124-
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
124+
Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());
125125

126126
items = await ctx.Items.OrderBy(x => x.BinaryEmbedding!.JaccardDistance(binaryEmbedding)).Take(5).ToListAsync();
127-
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
127+
Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());
128128

129129
// additional
130130

131131
items = await ctx.Items
132132
.OrderBy(x => x.Id)
133133
.Where(x => x.Embedding!.L2Distance(embedding) < 1.5)
134134
.ToListAsync();
135-
Assert.Equal(new int[] { 1, 3 }, items.Select(v => v.Id).ToArray());
135+
Assert.Equal([1, 3], items.Select(v => v.Id).ToArray());
136136

137137
var neighbors = await ctx.Items
138138
.OrderBy(x => x.Embedding!.L2Distance(embedding))
139139
.Select(x => new { Entity = x, Distance = x.Embedding!.L2Distance(embedding) })
140140
.ToListAsync();
141-
Assert.Equal(new int[] { 1, 3, 2 }, neighbors.Select(v => v.Entity.Id).ToArray());
142-
Assert.Equal(new double[] { 0, 1, Math.Sqrt(3) }, neighbors.Select(v => v.Distance).ToArray());
141+
Assert.Equal([1, 3, 2], neighbors.Select(v => v.Entity.Id).ToArray());
142+
Assert.Equal([0, 1, Math.Sqrt(3)], neighbors.Select(v => v.Distance).ToArray());
143+
}
144+
145+
[Theory]
146+
[InlineData(typeof(Vector), null, "vector")]
147+
[InlineData(typeof(Vector), 3, "vector(3)")]
148+
[InlineData(typeof(HalfVector), null, "halfvec")]
149+
[InlineData(typeof(HalfVector), 3, "halfvec(3)")]
150+
[InlineData(typeof(SparseVector), null, "sparsevec")]
151+
[InlineData(typeof(SparseVector), 3, "sparsevec(3)")]
152+
public void By_StoreType(Type type, int? size, string expectedStoreType)
153+
{
154+
using var ctx = new ItemContext();
155+
var typeMappingSource = ctx.GetService<IRelationalTypeMappingSource>();
156+
157+
var typeMapping = typeMappingSource.FindMapping(type, storeTypeName: null, size: size)!;
158+
Assert.Equal(expectedStoreType, typeMapping.StoreType);
159+
Assert.Same(type, typeMapping.ClrType);
160+
Assert.Equal(size, typeMapping.Size);
161+
}
162+
163+
[Theory]
164+
[InlineData("vector", typeof(Vector), null)]
165+
[InlineData("vector(3)", typeof(Vector), 3)]
166+
[InlineData("halfvec", typeof(HalfVector), null)]
167+
[InlineData("halfvec(3)", typeof(HalfVector), 3)]
168+
[InlineData("sparsevec", typeof(SparseVector), null)]
169+
[InlineData("sparsevec(3)", typeof(SparseVector), 3)]
170+
public void By_ClrType(string storeType, Type expectedType, int? expectedSize)
171+
{
172+
using var ctx = new ItemContext();
173+
var typeMappingSource = ctx.GetService<IRelationalTypeMappingSource>();
174+
175+
var typeMapping = typeMappingSource.FindMapping(storeType)!;
176+
Assert.Equal(storeType, typeMapping.StoreType);
177+
Assert.Same(expectedType, typeMapping.ClrType);
178+
Assert.Equal(expectedSize, typeMapping.Size);
143179
}
144180
}

0 commit comments

Comments
 (0)