@@ -494,6 +494,22 @@ def get_coordinates(self, as_array: bool = True) -> np.ndarray:
494494 coordinates = np .array (coordinates )
495495 return coordinates
496496
497+ def set_coordinates (self , coordinates : list [tuple ]) -> np .ndarray :
498+ """Get the coordinates of the root segments.
499+
500+ Args:
501+ as_array (bool, optional):
502+ Return the coordinates as a Numpy array. Defaults to True.
503+
504+ Returns:
505+ np.ndarray:
506+ The coordinates of the root segments
507+ """
508+ for i , segment in enumerate (self .segments ):
509+ node_data = segment .node_data
510+ node_data .x , node_data .y , node_data .z = coordinates [i ]
511+ return coordinates
512+
497513 def get_diameters (self , as_array : bool = True ) -> np .ndarray :
498514 """Get the diameters of the root segments.
499515
@@ -644,8 +660,6 @@ def __transform(**kwargs):
644660 if coin_flip == 1 :
645661 pitch *= - 1
646662
647- # No upwards growing roots
648- # Gravitropism
649663 iter_count = 0
650664 current_order = self .segments [0 ].node_data .order
651665 if current_order > 1 :
@@ -655,9 +669,12 @@ def __transform(**kwargs):
655669 __transform (pitch = pitch )
656670 iter_count += 1
657671
658- # Coordinates above no root zone
659672 iter_count = 0
660673 coordinates = self .get_coordinates ()
674+ if np .any (coordinates [:, 2 ] > no_root_zone ):
675+ coordinates [:, 2 ] *= - 1
676+ self .set_coordinates (coordinates )
677+
661678 while np .any (coordinates [:, 2 ] > no_root_zone ):
662679 if iter_count > max_attempts :
663680 return self .cascading_set_invalid_root ()
0 commit comments