Skip to content

Commit eae9575

Browse files
authored
Improve hetero show function (#237)
* Solve indentation problem * Add test show
1 parent 829be8a commit eae9575

File tree

2 files changed

+66
-49
lines changed

2 files changed

+66
-49
lines changed

src/GNNGraphs/gnnheterograph.jl

Lines changed: 46 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,28 @@ const NDict{T} = Dict{Symbol, T}
55

66
"""
77
GNNHeteroGraph(data; ndata, edata, gdata, num_nodes, graph_indicator, dir])
8-
8+
99
A type representing a heterogeneous graph structure.
1010
it is similar [`GNNGraph`](@ref) but node and edges are of different types.
1111
1212
# Arguments
1313
14-
- `data`: A dictionary or an iterable object that maps (source_type, edge_type, target_type)
14+
- `data`: A dictionary or an iterable object that maps (source_type, edge_type, target_type)
1515
triples to (source, target) index vectors.
1616
- `num_nodes`: The number of nodes for each type. If not specified, inferred from `g`. Default `nothing`.
17-
- `graph_indicator`: For batched graphs, a dictionary of vectors containing the graph assignment of each node. Default `nothing`.
18-
- `ndata`: Node features. A dictionary of arrays or named tuple of arrays.
17+
- `graph_indicator`: For batched graphs, a dictionary of vectors containing the graph assignment of each node. Default `nothing`.
18+
- `ndata`: Node features. A dictionary of arrays or named tuple of arrays.
1919
The size of the last dimension of each array must be given by `g.num_nodes`.
2020
- `edata`: Edge features. A dictionary of arrays or named tuple of arrays.
2121
The size of the last dimension of each array must be given by `g.num_edges`.
22-
- `gdata`: Graph features. An array or named tuple of arrays whose last dimension has size `num_graphs`.
22+
- `gdata`: Graph features. An array or named tuple of arrays whose last dimension has size `num_graphs`.
2323
2424
2525
!!! warning
2626
`GNNHeteroGraph` is still experimental and not fully supported.
2727
The interface could be subject to change in the future.
2828
29-
# Examples
29+
# Examples
3030
3131
```julia
3232
julia> using Flux, GraphNeuralNetworks
@@ -43,7 +43,7 @@ julia> eindex = ((:A, :rel1, :B) => edges1, (:B, :rel2, :A) => edges2);
4343
4444
julia> hg = GNNHeteroGraph(eindex; num_nodes)
4545
GNNHeteroGraph:
46-
num_nodes: (:A => 10, :B => 20)
46+
num_nodes: (:A => 10, :B => 20)
4747
num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30)
4848
4949
julia> hg.num_edges
@@ -57,7 +57,7 @@ julia> ndata = Dict(:A => (x = rand(2, num_nodes[:A]), y = rand(3, num_nodes[:A]
5757
5858
julia> hg = GNNHeteroGraph(eindex; num_nodes, ndata)
5959
GNNHeteroGraph:
60-
num_nodes: (:A => 10, :B => 20)
60+
num_nodes: (:A => 10, :B => 20)
6161
num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30)
6262
ndata:
6363
:A => (x = 2×10 Matrix{Float64}, y = 3×10 Matrix{Float64})
@@ -88,15 +88,15 @@ end
8888

8989
@functor GNNHeteroGraph
9090

91-
function GNNHeteroGraph(data::EDict;
92-
num_nodes = nothing,
93-
graph_indicator = nothing,
94-
graph_type = :coo,
95-
dir = :out,
96-
ndata = NDict{NamedTuple}(),
97-
edata = EDict{NamedTuple}(),
98-
gdata = (;))
99-
91+
function GNNHeteroGraph(data::EDict;
92+
num_nodes = nothing,
93+
graph_indicator = nothing,
94+
graph_type = :coo,
95+
dir = :out,
96+
ndata = NDict{NamedTuple}(),
97+
edata = EDict{NamedTuple}(),
98+
gdata = (;))
99+
100100

101101
@assert graph_type [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested"
102102
@assert dir [:in, :out]
@@ -116,18 +116,18 @@ function GNNHeteroGraph(data::EDict;
116116
elseif graph_type == :sparse
117117
graph, num_nodes, num_edges = to_sparse(data; num_nodes, dir)
118118
end
119-
119+
120120
num_graphs = !isnothing(graph_indicator) ? maximum([maximum(gi) for gi in values(graph_indicator)]) : 1
121-
121+
122122
ndata = normalize_heterographdata(ndata, default_name=:x, n=num_nodes)
123123
edata = normalize_heterographdata(edata, default_name=:e, n=num_edges, duplicate_if_needed=true)
124124
gdata = normalize_graphdata(gdata, default_name=:u, n=num_graphs)
125-
126-
return GNNHeteroGraph(graph,
127-
num_nodes, num_edges, num_graphs,
128-
graph_indicator,
129-
ndata, edata, gdata,
130-
ntypes, etypes)
125+
126+
return GNNHeteroGraph(graph,
127+
num_nodes, num_edges, num_graphs,
128+
graph_indicator,
129+
ndata, edata, gdata,
130+
ntypes, etypes)
131131
end
132132

133133

@@ -138,36 +138,33 @@ end
138138
function Base.show(io::IO, ::MIME"text/plain", g::GNNHeteroGraph)
139139
if get(io, :compact, false)
140140
print(io, "GNNHeteroGraph($(g.num_nodes), $(g.num_edges))")
141-
else # if the following block is indented the printing is ruined
142-
print(io, "GNNHeteroGraph:
143-
num_nodes: $((g.num_nodes...,))
144-
num_edges: $((g.num_edges...,))")
145-
g.num_graphs > 1 && print(io, "\n num_graphs = $(g.num_graphs)")
146-
if !isempty(g.ndata)
147-
print(io, "\n ndata:")
148-
for k in keys(g.ndata)
149-
print(io, "\n ", _str(k), " => $(shortsummary(g.ndata[k]))")
150-
end
151-
end
152-
if !isempty(g.edata)
153-
print(io, "\n edata:")
154-
for k in keys(g.edata)
155-
print(io, "\n $k => $(shortsummary(g.edata[k]))")
156-
end
157-
end
158-
if !isempty(g.gdata)
159-
print(io, "\n gdata:")
160-
print(io, "\n ")
161-
shortsummary(io, g.gdata)
162-
end #else
163-
end
141+
else
142+
print(io, "GNNHeteroGraph:\n num_nodes: $((g.num_nodes...,))\n num_edges: $((g.num_edges...,))")
143+
g.num_graphs > 1 && print(io, "\n num_graphs: $(g.num_graphs)")
144+
if !isempty(g.ndata)
145+
print(io, "\n ndata:")
146+
for k in keys(g.ndata)
147+
print(io, "\n\t", _str(k), " => $(shortsummary(g.ndata[k]))")
148+
end
149+
end
150+
if !isempty(g.edata)
151+
print(io, "\n edata:")
152+
for k in keys(g.edata)
153+
print(io, "\n\t$k => $(shortsummary(g.edata[k]))")
154+
end
155+
end
156+
if !isempty(g.gdata)
157+
print(io, "\n gdata:\n\t")
158+
shortsummary(io, g.gdata)
159+
end
160+
end
164161
end
165162

166163
GNNHeteroGraph(data; kws...) = GNNHeteroGraph(Dict(data); kws...)
167164

168165
_str(s::Symbol) = ":$s"
169166
_str(s) = "$s"
170167

171-
MLUtils.numobs(g::GNNHeteroGraph) = g.num_graphs
168+
MLUtils.numobs(g::GNNHeteroGraph) = g.num_graphs
172169
# MLUtils.getobs(g::GNNHeteroGraph, i) = getgraph(g, i)
173170

test/GNNGraphs/gnnheterograph.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,24 @@ using Test
7575
@test hg.num_nodes == Dict(:A => 10, :B => 20)
7676
@test hg.num_edges == Dict((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30)
7777
end
78+
79+
@testset "show" begin
80+
num_nodes = Dict(:A => 10, :B => 20);
81+
edges1 = rand(1:num_nodes[:A], 20), rand(1:num_nodes[:B], 20)
82+
edges2 = rand(1:num_nodes[:B], 30), rand(1:num_nodes[:A], 30)
83+
eindex = ((:A, :rel1, :B) => edges1, (:B, :rel2, :A) => edges2)
84+
ndata = Dict(:A => (x = rand(2, num_nodes[:A]), y = rand(3, num_nodes[:A])),:B => rand(10, num_nodes[:B]))
85+
edata= Dict((:A, :rel1, :B) => (x = rand(2, 20), y = rand(3, 20)),(:B, :rel2, :A) => rand(10, 30))
86+
hg1 = GraphNeuralNetworks.GNNHeteroGraph(eindex; num_nodes)
87+
hg2 = GraphNeuralNetworks.GNNHeteroGraph(eindex; num_nodes, ndata,edata)
88+
hg3 = GraphNeuralNetworks.GNNHeteroGraph(eindex; num_nodes, ndata)
89+
@test sprint(show, hg1) == "GNNHeteroGraph(Dict(:A => 10, :B => 20), Dict((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30))"
90+
@test sprint(show, hg2) == sprint(show, hg1)
91+
@test sprint(show, MIME("text/plain"), hg1; context=:compact => true) == "GNNHeteroGraph(Dict(:A => 10, :B => 20), Dict((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30))"
92+
@test sprint(show, MIME("text/plain"), hg2; context=:compact => true) == sprint(show, MIME("text/plain"), hg1;context=:compact => true)
93+
@test sprint(show, MIME("text/plain"), hg1; context=:compact => false) == "GNNHeteroGraph:\n num_nodes: (:A => 10, :B => 20)\n num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30)"
94+
@test sprint(show, MIME("text/plain"), hg2; context=:compact => false) == "GNNHeteroGraph:\n num_nodes: (:A => 10, :B => 20)\n num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30)\n ndata:\n\t:A => (x = 2×10 Matrix{Float64}, y = 3×10 Matrix{Float64})\n\t:B => x = 10×20 Matrix{Float64}\n edata:\n\t(:A, :rel1, :B) => (x = 2×20 Matrix{Float64}, y = 3×20 Matrix{Float64})\n\t(:B, :rel2, :A) => e = 10×30 Matrix{Float64}"
95+
@test sprint(show, MIME("text/plain"), hg3; context=:compact => false) =="GNNHeteroGraph:\n num_nodes: (:A => 10, :B => 20)\n num_edges: ((:A, :rel1, :B) => 20, (:B, :rel2, :A) => 30)\n ndata:\n\t:A => (x = 2×10 Matrix{Float64}, y = 3×10 Matrix{Float64})\n\t:B => x = 10×20 Matrix{Float64}"
96+
@test sprint(show, MIME("text/plain"), hg2; context=:compact => false) != sprint(show, MIME("text/plain"), hg3; context=:compact => false)
97+
end
7898
end

0 commit comments

Comments
 (0)