Skip to content

Commit 9e0ad4a

Browse files
authored
Add Flux.gpu function for TemporalSnapshotsGNNGraph type (#362)
* Add `TemporalGraphConv` * Improve names * Add function to move TSG to gpu * Add `Flux.` * Add test for `gpu` movement * Revert "Improve names" This reverts commit e8532ab. * Revert "Add `TemporalGraphConv`" This reverts commit 47b78f1. * Add test for `gpu` movement * Add functor macro for `TemporalSnapshotsGNNGraph` * Add `TEST_GPU`
1 parent 767bd2a commit 9e0ad4a

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

src/GNNGraphs/temporalsnapshotsgnngraph.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ TemporalSnapshotsGNNGraph:
3636
```
3737
"""
3838
struct TemporalSnapshotsGNNGraph
39-
num_nodes::Vector{Int}
40-
num_edges::Vector{Int}
39+
num_nodes::AbstractVector{Int}
40+
num_edges::AbstractVector{Int}
4141
num_snapshots::Int
42-
snapshots::Vector{<:GNNGraph}
42+
snapshots::AbstractVector{<:GNNGraph}
4343
tgdata::DataStore
4444
end
4545

@@ -239,4 +239,6 @@ function print_feature_t(io::IO, feature)
239239
else
240240
print(io, "no")
241241
end
242-
end
242+
end
243+
244+
@functor TemporalSnapshotsGNNGraph

test/GNNGraphs/temporalsnapshotsgnngraph.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,17 @@ end
101101
tsg.tgdata.x=rand(4)
102102
@test sprint(show,tsg) == "TemporalSnapshotsGNNGraph(5) with x: 4-element data"
103103
end
104+
105+
if TEST_GPU
106+
@testset "gpu" begin
107+
snapshots = [rand_graph(10, 20; ndata = rand(5,10)) for i in 1:5]
108+
tsg = TemporalSnapshotsGNNGraph(snapshots)
109+
tsg.tgdata.x = rand(5)
110+
tsg = Flux.gpu(tsg)
111+
@test tsg.snapshots[1].ndata.x isa CuArray
112+
@test tsg.snapshots[end].ndata.x isa CuArray
113+
@test tsg.tgdata.x isa CuArray
114+
@test tsg.num_nodes isa CuArray
115+
@test tsg.num_edges isa CuArray
116+
end
117+
end

0 commit comments

Comments
 (0)