Skip to content

Commit ba80bab

Browse files
committed
Added Disco example [skip ci]
1 parent 896ba22 commit ba80bab

File tree

3 files changed

+127
-1
lines changed

3 files changed

+127
-1
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ And follow the instructions for your database library:
2222

2323
- [libpqxx](#libpqxx)
2424

25-
Or check out an example:
25+
Or check out some examples:
2626

2727
- [Embeddings](examples/openai/example.cpp) with OpenAI
28+
- [Recommendations](examples/disco/example.cpp) with Disco
2829

2930
## libpqxx
3031

examples/disco/CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
cmake_minimum_required(VERSION 3.18)
2+
3+
project(example)
4+
5+
set(CMAKE_CXX_STANDARD 20)
6+
set(CMAKE_CXX_FLAGS "-Wno-unknown-attributes")
7+
8+
include(FetchContent)
9+
10+
FetchContent_Declare(disco GIT_REPOSITORY https://github.com/ankane/disco-cpp.git GIT_TAG v0.1.1)
11+
FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 7.10.0)
12+
FetchContent_MakeAvailable(disco libpqxx)
13+
14+
add_executable(example example.cpp)
15+
target_include_directories(example PRIVATE ${disco_SOURCE_DIR}/include ${CMAKE_SOURCE_DIR}/../../include)
16+
target_link_libraries(example PRIVATE pqxx)

examples/disco/example.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
#include <cassert>
2+
#include <cstdlib>
3+
#include <fstream>
4+
#include <iostream>
5+
#include <sstream>
6+
#include <string>
7+
#include <unordered_map>
8+
9+
#include <disco.hpp>
10+
// TODO make <pgvector/pqxx.hpp>
11+
#include <pqxx.hpp>
12+
#include <pqxx/pqxx>
13+
14+
using disco::Dataset;
15+
using disco::Recommender;
16+
17+
std::string convert_to_utf8(const std::string& str) {
18+
std::stringstream buf;
19+
for (auto &c : str) {
20+
auto v = static_cast<unsigned char>(c);
21+
// ISO-8859-1 to UTF-8
22+
// first 128 are same
23+
if (v < 128) {
24+
buf << v;
25+
} else {
26+
buf << static_cast<char>(195) << static_cast<char>(v - 64);
27+
}
28+
}
29+
return buf.str();
30+
}
31+
32+
Dataset<int, std::string> load_movielens(const std::string& path) {
33+
std::string line;
34+
35+
// read movies
36+
std::unordered_map<std::string, std::string> movies;
37+
std::ifstream movies_file(path + "/u.item");
38+
assert(movies_file.is_open());
39+
while (std::getline(movies_file, line)) {
40+
std::string::size_type n = line.find('|');
41+
std::string::size_type n2 = line.find('|', n + 1);
42+
movies.emplace(std::make_pair(line.substr(0, n), convert_to_utf8(line.substr(n + 1, n2 - n - 1))));
43+
}
44+
45+
// read ratings and create dataset
46+
auto data = Dataset<int, std::string>();
47+
std::ifstream ratings_file(path + "/u.data");
48+
assert(ratings_file.is_open());
49+
while (std::getline(ratings_file, line)) {
50+
std::string::size_type n = line.find('\t');
51+
std::string::size_type n2 = line.find('\t', n + 1);
52+
std::string::size_type n3 = line.find('\t', n2 + 1);
53+
data.push(
54+
std::stoi(line.substr(0, n)),
55+
movies.at(line.substr(n + 1, n2 - n - 1)),
56+
std::stof(line.substr(n2 + 1, n3 - n2 - 1))
57+
);
58+
}
59+
60+
return data;
61+
}
62+
63+
int main() {
64+
// https://grouplens.org/datasets/movielens/100k/
65+
char *movielens_path = std::getenv("MOVIELENS_100K_PATH");
66+
if (!movielens_path) {
67+
std::cout << "Set MOVIELENS_100K_PATH" << std::endl;
68+
return 1;
69+
}
70+
71+
pqxx::connection conn("dbname=pgvector_example");
72+
73+
pqxx::work tx(conn);
74+
tx.exec("CREATE EXTENSION IF NOT EXISTS vector");
75+
tx.exec("DROP TABLE IF EXISTS users");
76+
tx.exec("DROP TABLE IF EXISTS movies");
77+
tx.exec("CREATE TABLE users (id integer PRIMARY KEY, factors vector(20))");
78+
tx.exec("CREATE TABLE movies (name text PRIMARY KEY, factors vector(20))");
79+
tx.commit();
80+
81+
auto data = load_movielens(movielens_path);
82+
auto recommender = Recommender<int, std::string>::fit_explicit(data, { .factors = 20 });
83+
84+
for (auto& user_id : recommender.user_ids()) {
85+
auto factors = pgvector::Vector(*recommender.user_factors(user_id));
86+
tx.exec("INSERT INTO users (id, factors) VALUES ($1, $2)", {user_id, factors});
87+
}
88+
89+
for (auto& item_id : recommender.item_ids()) {
90+
auto factors = pgvector::Vector(*recommender.item_factors(item_id));
91+
tx.exec("INSERT INTO movies (name, factors) VALUES ($1, $2)", {item_id, factors});
92+
}
93+
94+
std::string movie = "Star Wars (1977)";
95+
std::cout << "Item-based recommendations for " << movie << std::endl;
96+
pqxx::result result = tx.exec("SELECT name FROM movies WHERE name != $1 ORDER BY factors <=> (SELECT factors FROM movies WHERE name = $1) LIMIT 5", pqxx::params{movie});
97+
for (auto const& row : result) {
98+
std::cout << "- " << row[0].c_str() << std::endl;
99+
}
100+
101+
int user_id = 123;
102+
std::cout << std::endl << "User-based recommendations for " << user_id << std::endl;
103+
result = tx.exec("SELECT name FROM movies ORDER BY factors <#> (SELECT factors FROM users WHERE id = $1) LIMIT 5", {user_id});
104+
for (auto const& row : result) {
105+
std::cout << "- " << row[0].c_str() << std::endl;
106+
}
107+
108+
return 0;
109+
}

0 commit comments

Comments
 (0)