@@ -5,28 +5,28 @@ const NDict{T} = Dict{Symbol, T}
5
5
6
6
"""
7
7
GNNHeteroGraph(data; ndata, edata, gdata, num_nodes, graph_indicator, dir])
8
-
8
+
9
9
A type representing a heterogeneous graph structure.
10
10
it is similar [`GNNGraph`](@ref) but node and edges are of different types.
11
11
12
12
# Arguments
13
13
14
- - `data`: A dictionary or an iterable object that maps (source_type, edge_type, target_type)
14
+ - `data`: A dictionary or an iterable object that maps (source_type, edge_type, target_type)
15
15
triples to (source, target) index vectors.
16
16
- `num_nodes`: The number of nodes for each type. If not specified, inferred from `g`. Default `nothing`.
17
- - `graph_indicator`: For batched graphs, a dictionary of vectors containing the graph assignment of each node. Default `nothing`.
18
- - `ndata`: Node features. A dictionary of arrays or named tuple of arrays.
17
+ - `graph_indicator`: For batched graphs, a dictionary of vectors containing the graph assignment of each node. Default `nothing`.
18
+ - `ndata`: Node features. A dictionary of arrays or named tuple of arrays.
19
19
The size of the last dimension of each array must be given by `g.num_nodes`.
20
20
- `edata`: Edge features. A dictionary of arrays or named tuple of arrays.
21
21
The size of the last dimension of each array must be given by `g.num_edges`.
22
- - `gdata`: Graph features. An array or named tuple of arrays whose last dimension has size `num_graphs`.
22
+ - `gdata`: Graph features. An array or named tuple of arrays whose last dimension has size `num_graphs`.
23
23
24
24
25
25
!!! warning
26
26
`GNNHeteroGraph` is still experimental and not fully supported.
27
27
The interface could be subject to change in the future.
28
28
29
- # Examples
29
+ # Examples
30
30
31
31
```julia
32
32
julia> using Flux, GraphNeuralNetworks
@@ -43,7 +43,7 @@ julia> eindex = ((:A, :rel1, :B) => edges1, (:B, :rel2, :A) => edges2);
43
43
44
44
julia> hg = GNNHeteroGraph(eindex; num_nodes)
45
45
GNNHeteroGraph:
46
- num_nodes: (:A => 10, :B => 20)
46
+ num_nodes: (:A => 10, :B => 20)
47
47
num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30)
48
48
49
49
julia> hg.num_edges
@@ -57,7 +57,7 @@ julia> ndata = Dict(:A => (x = rand(2, num_nodes[:A]), y = rand(3, num_nodes[:A]
57
57
58
58
julia> hg = GNNHeteroGraph(eindex; num_nodes, ndata)
59
59
GNNHeteroGraph:
60
- num_nodes: (:A => 10, :B => 20)
60
+ num_nodes: (:A => 10, :B => 20)
61
61
num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30)
62
62
ndata:
63
63
:A => (x = 2×10 Matrix{Float64}, y = 3×10 Matrix{Float64})
88
88
89
89
@functor GNNHeteroGraph
90
90
91
- function GNNHeteroGraph (data:: EDict ;
92
- num_nodes = nothing ,
93
- graph_indicator = nothing ,
94
- graph_type = :coo ,
95
- dir = :out ,
96
- ndata = NDict {NamedTuple} (),
97
- edata = EDict {NamedTuple} (),
98
- gdata = (;))
99
-
91
+ function GNNHeteroGraph (data:: EDict ;
92
+ num_nodes = nothing ,
93
+ graph_indicator = nothing ,
94
+ graph_type = :coo ,
95
+ dir = :out ,
96
+ ndata = NDict {NamedTuple} (),
97
+ edata = EDict {NamedTuple} (),
98
+ gdata = (;))
99
+
100
100
101
101
@assert graph_type ∈ [:coo , :dense , :sparse ] " Invalid graph_type $graph_type requested"
102
102
@assert dir ∈ [:in , :out ]
@@ -116,18 +116,18 @@ function GNNHeteroGraph(data::EDict;
116
116
elseif graph_type == :sparse
117
117
graph, num_nodes, num_edges = to_sparse (data; num_nodes, dir)
118
118
end
119
-
119
+
120
120
num_graphs = ! isnothing (graph_indicator) ? maximum ([maximum (gi) for gi in values (graph_indicator)]) : 1
121
-
121
+
122
122
ndata = normalize_heterographdata (ndata, default_name= :x , n= num_nodes)
123
123
edata = normalize_heterographdata (edata, default_name= :e , n= num_edges, duplicate_if_needed= true )
124
124
gdata = normalize_graphdata (gdata, default_name= :u , n= num_graphs)
125
-
126
- return GNNHeteroGraph (graph,
127
- num_nodes, num_edges, num_graphs,
128
- graph_indicator,
129
- ndata, edata, gdata,
130
- ntypes, etypes)
125
+
126
+ return GNNHeteroGraph (graph,
127
+ num_nodes, num_edges, num_graphs,
128
+ graph_indicator,
129
+ ndata, edata, gdata,
130
+ ntypes, etypes)
131
131
end
132
132
133
133
@@ -138,36 +138,33 @@ end
138
138
function Base. show (io:: IO , :: MIME"text/plain" , g:: GNNHeteroGraph )
139
139
if get (io, :compact , false )
140
140
print (io, " GNNHeteroGraph($(g. num_nodes) , $(g. num_edges) )" )
141
- else # if the following block is indented the printing is ruined
142
- print (io, " GNNHeteroGraph:
143
- num_nodes: $((g. num_nodes... ,))
144
- num_edges: $((g. num_edges... ,)) " )
145
- g. num_graphs > 1 && print (io, " \n num_graphs = $(g. num_graphs) " )
146
- if ! isempty (g. ndata)
147
- print (io, " \n ndata:" )
148
- for k in keys (g. ndata)
149
- print (io, " \n " , _str (k), " => $(shortsummary (g. ndata[k])) " )
150
- end
151
- end
152
- if ! isempty (g. edata)
153
- print (io, " \n edata:" )
154
- for k in keys (g. edata)
155
- print (io, " \n $k => $(shortsummary (g. edata[k])) " )
156
- end
157
- end
158
- if ! isempty (g. gdata)
159
- print (io, " \n gdata:" )
160
- print (io, " \n " )
161
- shortsummary (io, g. gdata)
162
- end # else
163
- end
141
+ else
142
+ print (io, " GNNHeteroGraph:\n num_nodes: $((g. num_nodes... ,)) \n num_edges: $((g. num_edges... ,)) " )
143
+ g. num_graphs > 1 && print (io, " \n num_graphs: $(g. num_graphs) " )
144
+ if ! isempty (g. ndata)
145
+ print (io, " \n ndata:" )
146
+ for k in keys (g. ndata)
147
+ print (io, " \n\t " , _str (k), " => $(shortsummary (g. ndata[k])) " )
148
+ end
149
+ end
150
+ if ! isempty (g. edata)
151
+ print (io, " \n edata:" )
152
+ for k in keys (g. edata)
153
+ print (io, " \n\t $k => $(shortsummary (g. edata[k])) " )
154
+ end
155
+ end
156
+ if ! isempty (g. gdata)
157
+ print (io, " \n gdata:\n\t " )
158
+ shortsummary (io, g. gdata)
159
+ end
160
+ end
164
161
end
165
162
166
163
GNNHeteroGraph (data; kws... ) = GNNHeteroGraph (Dict (data); kws... )
167
164
168
165
_str (s:: Symbol ) = " :$s "
169
166
_str (s) = " $s "
170
167
171
- MLUtils. numobs (g:: GNNHeteroGraph ) = g. num_graphs
168
+ MLUtils. numobs (g:: GNNHeteroGraph ) = g. num_graphs
172
169
# MLUtils.getobs(g::GNNHeteroGraph, i) = getgraph(g, i)
173
170
0 commit comments