diff --git a/src/ops/storages/neo4j.rs b/src/ops/storages/neo4j.rs index b467e0346..8b1915aef 100644 --- a/src/ops/storages/neo4j.rs +++ b/src/ops/storages/neo4j.rs @@ -362,14 +362,12 @@ impl ExportContext { rel_type = rel_spec.rel_type, }; - let analyzed_src = analyzed_data_coll - .source + let analyzed_rel = analyzed_data_coll + .rel .as_ref() - .ok_or_else(|| anyhow!("Relationship spec requires source fields"))?; - let analyzed_tgt = analyzed_data_coll - .target - .as_ref() - .ok_or_else(|| anyhow!("Relationship spec requires target fields"))?; + .ok_or_else(invariance_violation)?; + let analyzed_src = &analyzed_rel.source; + let analyzed_tgt = &analyzed_rel.target; let (src_key_field_params, src_key_fields_literal) = Self::build_key_field_params_n_literal( @@ -475,57 +473,41 @@ impl ExportContext { } let value = &upsert.value; - let mut insert_cypher = + let mut query = self.bind_rel_key_field_params(neo4rs::query(&self.insert_cypher), &upsert.key)?; - if let Some(analyzed_src) = &self.analyzed_data_coll.source { - insert_cypher = Self::bind_key_field_params( - insert_cypher, - &self.src_key_field_params, - std::iter::zip( - analyzed_src.schema.key_fields.iter(), - analyzed_src.fields_input_idx.key.iter(), - ) - .map(|(f, field_idx)| (&f.value_type.typ, &value.fields[*field_idx])), - )?; - - if analyzed_src.has_value_fields() { - insert_cypher = insert_cypher.param( - SRC_PROPS_PARAM, - mapped_field_values_to_bolt( - &analyzed_src.schema.value_fields, - &analyzed_src.fields_input_idx.value, - value, - )?, - ); - } - } - - if let Some(analyzed_tgt) = &self.analyzed_data_coll.target { - insert_cypher = Self::bind_key_field_params( - insert_cypher, - &self.tgt_key_field_params, - std::iter::zip( - analyzed_tgt.schema.key_fields.iter(), - analyzed_tgt.fields_input_idx.key.iter(), - ) - .map(|(f, field_idx)| (&f.value_type.typ, &value.fields[*field_idx])), - )?; - - if analyzed_tgt.has_value_fields() { - insert_cypher = insert_cypher.param( - TGT_PROPS_PARAM, - mapped_field_values_to_bolt( - &analyzed_tgt.schema.value_fields, - &analyzed_tgt.fields_input_idx.value, - value, - )?, - ); - } + if let Some(analyzed_rel) = &self.analyzed_data_coll.rel { + let bind_params = |query: neo4rs::Query, + analyzed: &AnalyzedGraphElementFieldMapping, + key_field_params: &[String]| + -> Result { + let mut query = Self::bind_key_field_params( + query, + key_field_params, + std::iter::zip( + analyzed.schema.key_fields.iter(), + analyzed.fields_input_idx.key.iter(), + ) + .map(|(f, field_idx)| (&f.value_type.typ, &value.fields[*field_idx])), + )?; + if analyzed.has_value_fields() { + query = query.param( + SRC_PROPS_PARAM, + mapped_field_values_to_bolt( + &analyzed.schema.value_fields, + &analyzed.fields_input_idx.value, + value, + )?, + ); + } + Ok(query) + }; + query = bind_params(query, &analyzed_rel.source, &self.src_key_field_params)?; + query = bind_params(query, &analyzed_rel.target, &self.tgt_key_field_params)?; } if !self.analyzed_data_coll.value_fields_input_idx.is_empty() { - insert_cypher = insert_cypher.param( + query = query.param( CORE_PROPS_PARAM, mapped_field_values_to_bolt( &self.analyzed_data_coll.schema.value_fields, @@ -534,7 +516,7 @@ impl ExportContext { )?, ); } - queries.push(insert_cypher); + queries.push(query); Ok(()) } diff --git a/src/ops/storages/shared/property_graph.rs b/src/ops/storages/shared/property_graph.rs index bd5065538..44e3324e4 100644 --- a/src/ops/storages/shared/property_graph.rs +++ b/src/ops/storages/shared/property_graph.rs @@ -132,22 +132,24 @@ impl AnalyzedGraphElementFieldMapping { } } +pub struct AnalyzedRelationshipInfo { + pub source: AnalyzedGraphElementFieldMapping, + pub target: AnalyzedGraphElementFieldMapping, +} + pub struct AnalyzedDataCollection { pub schema: Arc, pub value_fields_input_idx: Vec, - pub source: Option, - pub target: Option, + pub rel: Option, } impl AnalyzedDataCollection { pub fn dependent_node_labels(&self) -> IndexSet<&str> { let mut dependent_node_labels = IndexSet::new(); - if let Some(source) = &self.source { - dependent_node_labels.insert(source.schema.elem_type.label()); - } - if let Some(target) = &self.target { - dependent_node_labels.insert(target.schema.elem_type.label()); + if let Some(rel) = &self.rel { + dependent_node_labels.insert(rel.source.schema.elem_type.label()); + dependent_node_labels.insert(rel.target.schema.elem_type.label()); } dependent_node_labels } @@ -514,26 +516,27 @@ pub fn analyze_graph_mappings<'a, AuthEntry: 'a>( .ok_or_else(invariance_violation)? .clone(), value_fields_input_idx: processed_info.value_input_fields_idx, - source: None, - target: None, + rel: None, }, // Relationship Some(rel_info) => AnalyzedDataCollection { schema: Arc::new(rel_info.rel_schema), value_fields_input_idx: processed_info.value_input_fields_idx, - source: Some(AnalyzedGraphElementFieldMapping { - schema: node_schemas - .get(&rel_info.source_typ) - .ok_or_else(invariance_violation)? - .clone(), - fields_input_idx: rel_info.source_fields_idx, - }), - target: Some(AnalyzedGraphElementFieldMapping { - schema: node_schemas - .get(&rel_info.target_typ) - .ok_or_else(invariance_violation)? - .clone(), - fields_input_idx: rel_info.target_fields_idx, + rel: Some(AnalyzedRelationshipInfo { + source: AnalyzedGraphElementFieldMapping { + schema: node_schemas + .get(&rel_info.source_typ) + .ok_or_else(invariance_violation)? + .clone(), + fields_input_idx: rel_info.source_fields_idx, + }, + target: AnalyzedGraphElementFieldMapping { + schema: node_schemas + .get(&rel_info.target_typ) + .ok_or_else(invariance_violation)? + .clone(), + fields_input_idx: rel_info.target_fields_idx, + }, }), }, };