@@ -244,12 +244,11 @@ class Tree(Datastructure):
244244 >>> branch.add(leaf1)
245245 >>> branch.add(leaf2)
246246 >>> print(tree)
247- <Tree with 4 nodes, 1 branches, and 2 leaves>
248- >>> tree.print()
249- <TreeNode root>
250- <TreeNode branch>
251- <TreeNode leaf2>
252- <TreeNode leaf1>
247+ <Tree with 4 nodes>
248+ |--<TreeNode: root>
249+ |-- <TreeNode: branch>
250+ |-- <TreeNode: leaf1>
251+ |-- <TreeNode: leaf2>
253252
254253 """
255254
@@ -281,6 +280,9 @@ def __init__(self, name=None, **kwargs):
281280 super (Tree , self ).__init__ (kwargs , name = name )
282281 self ._root = None
283282
283+ def __str__ (self ):
284+ return "<Tree with {} nodes>\n {}" .format (len (list (self .nodes )), self .get_hierarchy_string (max_depth = 3 ))
285+
284286 @property
285287 def root (self ):
286288 return self ._root
@@ -435,12 +437,9 @@ def get_nodes_by_name(self, name):
435437 nodes .append (node )
436438 return nodes
437439
438- def __repr__ (self ):
439- return "<Tree with {} nodes>" .format (len (list (self .nodes )))
440-
441- def print_hierarchy (self , max_depth = None ):
440+ def get_hierarchy_string (self , max_depth = None ):
442441 """
443- Print the spatial hierarchy of the tree.
442+ Return string representation for the spatial hierarchy of the tree.
444443
445444 Parameters
446445 ----------
@@ -450,22 +449,27 @@ def print_hierarchy(self, max_depth=None):
450449
451450 Returns
452451 -------
453- None
452+ str
453+ String representing the spatial hierarchy of the tree.
454454
455455 """
456456
457- def _print (node , prefix = "" , last = True , depth = 0 ):
457+ hierarchy = []
458+
459+ def traverse (node , hierarchy , prefix = "" , last = True , depth = 0 ):
458460
459461 if max_depth is not None and depth > max_depth :
460462 return
461463
462464 connector = "└── " if last else "├── "
463- print ("{}{}{}" .format (prefix , connector , node ))
465+ hierarchy . append ("{}{}{}" .format (prefix , connector , node ))
464466 prefix += " " if last else "│ "
465467 for i , child in enumerate (node .children ):
466- _print (child , prefix , i == len (node .children ) - 1 , depth + 1 )
468+ traverse (child , hierarchy , prefix , i == len (node .children ) - 1 , depth + 1 )
469+
470+ traverse (self .root , hierarchy )
467471
468- _print ( self . root )
472+ return " \n " . join ( hierarchy )
469473
470474 def to_graph (self , key_mapper = None ):
471475 """Convert the tree to a graph.
0 commit comments