diff --git a/autogl/data/graph/_general_static_graph/_abstract_views.py b/autogl/data/graph/_general_static_graph/_abstract_views.py index 39cd4639..b529b9ae 100644 --- a/autogl/data/graph/_general_static_graph/_abstract_views.py +++ b/autogl/data/graph/_general_static_graph/_abstract_views.py @@ -81,6 +81,10 @@ class HomogeneousEdgesView: def connections(self) -> torch.LongTensor: raise NotImplementedError + @connections.setter + def connections(self, edges: torch.LongTensor) -> None: + raise NotImplementedError + @property def data(self) -> HomogeneousEdgesDataView: raise NotImplementedError @@ -91,6 +95,10 @@ class HeterogeneousEdgesView(_typing.Collection[_canonical_edge_type.CanonicalEd def connections(self) -> torch.LongTensor: raise NotImplementedError + @connections.setter + def connections(self, edges: torch.LongTensor) -> None: + raise NotImplementedError + @property def data(self) -> HomogeneousEdgesDataView: raise NotImplementedError diff --git a/autogl/data/graph/_general_static_graph/_general_static_graph_default_implementation.py b/autogl/data/graph/_general_static_graph/_general_static_graph_default_implementation.py index b47de072..6d953baa 100644 --- a/autogl/data/graph/_general_static_graph/_general_static_graph_default_implementation.py +++ b/autogl/data/graph/_general_static_graph/_general_static_graph_default_implementation.py @@ -387,6 +387,10 @@ class HomogeneousEdgesContainer: def connections(self) -> torch.Tensor: raise NotImplementedError + @connections.setter + def connections(self, edges: torch.Tensor) -> None: + raise NotImplementedError + @property def data_keys(self) -> _typing.Iterable[str]: raise NotImplementedError @@ -438,6 +442,10 @@ def __init__( def connections(self) -> torch.Tensor: return self.__connections + @connections.setter + def connections(self, edges: torch.Tensor) -> None: + self.__connections = edges + @property def data_keys(self) -> _typing.Iterable[str]: return self.__data.keys() @@ -547,10 +555,15 @@ def _get_edges( self, edge_t: _typing.Union[None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType] = ... ) -> HomogeneousEdgesContainer: if edge_t in (Ellipsis, None): - if len(self.__heterogeneous_edges_data_frame) == 1: + if len(self.__heterogeneous_edges_data_frame) == 0: + raise ValueError("The graph contains no edges") + elif len(self.__heterogeneous_edges_data_frame) == 1: return self.__heterogeneous_edges_data_frame.iloc[0]['edges'] else: - raise RuntimeError # Undetermined + raise ValueError( + "Unable to automatically determine edge type " + "since the graph contains multiple edge types" + ) elif isinstance(edge_t, str): if ' ' in edge_t: raise ValueError @@ -558,9 +571,13 @@ def _get_edges( self.__heterogeneous_edges_data_frame.loc[ self.__heterogeneous_edges_data_frame['r'] == edge_t ] - ) != 1: - raise ValueError # todo: Unable to determine - else: + ) == 0: + raise ValueError(f"The graph has NOT edge with relation type as {edge_t}") + elif len( + self.__heterogeneous_edges_data_frame.loc[ + self.__heterogeneous_edges_data_frame['r'] == edge_t + ] + ) == 1: temp: HomogeneousEdgesContainer = self.__heterogeneous_edges_data_frame.loc[ self.__heterogeneous_edges_data_frame['r'] == edge_t, 'edges' ] @@ -568,6 +585,11 @@ def _get_edges( raise RuntimeError else: return temp + else: + raise ValueError( + f"Unable to determine canonical edge type by relation type \"{edge_t}\", " + f"since the graph contains multiple edge types with relation type as \"{edge_t}\"" + ) elif isinstance(edge_t, _typing.Tuple) or isinstance(edge_t, _canonical_edge_type.CanonicalEdgeType): if isinstance(edge_t, _typing.Tuple) and not ( len(edge_t) == 3 and @@ -625,7 +647,10 @@ def _set_edges( else HomogeneousEdgesContainerImplementation(edges) ) else: - raise RuntimeError # todo: Unable to determine error + raise ValueError( + "Unable to set edges for heterogeneous graph consist of multiple edge types " + "with automatically determined edge type" + ) elif isinstance(edge_t, str): if ' ' in edge_t: raise ValueError @@ -694,9 +719,9 @@ def _set_edges( else HomogeneousEdgesContainerImplementation(edges) ) else: - raise RuntimeError # todo: Unable to determine error + raise RuntimeError else: - raise RuntimeError + raise TypeError("Unsupported edge type") def _delete_edges( self, edge_t: _typing.Union[None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType] = ... @@ -708,7 +733,38 @@ def _delete_edges( ) elif len(self.__heterogeneous_edges_data_frame) > 1: raise ValueError("Edge Type must be specified for graph containing heterogeneous edges") - raise NotImplementedError # todo: Complete this function + elif isinstance(edge_t, str): + if ' ' in edge_t: + raise ValueError + if len(self.__heterogeneous_edges_data_frame) > 0: + self.__heterogeneous_edges_data_frame: pd.DataFrame = ( + self.__heterogeneous_edges_data_frame[ + self.__heterogeneous_edges_data_frame['r'] != edge_t + ].reset_index(drop=True) + ) + elif isinstance(edge_t, _typing.Tuple) or isinstance(edge_t, _canonical_edge_type.CanonicalEdgeType): + if isinstance(edge_t, _typing.Tuple) and not ( + len(edge_t) == 3 and + isinstance(edge_t[0], str) and + isinstance(edge_t[1], str) and + isinstance(edge_t[2], str) and + ' ' not in edge_t[0] and ' ' not in edge_t[1] and ' ' not in edge_t[2] + ): + raise TypeError("Illegal canonical edge type") + __edge_t: _typing.Tuple[str, str, str] = ( + (edge_t.source_node_type, edge_t.relation_type, edge_t.target_node_type) + if isinstance(edge_t, _canonical_edge_type.CanonicalEdgeType) else edge_t + ) + if len(self.__heterogeneous_edges_data_frame) > 0: + self.__heterogeneous_edges_data_frame: pd.DataFrame = ( + self.__heterogeneous_edges_data_frame[ + (self.__heterogeneous_edges_data_frame['s'] != edge_t) | + (self.__heterogeneous_edges_data_frame['r'] != edge_t) | + (self.__heterogeneous_edges_data_frame['t'] != edge_t) + ].reset_index(drop=True) + ) + else: + raise TypeError("Unsupported edge type") class _HomogeneousEdgesDataView(_abstract_views.HomogeneousEdgesDataView): @@ -759,6 +815,10 @@ def __init__(self, homogeneous_edges_container: HomogeneousEdgesContainer): def connections(self) -> torch.Tensor: return self._homogeneous_edges_container.connections + @connections.setter + def connections(self, edges: torch.LongTensor) -> None: + self._homogeneous_edges_container.connections = edges + @property def data(self) -> _HomogeneousEdgesDataView: return _HomogeneousEdgesDataView(self._homogeneous_edges_container) @@ -826,6 +886,10 @@ def __contains__(self, edge_type: _typing.Union[str, _typing.Tuple[str, str, str def connections(self) -> torch.Tensor: return self[...].connections + @connections.setter + def connections(self, edges: torch.LongTensor) -> None: + self[...].connections = edges + @property def data(self) -> _HomogeneousEdgesDataView: return self[...].data diff --git a/autogl/data/graph/_general_static_graph/_general_static_graph_dgl_implementation.py b/autogl/data/graph/_general_static_graph/_general_static_graph_dgl_implementation.py index b9bf413e..3c5b909a 100644 --- a/autogl/data/graph/_general_static_graph/_general_static_graph_dgl_implementation.py +++ b/autogl/data/graph/_general_static_graph/_general_static_graph_dgl_implementation.py @@ -426,6 +426,16 @@ def connections(self) -> torch.Tensor: self.__dgl_graph_holder.graph.edges(etype=self.__get_canonical_edge_type()) ) + @connections.setter + def connections(self, edges: torch.LongTensor) -> None: + self.__dgl_graph_holder.graph.remove_edges( + self.__dgl_graph_holder.graph.edges(etype=self.__get_canonical_edge_type(), form='eid'), + etype=self.__get_canonical_edge_type() + ) + self.__dgl_graph_holder.graph.add_edges( + edges[0], edges[1], etype=self.__get_canonical_edge_type() + ) + @property def data(self) -> _HomogeneousEdgesDataView: return _HomogeneousEdgesDataView(self.__dgl_graph_holder, self.__optional_edge_type) @@ -463,7 +473,12 @@ def __get_canonical_edge_type(self) -> _typing.Tuple[str, str, str]: @property def connections(self) -> torch.Tensor: - return _HomogeneousEdgesView(self.__dgl_graph_holder, ...).connections + return self[...].connections + + @connections.setter + def connections(self, edges: torch.LongTensor) -> None: + self[...].connections = edges + @property def data(self) -> _HomogeneousEdgesDataView: