diff --git a/doc/source/relationships.rst b/doc/source/relationships.rst index b1be1e94..f3a2bb7a 100644 --- a/doc/source/relationships.rst +++ b/doc/source/relationships.rst @@ -173,5 +173,8 @@ The ``defintion`` argument is a :term:`py3:mapping` with these items: ``direction`` ``match.OUTGOING`` / ``match.INCOMING`` / ``match.EITHER`` ``relation_type`` Can be ``None`` (for any direction), ``*`` for all paths or an explicit name of a relation type (the edge's label). + Matching multiple labels can be done by supplying a list of + names. This will match any edge that matches at least one of + these. ``model`` The class of the relation model, ``None`` for such without one. ================= =============================================================== diff --git a/neomodel/match.py b/neomodel/match.py index 12ffa992..50bc2e4f 100644 --- a/neomodel/match.py +++ b/neomodel/match.py @@ -28,7 +28,7 @@ def _rel_helper(lhs, rhs, ident=None, relation_type=None, direction=None, relati :type rhs: str :param ident: A specific identity to name the relationship, or None. :type ident: str - :param relation_type: None for all direct rels, * for all of any length, or a name of an explicit rel. + :param relation_type: None for all direct rels, * for all of any length, a name of an explicit rel, or a list of names. :type relation_type: str :param direction: None or EITHER for all OUTGOING,INCOMING,EITHER. Otherwise OUTGOING or INCOMING. :param relation_properties: dictionary of relationship properties to match @@ -56,7 +56,15 @@ def _rel_helper(lhs, rhs, ident=None, relation_type=None, direction=None, relati stmt = stmt.format('[*]') else: # explicit relation_type - stmt = stmt.format('[{0}:`{1}`{2}]'.format(ident if ident else '', relation_type, rel_props)) + # if multiple relationship types are given, use OR syntax (:TYPE1|TYPE2|TYPE3...) + if type(relation_type) != str and hasattr(relation_type, '__iter__'): + # we also have to escape them here, as we cannot escape the built string later. + # `TYPE1|TYPE2` will be interpreted as a single value + relation_type = "|".join(['`' + label + '`' for label in relation_type]) + else: + # if it is a single value, escape it + relation_type = '`' + relation_type + '`' + stmt = stmt.format('[{0}:{1}{2}]'.format(ident if ident else '', relation_type, rel_props)) return "({0}){1}({2})".format(lhs, stmt, rhs) @@ -323,6 +331,7 @@ def build_traversal(self, traversal): lhs_ident = self.build_source(traversal.source) rhs_ident = traversal.name + rhs_label self._ast['return'] = traversal.name + self._ast['return_mod'] = 'DISTINCT' self._ast['result_class'] = traversal.target_class rel_ident = self.create_ident() @@ -461,7 +470,7 @@ def build_query(self): query += ' WITH ' query += self._ast['with'] - query += ' RETURN ' + self._ast['return'] + query += ' RETURN ' + self._ast.get('return_mod', '') + ' ' + self._ast['return'] if 'order_by' in self._ast and self._ast['order_by']: query += ' ORDER BY ' @@ -476,7 +485,7 @@ def build_query(self): return query def _count(self): - self._ast['return'] = 'count({0})'.format(self._ast['return']) + self._ast['return'] = 'count({0} {1})'.format(self._ast.pop('return_mod', ''), self._ast['return']) # drop order_by, results in an invalid query self._ast.pop('order_by', None) query = self.build_query() diff --git a/test/test_relationships.py b/test/test_relationships.py index ee803caa..51d247e3 100644 --- a/test/test_relationships.py +++ b/test/test_relationships.py @@ -1,7 +1,7 @@ from pytest import raises from neomodel import (StructuredNode, RelationshipTo, RelationshipFrom, Relationship, - StringProperty, IntegerProperty, StructuredRel, One) + StringProperty, IntegerProperty, StructuredRel, One, Traversal) class PersonWithRels(StructuredNode): @@ -182,3 +182,51 @@ def test_props_relationship(): with raises(NotImplementedError): c.inhabitant.connect(u, properties={'city': 'Thessaloniki'}) + +def test_multiple_label_relationship_traversal(): + #set up country and two persons + p1 = PersonWithRels(name="Max", age=20).save() + p2 = PersonWithRels(name="Moritz", age=21).save() + c1 = Country(code="IO").save() + + assert p1 + assert p2 + assert c1 + + c1.inhabitant.connect(p1) + c1.president.connect(p2) + + assert len(c1.inhabitant) == 1 + assert len(c1.president) == 1 + + # test that both inhabitant and president is returned when specifying both + definition = dict(node_class=PersonWithRels, direction=None, + relation_type=('IS_FROM', 'PRESIDENT'), model=None) + relations_traversal = Traversal(c1, PersonWithRels.__label__, + definition) + + assert len(relations_traversal) == 2 + + assert p1 in relations_traversal + assert p2 in relations_traversal + + # add president as inhabitant + c1.inhabitant.connect(p2) + + # test that we still get only two results from the same traversal + relations_traversal = Traversal(c1, PersonWithRels.__label__, + definition) + + # p2 is connected twice, but should only be returned once + assert len(relations_traversal) == 2 + + assert p1 in relations_traversal + assert p2 in relations_traversal + + # check if lazy evaluation also honors the distinct restriction + lazy_evaluated = Traversal(c1, PersonWithRels.__label__, definition).all(True) + + assert len(lazy_evaluated) == 2 + + assert p1.id in lazy_evaluated + assert p2.id in lazy_evaluated