Skip to content

Commit 2759e0e

Browse files
authored
Merge pull request #201 from jchanvfx/auto_layout_func_into_graph
Auto layout func into graph
2 parents 17f1a53 + 11e8bfd commit 2759e0e

File tree

4 files changed

+186
-241
lines changed

4 files changed

+186
-241
lines changed

NodeGraphQt/base/graph.py

Lines changed: 152 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
from .model import NodeGraphModel
1818
from .node import NodeObject, BaseNode, BackdropNode
1919
from .port import Port
20-
from ..constants import (URI_SCHEME, URN_SCHEME,
21-
PIPE_LAYOUT_CURVED,
22-
PIPE_LAYOUT_STRAIGHT,
23-
PIPE_LAYOUT_ANGLE,
24-
IN_PORT, OUT_PORT,
25-
VIEWER_GRID_LINES)
20+
from ..constants import (
21+
URI_SCHEME, URN_SCHEME,
22+
NODE_LAYOUT_DIRECTION, NODE_LAYOUT_HORIZONTAL, NODE_LAYOUT_VERTICAL,
23+
PIPE_LAYOUT_CURVED, PIPE_LAYOUT_STRAIGHT, PIPE_LAYOUT_ANGLE,
24+
IN_PORT, OUT_PORT,
25+
VIEWER_GRID_LINES
26+
)
2627
from ..widgets.node_space_bar import node_space_bar
2728
from ..widgets.viewer import NodeViewer
2829

@@ -1558,6 +1559,149 @@ def disable_nodes(self, nodes, mode=None):
15581559
return
15591560
nodes[0].set_disabled(mode)
15601561

1562+
# auto layout node functions.
1563+
1564+
@staticmethod
1565+
def _update_node_rank(node, nodes_rank, down_stream=True):
1566+
"""
1567+
Recursive function for updating the node ranking.
1568+
1569+
Args:
1570+
node (NodeGraphQt.BaseNode): node to start from.
1571+
nodes_rank (dict): node ranking object to be updated.
1572+
down_stream (bool): true to rank down stram.
1573+
"""
1574+
if down_stream:
1575+
node_values = node.connected_output_nodes().values()
1576+
else:
1577+
node_values = node.connected_input_nodes().values()
1578+
1579+
connected_nodes = set()
1580+
for nodes in node_values:
1581+
connected_nodes.update(nodes)
1582+
1583+
rank = nodes_rank[node] + 1
1584+
for n in connected_nodes:
1585+
if n in nodes_rank:
1586+
nodes_rank[n] = max(nodes_rank[n], rank)
1587+
else:
1588+
nodes_rank[n] = rank
1589+
NodeGraph._update_node_rank(n, nodes_rank, down_stream)
1590+
1591+
@staticmethod
1592+
def _compute_node_rank(nodes, down_stream=True):
1593+
"""
1594+
Compute the ranking of nodes.
1595+
1596+
Args:
1597+
nodes (list[NodeGraphQt.BaseNode]): nodes to start ranking from.
1598+
down_stream (bool): true to compute down stream.
1599+
1600+
Returns:
1601+
dict: {NodeGraphQt.BaseNode: node_rank, ...}
1602+
"""
1603+
nodes_rank = {}
1604+
for node in nodes:
1605+
nodes_rank[node] = 0
1606+
NodeGraph._update_node_rank(node, nodes_rank, down_stream)
1607+
return nodes_rank
1608+
1609+
def auto_layout_nodes(self, nodes=None, down_stream=True, start_nodes=None):
1610+
"""
1611+
Auto layout the nodes in the node graph.
1612+
1613+
Note:
1614+
If the node graph is acyclic then the ``start_nodes`` will need
1615+
to be specified.
1616+
1617+
Args:
1618+
nodes (list[NodeGraphQt.BaseNode]): list of nodes to auto layout
1619+
if nodes is None then all nodes is layed out.
1620+
down_stream (bool): false to layout up stream.
1621+
start_nodes (list[NodeGraphQt.BaseNode]):
1622+
list of nodes to start the auto layout from (Optional).
1623+
"""
1624+
self.begin_undo('Auto Layout Nodes')
1625+
1626+
nodes = nodes or self.all_nodes()
1627+
1628+
# filter out the backdrops.
1629+
backdrops = {
1630+
n: n.nodes() for n in nodes if isinstance(n, BackdropNode)
1631+
}
1632+
filtered_nodes = [n for n in nodes if not isinstance(n, BackdropNode)]
1633+
1634+
start_nodes = start_nodes or []
1635+
if down_stream:
1636+
start_nodes += [
1637+
n for n in filtered_nodes
1638+
if not any(n.connected_input_nodes().values())
1639+
]
1640+
else:
1641+
start_nodes += [
1642+
n for n in filtered_nodes
1643+
if not any(n.connected_output_nodes().values())
1644+
]
1645+
1646+
if not start_nodes:
1647+
return
1648+
1649+
node_views = [n.view for n in nodes]
1650+
nodes_center_0 = self.viewer().nodes_rect_center(node_views)
1651+
1652+
nodes_rank = NodeGraph._compute_node_rank(start_nodes, down_stream)
1653+
1654+
rank_map = {}
1655+
for node, rank in nodes_rank.items():
1656+
if rank in rank_map:
1657+
rank_map[rank].append(node)
1658+
else:
1659+
rank_map[rank] = [node]
1660+
1661+
if NODE_LAYOUT_DIRECTION is NODE_LAYOUT_HORIZONTAL:
1662+
current_x = 0
1663+
node_height = 120
1664+
for rank in sorted(range(len(rank_map)), reverse=not down_stream):
1665+
ranked_nodes = rank_map[rank]
1666+
max_width = max([node.view.width for node in ranked_nodes])
1667+
current_x += max_width
1668+
current_y = 0
1669+
for idx, node in enumerate(ranked_nodes):
1670+
dy = max(node_height, node.view.height)
1671+
current_y += 0 if idx == 0 else dy
1672+
node.set_pos(current_x, current_y)
1673+
current_y += dy * 0.5 + 10
1674+
1675+
current_x += max_width * 0.5 + 100
1676+
elif NODE_LAYOUT_DIRECTION is NODE_LAYOUT_VERTICAL:
1677+
current_y = 0
1678+
node_width = 250
1679+
for rank in sorted(range(len(rank_map)), reverse=not down_stream):
1680+
ranked_nodes = rank_map[rank]
1681+
max_height = max([node.view.height for node in ranked_nodes])
1682+
current_y += max_height
1683+
current_x = 0
1684+
for idx, node in enumerate(ranked_nodes):
1685+
dx = max(node_width, node.view.width)
1686+
current_x += 0 if idx == 0 else dx
1687+
node.set_pos(current_x, current_y)
1688+
current_x += dx * 0.5 + 10
1689+
1690+
current_y += max_height * 0.5 + 100
1691+
1692+
nodes_center_1 = self.viewer().nodes_rect_center(node_views)
1693+
dx = nodes_center_0[0] - nodes_center_1[0]
1694+
dy = nodes_center_0[1] - nodes_center_1[1]
1695+
[n.set_pos(n.x_pos() + dx, n.y_pos() + dy) for n in nodes]
1696+
1697+
# wrap the backdrop nodes.
1698+
for backdrop, contained_nodes in backdrops.items():
1699+
backdrop.wrap_nodes(contained_nodes)
1700+
1701+
self.end_undo()
1702+
1703+
# prompt dialog functions.
1704+
15611705
def question_dialog(self, text, title='Node Graph'):
15621706
"""
15631707
Prompts a question open dialog with ``"Yes"`` and ``"No"`` buttons in
@@ -1624,6 +1768,8 @@ def save_dialog(self, current_dir=None, ext=None):
16241768
"""
16251769
return self._viewer.save_dialog(current_dir, ext)
16261770

1771+
### ---
1772+
16271773
def use_OpenGL(self):
16281774
"""
16291775
Use OpenGL to draw the graph.

0 commit comments

Comments
 (0)