Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit c51ef55

Browse files
committed
FeaturedGraph support DataLoader
1 parent 6266a94 commit c51ef55

File tree

2 files changed

+5
-34
lines changed

2 files changed

+5
-34
lines changed

example/FlowOverCircle/src/FlowOverCircle.jl

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -71,29 +71,6 @@ function train_mno(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
7171
return learner
7272
end
7373

74-
function batch_featured_graph(data, graph, batchsize)
75-
tot_len = size(data)[end]
76-
bch_data = FeaturedGraph[]
77-
for i in 1:batchsize:tot_len
78-
bch_rng = (i + batchsize >= tot_len) ? (i:tot_len) : (i:(i + batchsize - 1))
79-
fg = FeaturedGraph(graph, nf = data[:, :, bch_rng], pf = data[:, :, bch_rng])
80-
push!(bch_data, fg)
81-
end
82-
83-
return bch_data
84-
end
85-
86-
function batch_data(data, batchsize)
87-
tot_len = size(data)[end]
88-
bch_data = Array{Float32, 3}[]
89-
for i in 1:batchsize:tot_len
90-
bch_rng = (i + batchsize >= tot_len) ? (i:tot_len) : (i:(i + batchsize - 1))
91-
push!(bch_data, data[:, :, bch_rng])
92-
end
93-
94-
return bch_data
95-
end
96-
9774
function get_gno_dataloader(; ts::AbstractRange = LinRange(100, 11000, 10000),
9875
ratio::Float64 = 0.95, batchsize = 8)
9976
data = gen_data(ts)
@@ -111,17 +88,11 @@ function get_gno_dataloader(; ts::AbstractRange = LinRange(100, 11000, 10000),
11188
# flatten
11289
𝐱, 𝐲 = reshape(𝐱, size(𝐱, 1), :, n), reshape(𝐲, 1, :, n)
11390

114-
data_train, data_test = splitobs(shuffleobs((𝐱, 𝐲)), at = ratio)
115-
116-
batched_train_X = batch_featured_graph(data_train[1], graph, batchsize)
117-
batched_test_X = batch_featured_graph(data_test[1], graph, batchsize)
118-
batched_train_y = batch_data(data_train[2], batchsize)
119-
batched_test_y = batch_data(data_test[2], batchsize)
91+
fg = FeaturedGraph(graph, nf = 𝐱, pf = 𝐱)
92+
data_train, data_test = splitobs(shuffleobs((fg, 𝐲)), at = ratio)
12093

121-
loader_train = DataLoader((batched_train_X, batched_train_y), batchsize = -1,
122-
shuffle = true)
123-
loader_test = DataLoader((batched_test_X, batched_test_y), batchsize = -1,
124-
shuffle = false)
94+
loader_train = DataLoader(data_train, batchsize = batchsize, shuffle = true)
95+
loader_test = DataLoader(data_test, batchsize = batchsize, shuffle = false)
12596

12697
return loader_train, loader_test
12798
end

src/graph_kernel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ function (l::GraphKernel)(fg::AbstractFeaturedGraph)
6666
end
6767

6868
function Base.show(io::IO, l::GraphKernel)
69-
channel, _ = size(l.linear)
69+
channel = size(l.linear, 1)
7070
print(io, "GraphKernel(", l.κ, ", channel=", channel)
7171
l.σ == identity || print(io, ", ", l.σ)
7272
print(io, ")")

0 commit comments

Comments
 (0)