Skip to content

Commit 60a2f05

Browse files
authored
Add TemporalSnapshotsGraph type (#221)
* Add TemporalSnapshotGraph * Add docs
1 parent d4986f1 commit 60a2f05

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed

src/graph.jl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,104 @@ function Base.show(io::IO, ::MIME"text/plain", d::HeteroGraph)
198198
end
199199
end
200200

201+
struct TemporalSnapshotsGraph <: AbstractGraph
202+
num_nodes::Vector{Int}
203+
num_edges::Vector{Int}
204+
num_snapshots::Int
205+
snapshots::Vector{Graph}
206+
graph_data::Any
207+
end
208+
209+
210+
"""
211+
TemporalSnapshotsGraph(; kws...)
212+
213+
A type that represents a temporal snapshot graph as a sequence of [`Graph`](@ref)s and can store graph data.
214+
215+
Nodes are indexed in `1:num_nodes` and snapshots are indexed in `1:num_snapshots`.
216+
217+
# Keyword Arguments
218+
219+
- `num_nodes`: a vector containing the number of nodes at each snapshot.
220+
- `edge_index`: a tuple containing three vectors.
221+
The first vector contains the list of the source nodes of each edge, the second the target nodes at the third contains the snapshot at which each edge exists.
222+
Defaults to `(Int[], Int[], Int[])`.
223+
- `node_data`: node-related data. Can be `nothing`, a vector of named tuples of arrays or a dictionary of arrays.
224+
The arrays' last dimension size should be equal to the number of nodes.
225+
Default `nothing`.
226+
- `edge_data`: edge-related data. Can be `nothing`, a vector of named tuples of arrays or a dictionary of arrays.
227+
The arrays' last dimension size should be equal to the number of edges.
228+
Default `nothing`.
229+
- `graph_data`: graph-related data. Can be `nothing`, or a named tuple of arrays or a dictionary of arrays.
230+
231+
# Examples
232+
233+
```julia-repl
234+
julia> tg = MLDatasets.TemporalSnapshotsGraph(num_nodes = [10,10,10], edge_index= ([1,3,4,5,6,7,8],[2,6,7,1,2,10,9],[1,1,1,2,2,3,3]), node_data=[rand(3,10), rand(4,10), rand(2,10)])
235+
TemporalSnapshotsGraph:
236+
num_nodes => 3-element Vector{Int64}
237+
num_edges => 3-element Vector{Int64}
238+
num_snapsh => 3
239+
snapshots => 3-element Vector{Main.MLDatasets.Graph}
240+
graph_data => nothing
241+
242+
julia> tg.snapshots[1] # access the first snapshot
243+
Graph:
244+
num_nodes => 10
245+
num_edges => 3
246+
edge_index => ("3-element Vector{Int64}", "3-element Vector{Int64}")
247+
node_data => 3×10 Matrix{Float64}
248+
edge_data => nothing
249+
```
250+
"""
251+
function TemporalSnapshotsGraph(;
252+
num_nodes::Vector{Int},
253+
edge_index::Tuple{Vector{Int}, Vector{Int}, Vector{Int}} = (Int[], Int[], Int[]),
254+
node_data:: Union{Vector,Nothing} = nothing,
255+
edge_data:: Union{Vector,Nothing} = nothing,
256+
graph_data = nothing)
257+
258+
u, v, t = edge_index
259+
@assert length(u) == length(v) == length(t)
260+
num_snapshots = maximum(t)
261+
if !isnothing(node_data) && !isnothing(edge_data)
262+
@assert length(node_data) == length(edge_data) == num_snapshots
263+
end
264+
265+
snapshots = Vector{Graph}(undef, num_snapshots)
266+
num_edges = Vector{Int}(undef, num_snapshots)
267+
for i in 1:num_snapshots
268+
if !isnothing(node_data) && !isnothing(edge_data)
269+
snapshot = Graph(num_nodes[i], sum(t.==i), (u[t.==i], v[t.==i]), node_data[i], edge_data[i])
270+
elseif !isnothing(node_data)
271+
snapshot = Graph(num_nodes[i], sum(t.==i), (u[t.==i], v[t.==i]), node_data[i],nothing)
272+
elseif !isnothing(edge_data)
273+
snapshot = Graph(num_nodes[i], sum(t.==i), (u[t.==i], v[t.==i]), nothing, edge_data[i])
274+
else
275+
snapshot = Graph(num_nodes[i], sum(t.==i), (u[t.==i], v[t.==i]), nothing, nothing)
276+
end
277+
snapshots[i] = snapshot
278+
num_edges[i] = sum(t.==i)
279+
end
280+
return TemporalSnapshotsGraph(num_nodes, num_edges, num_snapshots, snapshots, graph_data)
281+
end
282+
283+
function Base.show(io::IO, d::TemporalSnapshotsGraph)
284+
print(io, "TemporalSnapshotsGraph($(d.num_nodes), $(d.num_edges), $(d.num_snapshots))")
285+
end
286+
287+
function Base.show(io::IO, ::MIME"text/plain", d::TemporalSnapshotsGraph)
288+
recur_io = IOContext(io, :compact => false)
289+
print(io, "TemporalSnapshotsGraph:")
290+
for f in fieldnames(TemporalSnapshotsGraph)
291+
if !startswith(string(f), "_")
292+
fstring = leftalign(string(f), 10)
293+
print(recur_io, "\n $fstring => ")
294+
print(recur_io, "$(_summary(getfield(d, f)))")
295+
end
296+
end
297+
end
298+
201299
# Transform an adjacency list to edge index.
202300
# If inneigs = true, assume neighbors from incoming edges.
203301
function adjlist2edgeindex(adj; inneigs = false)

0 commit comments

Comments
 (0)