diff --git a/examples/docs_to_kg/main.py b/examples/docs_to_kg/main.py index fb4c23e45..9813768f9 100644 --- a/examples/docs_to_kg/main.py +++ b/examples/docs_to_kg/main.py @@ -96,35 +96,35 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D "document_node", cocoindex.storages.Neo4j( connection=conn_spec, - mapping=cocoindex.storages.Neo4jNode(label="Document")), + mapping=cocoindex.storages.GraphNode(label="Document")), primary_key_fields=["filename"], ) entity_relationship.export( "entity_relationship", cocoindex.storages.Neo4j( connection=conn_spec, - mapping=cocoindex.storages.Neo4jRelationship( + mapping=cocoindex.storages.GraphRelationship( rel_type="RELATIONSHIP", - source=cocoindex.storages.Neo4jRelationshipEnd( + source=cocoindex.storages.GraphRelationshipEnd( label="Entity", fields=[ - cocoindex.storages.Neo4jFieldMapping( + cocoindex.storages.GraphFieldMapping( field_name="subject", node_field_name="value"), - cocoindex.storages.Neo4jFieldMapping( + cocoindex.storages.GraphFieldMapping( field_name="subject_embedding", node_field_name="embedding"), ] ), - target=cocoindex.storages.Neo4jRelationshipEnd( + target=cocoindex.storages.GraphRelationshipEnd( label="Entity", fields=[ - cocoindex.storages.Neo4jFieldMapping( + cocoindex.storages.GraphFieldMapping( field_name="object", node_field_name="value"), - cocoindex.storages.Neo4jFieldMapping( + cocoindex.storages.GraphFieldMapping( field_name="object_embedding", node_field_name="embedding"), ] ), nodes={ - "Entity": cocoindex.storages.Neo4jRelationshipNode( + "Entity": cocoindex.storages.GraphRelationshipNode( primary_key_fields=["value"], vector_indexes=[ cocoindex.VectorIndexDef( @@ -142,15 +142,15 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D "entity_mention", cocoindex.storages.Neo4j( connection=conn_spec, - mapping=cocoindex.storages.Neo4jRelationship( + mapping=cocoindex.storages.GraphRelationship( rel_type="MENTION", - source=cocoindex.storages.Neo4jRelationshipEnd( + source=cocoindex.storages.GraphRelationshipEnd( label="Document", - fields=[cocoindex.storages.Neo4jFieldMapping("filename")], + fields=[cocoindex.storages.GraphFieldMapping("filename")], ), - target=cocoindex.storages.Neo4jRelationshipEnd( + target=cocoindex.storages.GraphRelationshipEnd( label="Entity", - fields=[cocoindex.storages.Neo4jFieldMapping( + fields=[cocoindex.storages.GraphFieldMapping( field_name="entity", node_field_name="value")], ), ), diff --git a/python/cocoindex/storages.py b/python/cocoindex/storages.py index c1568e2e4..f615cfa3e 100644 --- a/python/cocoindex/storages.py +++ b/python/cocoindex/storages.py @@ -29,7 +29,7 @@ class Neo4jConnection: db: str | None = None @dataclass -class Neo4jFieldMapping: +class GraphFieldMapping: """Mapping for a Neo4j field.""" field_name: str # Field name for the node in the Knowledge Graph. @@ -37,36 +37,36 @@ class Neo4jFieldMapping: node_field_name: str | None = None @dataclass -class Neo4jRelationshipEnd: +class GraphRelationshipEnd: """Spec for a Neo4j node type.""" label: str - fields: list[Neo4jFieldMapping] + fields: list[GraphFieldMapping] @dataclass -class Neo4jRelationshipNode: +class GraphRelationshipNode: """Spec for a Neo4j node type.""" primary_key_fields: Sequence[str] vector_indexes: Sequence[index.VectorIndexDef] = () @dataclass -class Neo4jNode: +class GraphNode: """Spec for a Neo4j node type.""" kind = "Node" label: str @dataclass -class Neo4jRelationship: +class GraphRelationship: """Spec for a Neo4j relationship.""" kind = "Relationship" rel_type: str - source: Neo4jRelationshipEnd - target: Neo4jRelationshipEnd - nodes: dict[str, Neo4jRelationshipNode] | None = None + source: GraphRelationshipEnd + target: GraphRelationshipEnd + nodes: dict[str, GraphRelationshipNode] | None = None class Neo4j(op.StorageSpec): """Graph storage powered by Neo4j.""" connection: AuthEntryReference - mapping: Neo4jNode | Neo4jRelationship + mapping: GraphNode | GraphRelationship diff --git a/src/ops/storages/neo4j.rs b/src/ops/storages/neo4j.rs index 83ba4f5a1..0db197011 100644 --- a/src/ops/storages/neo4j.rs +++ b/src/ops/storages/neo4j.rs @@ -19,7 +19,7 @@ pub struct ConnectionSpec { } #[derive(Debug, Deserialize)] -pub struct FieldMapping { +pub struct GraphFieldMappingSpec { field_name: FieldName, /// Field name for the node in the Knowledge Graph. @@ -28,48 +28,48 @@ pub struct FieldMapping { node_field_name: Option, } -impl FieldMapping { +impl GraphFieldMappingSpec { fn get_node_field_name(&self) -> &FieldName { self.node_field_name.as_ref().unwrap_or(&self.field_name) } } #[derive(Debug, Deserialize)] -pub struct RelationshipEndSpec { +pub struct GraphRelationshipEndSpec { label: String, - fields: Vec, + fields: Vec, } #[derive(Debug, Deserialize)] -pub struct RelationshipNodeSpec { +pub struct GraphRelationshipNodeSpec { #[serde(flatten)] index_options: spec::IndexOptions, } #[derive(Debug, Deserialize)] -pub struct NodeSpec { +pub struct GraphNodeSpec { label: String, } #[derive(Debug, Deserialize)] -pub struct RelationshipSpec { +pub struct GraphRelationshipSpec { rel_type: String, - source: RelationshipEndSpec, - target: RelationshipEndSpec, - nodes: Option>, + source: GraphRelationshipEndSpec, + target: GraphRelationshipEndSpec, + nodes: Option>, } #[derive(Debug, Deserialize)] #[serde(tag = "kind")] -pub enum RowMappingSpec { - Relationship(RelationshipSpec), - Node(NodeSpec), +pub enum GraphMappingSpec { + Relationship(GraphRelationshipSpec), + Node(GraphNodeSpec), } #[derive(Debug, Deserialize)] pub struct Spec { connection: AuthEntryReference, - mapping: RowMappingSpec, + mapping: GraphMappingSpec, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] @@ -101,10 +101,12 @@ impl ElementType { } } - fn from_mapping_spec(spec: &RowMappingSpec) -> Self { + fn from_mapping_spec(spec: &GraphMappingSpec) -> Self { match spec { - RowMappingSpec::Relationship(spec) => ElementType::Relationship(spec.rel_type.clone()), - RowMappingSpec::Node(spec) => ElementType::Node(spec.label.clone()), + GraphMappingSpec::Relationship(spec) => { + ElementType::Relationship(spec.rel_type.clone()) + } + GraphMappingSpec::Node(spec) => ElementType::Node(spec.label.clone()), } } @@ -393,7 +395,7 @@ impl ExportContext { key_fields.iter().map(|f| &f.name), ); let result = match spec.mapping { - RowMappingSpec::Node(node_spec) => { + GraphMappingSpec::Node(node_spec) => { let delete_cypher = formatdoc! {" OPTIONAL MATCH (old_node:{label} {key_fields_literal}) WITH old_node @@ -433,7 +435,7 @@ impl ExportContext { tgt_fields: None, } } - RowMappingSpec::Relationship(rel_spec) => { + GraphMappingSpec::Relationship(rel_spec) => { let delete_cypher = formatdoc! {" OPTIONAL MATCH (old_src)-[old_rel:{rel_type} {key_fields_literal}]->(old_tgt) @@ -687,8 +689,8 @@ impl RelationshipSetupState { } let mut dependent_node_labels = vec![]; match &spec.mapping { - RowMappingSpec::Node(_) => {} - RowMappingSpec::Relationship(rel_spec) => { + GraphMappingSpec::Node(_) => {} + GraphMappingSpec::Relationship(rel_spec) => { let (src_label_info, tgt_label_info) = end_nodes_label_info.ok_or_else(|| { anyhow!( "Expect `end_nodes_label_info` existing for relationship `{}`", @@ -1079,12 +1081,15 @@ impl Factory { struct DependentNodeLabelAnalyzer<'a> { label_name: &'a str, fields: IndexMap<&'a str, AnalyzedGraphFieldMapping>, - remaining_fields: HashMap<&'a str, &'a FieldMapping>, + remaining_fields: HashMap<&'a str, &'a GraphFieldMappingSpec>, index_options: Option<&'a IndexOptions>, } impl<'a> DependentNodeLabelAnalyzer<'a> { - fn new(rel_spec: &'a RelationshipSpec, rel_end_spec: &'a RelationshipEndSpec) -> Result { + fn new( + rel_spec: &'a GraphRelationshipSpec, + rel_end_spec: &'a GraphRelationshipEndSpec, + ) -> Result { Ok(Self { label_name: rel_end_spec.label.as_str(), fields: IndexMap::new(), @@ -1181,7 +1186,7 @@ impl StorageFactoryBase for Factory { let setup_key = GraphElement::from_spec(&spec); let (value_fields_info, rel_end_label_info) = match &spec.mapping { - RowMappingSpec::Node(_) => ( + GraphMappingSpec::Node(_) => ( value_fields_schema .into_iter() .enumerate() @@ -1193,7 +1198,7 @@ impl StorageFactoryBase for Factory { .collect(), None, ), - RowMappingSpec::Relationship(rel_spec) => { + GraphMappingSpec::Relationship(rel_spec) => { let mut src_label_analyzer = DependentNodeLabelAnalyzer::new(&rel_spec, &rel_spec.source)?; let mut tgt_label_analyzer =