Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions autogl/data/graph/_general_static_graph/_abstract_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ class HomogeneousEdgesView:
def connections(self) -> torch.LongTensor:
raise NotImplementedError

@connections.setter
def connections(self, edges: torch.LongTensor) -> None:
raise NotImplementedError

@property
def data(self) -> HomogeneousEdgesDataView:
raise NotImplementedError
Expand All @@ -91,6 +95,10 @@ class HeterogeneousEdgesView(_typing.Collection[_canonical_edge_type.CanonicalEd
def connections(self) -> torch.LongTensor:
raise NotImplementedError

@connections.setter
def connections(self, edges: torch.LongTensor) -> None:
raise NotImplementedError

@property
def data(self) -> HomogeneousEdgesDataView:
raise NotImplementedError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,10 @@ class HomogeneousEdgesContainer:
def connections(self) -> torch.Tensor:
raise NotImplementedError

@connections.setter
def connections(self, edges: torch.Tensor) -> None:
raise NotImplementedError

@property
def data_keys(self) -> _typing.Iterable[str]:
raise NotImplementedError
Expand Down Expand Up @@ -438,6 +442,10 @@ def __init__(
def connections(self) -> torch.Tensor:
return self.__connections

@connections.setter
def connections(self, edges: torch.Tensor) -> None:
self.__connections = edges

@property
def data_keys(self) -> _typing.Iterable[str]:
return self.__data.keys()
Expand Down Expand Up @@ -547,27 +555,41 @@ def _get_edges(
self, edge_t: _typing.Union[None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType] = ...
) -> HomogeneousEdgesContainer:
if edge_t in (Ellipsis, None):
if len(self.__heterogeneous_edges_data_frame) == 1:
if len(self.__heterogeneous_edges_data_frame) == 0:
raise ValueError("The graph contains no edges")
elif len(self.__heterogeneous_edges_data_frame) == 1:
return self.__heterogeneous_edges_data_frame.iloc[0]['edges']
else:
raise RuntimeError # Undetermined
raise ValueError(
"Unable to automatically determine edge type "
"since the graph contains multiple edge types"
)
elif isinstance(edge_t, str):
if ' ' in edge_t:
raise ValueError
if len(
self.__heterogeneous_edges_data_frame.loc[
self.__heterogeneous_edges_data_frame['r'] == edge_t
]
) != 1:
raise ValueError # todo: Unable to determine
else:
) == 0:
raise ValueError(f"The graph has NOT edge with relation type as {edge_t}")
elif len(
self.__heterogeneous_edges_data_frame.loc[
self.__heterogeneous_edges_data_frame['r'] == edge_t
]
) == 1:
temp: HomogeneousEdgesContainer = self.__heterogeneous_edges_data_frame.loc[
self.__heterogeneous_edges_data_frame['r'] == edge_t, 'edges'
]
if not isinstance(temp, HomogeneousEdgesContainer):
raise RuntimeError
else:
return temp
else:
raise ValueError(
f"Unable to determine canonical edge type by relation type \"{edge_t}\", "
f"since the graph contains multiple edge types with relation type as \"{edge_t}\""
)
elif isinstance(edge_t, _typing.Tuple) or isinstance(edge_t, _canonical_edge_type.CanonicalEdgeType):
if isinstance(edge_t, _typing.Tuple) and not (
len(edge_t) == 3 and
Expand Down Expand Up @@ -625,7 +647,10 @@ def _set_edges(
else HomogeneousEdgesContainerImplementation(edges)
)
else:
raise RuntimeError # todo: Unable to determine error
raise ValueError(
"Unable to set edges for heterogeneous graph consist of multiple edge types "
"with automatically determined edge type"
)
elif isinstance(edge_t, str):
if ' ' in edge_t:
raise ValueError
Expand Down Expand Up @@ -694,9 +719,9 @@ def _set_edges(
else HomogeneousEdgesContainerImplementation(edges)
)
else:
raise RuntimeError # todo: Unable to determine error
raise RuntimeError
else:
raise RuntimeError
raise TypeError("Unsupported edge type")

def _delete_edges(
self, edge_t: _typing.Union[None, str, _typing.Tuple[str, str, str], _canonical_edge_type.CanonicalEdgeType] = ...
Expand All @@ -708,7 +733,38 @@ def _delete_edges(
)
elif len(self.__heterogeneous_edges_data_frame) > 1:
raise ValueError("Edge Type must be specified for graph containing heterogeneous edges")
raise NotImplementedError # todo: Complete this function
elif isinstance(edge_t, str):
if ' ' in edge_t:
raise ValueError
if len(self.__heterogeneous_edges_data_frame) > 0:
self.__heterogeneous_edges_data_frame: pd.DataFrame = (
self.__heterogeneous_edges_data_frame[
self.__heterogeneous_edges_data_frame['r'] != edge_t
].reset_index(drop=True)
)
elif isinstance(edge_t, _typing.Tuple) or isinstance(edge_t, _canonical_edge_type.CanonicalEdgeType):
if isinstance(edge_t, _typing.Tuple) and not (
len(edge_t) == 3 and
isinstance(edge_t[0], str) and
isinstance(edge_t[1], str) and
isinstance(edge_t[2], str) and
' ' not in edge_t[0] and ' ' not in edge_t[1] and ' ' not in edge_t[2]
):
raise TypeError("Illegal canonical edge type")
__edge_t: _typing.Tuple[str, str, str] = (
(edge_t.source_node_type, edge_t.relation_type, edge_t.target_node_type)
if isinstance(edge_t, _canonical_edge_type.CanonicalEdgeType) else edge_t
)
if len(self.__heterogeneous_edges_data_frame) > 0:
self.__heterogeneous_edges_data_frame: pd.DataFrame = (
self.__heterogeneous_edges_data_frame[
(self.__heterogeneous_edges_data_frame['s'] != edge_t) |
(self.__heterogeneous_edges_data_frame['r'] != edge_t) |
(self.__heterogeneous_edges_data_frame['t'] != edge_t)
].reset_index(drop=True)
)
else:
raise TypeError("Unsupported edge type")


class _HomogeneousEdgesDataView(_abstract_views.HomogeneousEdgesDataView):
Expand Down Expand Up @@ -759,6 +815,10 @@ def __init__(self, homogeneous_edges_container: HomogeneousEdgesContainer):
def connections(self) -> torch.Tensor:
return self._homogeneous_edges_container.connections

@connections.setter
def connections(self, edges: torch.LongTensor) -> None:
self._homogeneous_edges_container.connections = edges

@property
def data(self) -> _HomogeneousEdgesDataView:
return _HomogeneousEdgesDataView(self._homogeneous_edges_container)
Expand Down Expand Up @@ -826,6 +886,10 @@ def __contains__(self, edge_type: _typing.Union[str, _typing.Tuple[str, str, str
def connections(self) -> torch.Tensor:
return self[...].connections

@connections.setter
def connections(self, edges: torch.LongTensor) -> None:
self[...].connections = edges

@property
def data(self) -> _HomogeneousEdgesDataView:
return self[...].data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,16 @@ def connections(self) -> torch.Tensor:
self.__dgl_graph_holder.graph.edges(etype=self.__get_canonical_edge_type())
)

@connections.setter
def connections(self, edges: torch.LongTensor) -> None:
self.__dgl_graph_holder.graph.remove_edges(
self.__dgl_graph_holder.graph.edges(etype=self.__get_canonical_edge_type(), form='eid'),
etype=self.__get_canonical_edge_type()
)
self.__dgl_graph_holder.graph.add_edges(
edges[0], edges[1], etype=self.__get_canonical_edge_type()
)

@property
def data(self) -> _HomogeneousEdgesDataView:
return _HomogeneousEdgesDataView(self.__dgl_graph_holder, self.__optional_edge_type)
Expand Down Expand Up @@ -463,7 +473,12 @@ def __get_canonical_edge_type(self) -> _typing.Tuple[str, str, str]:

@property
def connections(self) -> torch.Tensor:
return _HomogeneousEdgesView(self.__dgl_graph_holder, ...).connections
return self[...].connections

@connections.setter
def connections(self, edges: torch.LongTensor) -> None:
self[...].connections = edges


@property
def data(self) -> _HomogeneousEdgesDataView:
Expand Down