Skip to content

Commit 2a36320

Browse files
authored
feat(api): Semantic search on name and description (#47)
* feat(api): Semantic search on name and description * fix: clippy + bump docker rust version
1 parent 45a2136 commit 2a36320

File tree

18 files changed

+1390
-120
lines changed

18 files changed

+1390
-120
lines changed

Cargo.lock

Lines changed: 1047 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ docker run \
2323
### 2. Compile and run the indexer
2424
In a separate terminal, run the following commands:
2525
```bash
26-
cargo run --bin sink -- \
26+
CFLAGS='-std=gnu17' cargo run --bin sink -- \
2727
--reset-db \
2828
--neo4j-uri neo4j://localhost:7687 \
2929
--neo4j-user neo4j \
3030
--neo4j-pass neo4j
3131
```
3232

3333
```bash
34-
cargo run --bin api -- \
34+
CFLAGS='-std=gnu17' cargo run --bin api -- \
3535
--neo4j-uri neo4j://localhost:7687 \
3636
--neo4j-user neo4j \
3737
--neo4j-pass neo4j

api/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ grc20-core = { version = "0.1.0", path = "../grc20-core" }
2222
grc20-sdk = { version = "0.1.0", path = "../grc20-sdk" }
2323
cache = { version = "0.1.0", path = "../cache" }
2424
chrono = "0.4.39"
25+
fastembed = "4.8.0"
2526

2627
[dev-dependencies]
2728
serde_path_to_error = "0.1.16"

api/src/context.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,30 @@
11
use cache::KgCache;
2+
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
23
use grc20_core::neo4rs;
34
use std::sync::Arc;
45

6+
const EMBEDDING_MODEL: EmbeddingModel = EmbeddingModel::AllMiniLML6V2;
7+
58
#[derive(Clone)]
69
pub struct KnowledgeGraph {
710
pub neo4j: Arc<neo4rs::Graph>,
811
pub cache: Option<Arc<KgCache>>,
12+
pub embedding_model: Arc<TextEmbedding>,
913
}
1014

1115
impl juniper::Context for KnowledgeGraph {}
1216

1317
impl KnowledgeGraph {
1418
pub fn new(neo4j: Arc<neo4rs::Graph>, cache: Option<Arc<KgCache>>) -> Self {
15-
Self { neo4j, cache }
19+
Self {
20+
neo4j,
21+
cache,
22+
embedding_model: Arc::new(
23+
TextEmbedding::try_new(
24+
InitOptions::new(EMBEDDING_MODEL).with_show_download_progress(true),
25+
)
26+
.expect("Failed to initialize embedding model"),
27+
),
28+
}
1629
}
1730
}

api/src/schema/query.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,4 +317,33 @@ impl RootQuery {
317317
.await?
318318
.map(|triple| Triple::new(triple, space_id, version_index)))
319319
}
320+
321+
async fn search<'a, S: ScalarValue>(
322+
&'a self,
323+
executor: &'a Executor<'_, '_, KnowledgeGraph, S>,
324+
query: String,
325+
#[graphql(default = 100)] first: i32,
326+
// #[graphql(default = 0)] skip: i32,
327+
) -> FieldResult<Vec<Triple>> {
328+
let embedding = executor
329+
.context()
330+
.embedding_model
331+
.embed(vec![&query], None)
332+
.expect("Failed to get embedding")
333+
.pop()
334+
.expect("Embedding is empty")
335+
.into_iter()
336+
.map(|v| v as f64)
337+
.collect::<Vec<_>>();
338+
339+
let query = mapping::triple::semantic_search(&executor.context().neo4j, embedding)
340+
.limit(first as usize);
341+
342+
Ok(query
343+
.send()
344+
.await?
345+
.map_ok(|search_result| Triple::new(search_result.triple, search_result.space_id, None))
346+
.try_collect::<Vec<_>>()
347+
.await?)
348+
}
320349
}

api/src/schema/triple.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::fmt::Display;
22

3-
use juniper::{graphql_object, Executor, GraphQLEnum, GraphQLObject, ScalarValue};
3+
use juniper::{graphql_object, Executor, FieldResult, GraphQLEnum, GraphQLObject, ScalarValue};
44

55
use grc20_core::{
66
mapping::{self, query_utils::Query, triple},
@@ -9,8 +9,11 @@ use grc20_core::{
99

1010
use crate::context::KnowledgeGraph;
1111

12+
use super::Entity;
13+
1214
#[derive(Debug)]
1315
pub struct Triple {
16+
entity_id: String,
1417
pub attribute: String,
1518
pub value: String,
1619
pub value_type: ValueType,
@@ -23,6 +26,7 @@ pub struct Triple {
2326
impl Triple {
2427
pub fn new(triple: mapping::Triple, space_id: String, space_version: Option<String>) -> Self {
2528
Self {
29+
entity_id: triple.entity,
2630
attribute: triple.attribute,
2731
value: triple.value.value,
2832
value_type: triple.value.value_type.into(),
@@ -79,6 +83,20 @@ impl Triple {
7983
.expect("Failed to find triple name attribute")
8084
.map(|triple| triple.value.value)
8185
}
86+
87+
async fn entity<'a, S: ScalarValue>(
88+
&'a self,
89+
executor: &'a Executor<'_, '_, KnowledgeGraph, S>,
90+
) -> FieldResult<Option<Entity>> {
91+
Entity::load(
92+
&executor.context().neo4j,
93+
self.entity_id.clone(),
94+
self.space_id.clone(),
95+
self.space_version.clone(),
96+
false,
97+
)
98+
.await
99+
}
82100
}
83101

84102
impl From<mapping::ValueType> for ValueType {

docker/Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
FROM rust:1.81.0 AS builder
1+
FROM rust:1.87.0 AS builder
22

33
WORKDIR /kg-node
44
COPY . .
55
RUN apt-get update && apt-get upgrade -y
66
RUN apt-get install libssl-dev protobuf-compiler -y
7-
RUN cargo build --release --bin sink --bin api
7+
RUN CFLAGS='-std=gnu17' cargo build --release --bin sink --bin api
88

99
# Run image
1010
FROM debian:bookworm-slim AS run

grc20-core/src/ids/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,8 @@ pub mod network_ids;
55
pub mod system_ids;
66

77
pub use id::*;
8+
9+
pub fn indexed(id: &str) -> bool {
10+
// Add other ids to this list as needed
11+
id == system_ids::DESCRIPTION_ATTRIBUTE || id == system_ids::NAME_ATTRIBUTE
12+
}

grc20-core/src/mapping/attributes.rs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -837,15 +837,11 @@ mod tests {
837837
.await
838838
.expect("Insert failed");
839839

840-
Triple {
841-
entity: "abc".to_string(),
842-
attribute: "bar".to_string(),
843-
value: 456u64.into(),
844-
}
845-
.insert(&neo4j, &BlockMetadata::default(), "space_id", "1")
846-
.send()
847-
.await
848-
.expect("Failed to insert triple");
840+
Triple::new("abc", "bar", 456u64)
841+
.insert(&neo4j, &BlockMetadata::default(), "space_id", "1")
842+
.send()
843+
.await
844+
.expect("Failed to insert triple");
849845

850846
let foo_v2 = entity::find_one::<Entity<Foo>>(&neo4j, "abc")
851847
.space_id("space_id")

grc20-core/src/mapping/entity/find_one.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,7 @@ mod tests {
168168
// Setup a local Neo 4J container for testing. NOTE: docker service must be running.
169169
let (_container, neo4j) = crate::test_utils::setup_neo4j().await;
170170

171-
let triple = Triple {
172-
entity: "abc".to_string(),
173-
attribute: "name".to_string(),
174-
value: "Alice".into(),
175-
};
171+
let triple = Triple::new("abc", "name", "Alice");
176172

177173
triple
178174
.insert(&neo4j, &BlockMetadata::default(), "space_id", "0")

0 commit comments

Comments
 (0)