Skip to content

Nodes don't get removed when customizing 'expand' function #37

@rc2000123

Description

@rc2000123

I want to create a custom expand function where when you double click on a node, the elements item is modified where the parent node is removed and 2 new nodes take its place. But when I set up the expand function, the new nodes appeared, but the original node doesn't disappear, despite the node being removed from the elements variable. Here is a modified example from the demo.

import streamlit as st
from st_link_analysis import st_link_analysis, NodeStyle, EdgeStyle, Event

st.set_page_config(layout="wide")
parent_child_dict = {2: [4,5]}

# Sample Data
elements = {
    "nodes": [
        {"data": {"id": 1, "label": "REGULAR", "name": "Streamlit"}},
        {"data": {"id": 2, "label": "PARENT", "name": "PARENT2"}},
        {"data": {"id": 3, "label": "REGULAR", "name": "World"}},
        
        {"data": {"id": 4, "label": "CHILD", "name": "CHILD4"}},
        {"data": {"id": 5, "label": "CHILD", "name": "CHILD5"}},
    ],
    "edges": [
        {"data": {"id": 7, "label": "REGULAR_EDGE", "source": 1, "target": 2}},
        {"data": {"id": 8, "label": "REGULAR_EDGE", "source": 2, "target": 3}},

        {"data": {"id": 9, "label": "CHILD_EDGE", "source": 1, "target": 4}},
        {"data": {"id": 10, "label": "CHILD_EDGE", "source": 4, "target": 3}},

        {"data": {"id": 11, "label": "CHILD_EDGE", "source": 1, "target": 5}},
        {"data": {"id": 12, "label": "CHILD_EDGE", "source": 5, "target": 3}},
        
    ],
}


class DummyGraph:
    def __init__(self):
        self.all_nodes = elements["nodes"]
        self.all_edges = elements["edges"]
        self.nodes = set([n["data"]["id"] for n in elements["nodes"]])
        self.edges = set([e["data"]["id"] for e in elements["edges"]])

        remove_list = []
        for node in self.all_nodes:
            if node['data']['label'] == 'CHILD':
                remove_list.append(node['data']['id'])
        self.remove(remove_list)
                

    def get_elements(self):
        return {
            "nodes": [n for n in self.all_nodes if n["data"]["id"] in self.nodes],
            "edges": [e for e in self.all_edges if e["data"]["id"] in self.edges],
        }

    def remove(self, node_ids):
        pass
        print('remove',node_ids)
        self.nodes -= set(node_ids)
        self._update_edges()

    def expand(self, node_ids):
        print("expand called", node_ids)
        print("original nodes", self.nodes)
        if len(node_ids) > 1:
            print("error: multiple node ids selected")
            return
        node_id = node_ids[0]
        for node in self.all_nodes:
            if int(node["data"]['id']) == int(node_id):
                if node['data']['label'] == 'PARENT':
                    print("Expanding!!")

                    new_nodes = parent_child_dict[int(node_id)]
                    self.nodes.update(new_nodes)
                    self.nodes.remove(int(node_id))
                    print("new nodes", self.nodes)
                    
                    #3. remove the parent node
                else:
                    print("error: this node is not a parent")
                break

        
        self._update_edges()
                    

    
    # def expand(self, node_ids):
    #     new_nodes = set()
    #     node_ids = set(node_ids)
    #     for e in self.all_edges:
    #         if e["data"]["source"] in node_ids:  # outbound
    #             new_nodes.add(e["data"]["target"])
    #         elif e["data"]["target"] in node_ids:  # inbound
    #             new_nodes.add(e["data"]["source"])
    #     self.nodes |= new_nodes
    #     self._update_edges()

    def _update_edges(self):
        self.edges = {
            e["data"]["id"]
            for e in self.all_edges
            if e["data"]["source"] in self.nodes and e["data"]["target"] in self.nodes
        }


COMPONENT_KEY = "NODE_ACTIONS"

if not hasattr(st.session_state, "graph"):
    st.session_state.graph = DummyGraph()
        
# Style node & edge groups
node_styles = [
    NodeStyle("REGULAR", "#FF7F3E", "name", "regular"),
    NodeStyle("PARENT", "#2A629A", "content", "parent"),
    NodeStyle("CHILD", "#008000", "content", "child"),
]

edge_styles = [
    EdgeStyle("REGULAR_EDGE", caption='label', directed=True),
    EdgeStyle("CHILD_EDGE", caption='label', directed=True),
]


def onchange_callback():
    val = st.session_state[COMPONENT_KEY]

    print("action:", val["action"])
    
    if val["action"] == "remove":
        st.session_state.graph.remove(val["data"]["node_ids"])

    elif val["action"] == "expand":
        st.session_state.graph.expand(val["data"]["node_ids"])
        
elements = st.session_state.graph.get_elements()

print("new elements:\n",elements)

# events = [
#     Event("clicked_node", "click tap", "node"),
#     Event("another_name", "dblclick dbltap", "*"),
# ]

placeholder = st.empty()

with placeholder.container(border=True, key = "abc"):
    vals = st_link_analysis(
        elements,
        node_styles=node_styles,
        edge_styles=edge_styles,
        key=COMPONENT_KEY,
        node_actions=['remove', 'expand'],
        on_change=onchange_callback,
    )
    st.markdown("#### Returned Value")
    st.json(vals or {}, expanded=True)
    st.markdown("#### Elements")
    st.json(elements or {}, expanded=True)
    


# if (vals["action"] == 'expand'):
#     with st.container(border=True, key = "abc"):
#         vals = st_link_analysis(
#             elements,
#             node_styles=node_styles,
#             edge_styles=edge_styles,
#             key=COMPONENT_KEY,
#             node_actions=['remove', 'expand'],
#             on_change=onchange_callback,
#         )
#         st.markdown("#### Returned Value")
#         st.json(vals or {}, expanded=True)


# Render the component
# st.markdown("##: Example")
# st_link_analysis(elements, "cose", node_styles, edge_styles)

@st.cache_data
def get_source():
    with open(__file__, "r") as f:
        source = f.read()
    return source


source = get_source()
with st.expander("Source", expanded=False, icon="💻"):
    st.code(source, language="python")

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions