Skip to content

Commit 0fbe919

Browse files
committed
Improved examples [skip ci]
1 parent bf32cf6 commit 0fbe919

File tree

5 files changed

+24
-21
lines changed

5 files changed

+24
-21
lines changed

examples/cohere/example.js

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ await pgvector.registerTypes(client);
1111
await client.query('DROP TABLE IF EXISTS documents');
1212
await client.query('CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding bit(1024))');
1313

14-
async function fetchEmbeddings(texts, inputType) {
14+
async function embed(texts, inputType) {
1515
const cohere = new CohereClient();
1616
const response = await cohere.embed({
1717
texts: texts,
@@ -29,14 +29,14 @@ const input = [
2929
'The cat is purring',
3030
'The bear is growling'
3131
];
32-
const embeddings = await fetchEmbeddings(input, 'search_document');
32+
const embeddings = await embed(input, 'search_document');
3333
for (let [i, content] of input.entries()) {
3434
await client.query('INSERT INTO documents (content, embedding) VALUES ($1, $2)', [content, embeddings[i]]);
3535
}
3636

3737
const query = 'forest';
38-
const queryEmbedding = (await fetchEmbeddings([query], 'search_query'))[0];
39-
const { rows } = await client.query('SELECT * FROM documents ORDER BY embedding <~> $1 LIMIT 5', [queryEmbedding]);
38+
const queryEmbedding = (await embed([query], 'search_query'))[0];
39+
const { rows } = await client.query('SELECT content FROM documents ORDER BY embedding <~> $1 LIMIT 5', [queryEmbedding]);
4040
for (let row of rows) {
4141
console.log(row.content);
4242
}

examples/hybrid-search/example.js

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { pipeline } from '@xenova/transformers';
1+
import { pipeline } from '@huggingface/transformers';
22
import pg from 'pg';
33
import pgvector from 'pgvector/pg';
44

@@ -18,15 +18,15 @@ const input = [
1818
'The bear is growling'
1919
];
2020

21-
const extractor = await pipeline('feature-extraction', 'Xenova/multi-qa-MiniLM-L6-cos-v1');
21+
const extractor = await pipeline('feature-extraction', 'Xenova/multi-qa-MiniLM-L6-cos-v1', {dtype: 'fp32'});
2222

23-
async function generateEmbedding(content) {
23+
async function embed(content) {
2424
const output = await extractor(content, {pooling: 'mean', normalize: true});
2525
return Array.from(output.data);
2626
}
2727

2828
for (let content of input) {
29-
const embedding = await generateEmbedding(content);
29+
const embedding = await embed(content);
3030
await client.query('INSERT INTO documents (content, embedding) VALUES ($1, $2)', [content, pgvector.toSql(embedding)]);
3131
}
3232

@@ -54,7 +54,7 @@ ORDER BY score DESC
5454
LIMIT 5
5555
`;
5656
const query = 'growling bear'
57-
const embedding = await generateEmbedding(query);
57+
const embedding = await embed(query);
5858
const k = 60
5959
const { rows } = await client.query(sql, [query, pgvector.toSql(embedding), k]);
6060
for (let row of rows) {

examples/hybrid-search/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"private": true,
33
"type": "module",
44
"dependencies": {
5-
"@xenova/transformers": "^2.6.0",
5+
"@huggingface/transformers": "^3.3.3",
66
"pg": "^8.11.3",
77
"pgvector": "file:../.."
88
}

examples/openai/example.js

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,25 @@ await pgvector.registerTypes(client);
1111
await client.query('DROP TABLE IF EXISTS documents');
1212
await client.query('CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(1536))');
1313

14+
async function embed(input) {
15+
const openai = new OpenAI();
16+
const response = await openai.embeddings.create({input: input, model: 'text-embedding-3-small'});
17+
return response.data.map((v) => v.embedding);
18+
}
19+
1420
const input = [
1521
'The dog is barking',
1622
'The cat is purring',
1723
'The bear is growling'
1824
];
19-
20-
const openai = new OpenAI();
21-
const response = await openai.embeddings.create({input: input, model: 'text-embedding-3-small'});
22-
const embeddings = response.data.map((v) => v.embedding);
23-
25+
const embeddings = await embed(input);
2426
for (let [i, content] of input.entries()) {
2527
await client.query('INSERT INTO documents (content, embedding) VALUES ($1, $2)', [content, pgvector.toSql(embeddings[i])]);
2628
}
2729

28-
const documentId = 2;
29-
const { rows } = await client.query('SELECT * FROM documents WHERE id != $1 ORDER BY embedding <=> (SELECT embedding FROM documents WHERE id = $1) LIMIT 5', [documentId]);
30+
const query = 'forest';
31+
const queryEmbedding = (await embed([query]))[0];
32+
const { rows } = await client.query('SELECT content FROM documents ORDER BY embedding <=> $1 LIMIT 5', [pgvector.toSql(queryEmbedding)]);
3033
for (let row of rows) {
3134
console.log(row.content);
3235
}

examples/sparse-search/example.js

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ await pgvector.registerTypes(client);
1818
await client.query('DROP TABLE IF EXISTS documents');
1919
await client.query('CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding sparsevec(30522))');
2020

21-
async function fetchEmbeddings(inputs) {
21+
async function embed(inputs) {
2222
const url = 'http://localhost:3000/embed_sparse';
2323
const data = {inputs: inputs};
2424
const options = {
@@ -48,14 +48,14 @@ const input = [
4848
'The bear is growling'
4949
];
5050

51-
const embeddings = await fetchEmbeddings(input);
51+
const embeddings = await embed(input);
5252
for (let [i, content] of input.entries()) {
5353
await client.query('INSERT INTO documents (content, embedding) VALUES ($1, $2)', [content, new SparseVector(embeddings[i], 30522)]);
5454
}
5555

5656
const query = 'forest';
57-
const queryEmbeddings = await fetchEmbeddings([query]);
58-
const { rows } = await client.query('SELECT content FROM documents ORDER BY embedding <#> $1 LIMIT 5', [new SparseVector(queryEmbeddings[0], 30522)]);
57+
const queryEmbedding = (await embed([query]))[0];
58+
const { rows } = await client.query('SELECT content FROM documents ORDER BY embedding <#> $1 LIMIT 5', [new SparseVector(queryEmbedding, 30522)]);
5959
for (let row of rows) {
6060
console.log(row.content);
6161
}

0 commit comments

Comments
 (0)