Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit d7e211e

Browse files
committed
add same resolution
1 parent 9dfc134 commit d7e211e

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

example/SuperResolution/src/data.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,19 @@ function get_dataloader(; ts::AbstractRange=LinRange(100, 11000, 10000), ratio::
4141

4242
return loader_train, loader_test
4343
end
44+
45+
function get_same_resolution(; ts::AbstractRange=LinRange(100, 11000, 10000), ratio::Real=0.95, batchsize=100)
46+
data = gen_data(ts)
47+
48+
n_train, n_test = floor(Int, length(ts)*ratio), floor(Int, length(ts)*(1-ratio))
49+
50+
𝐱_train, 𝐲_train = data[:, 1:2:end, 1:2:end, 1:(n_train-1)], data[:, 1:2:end, 1:2:end, 2:n_train]
51+
𝐱_train, 𝐲_train = reshape(𝐱_train, 1, :, n_train-1), reshape(𝐲_train, 1, :, n_train-1)
52+
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true)
53+
54+
𝐱_test, 𝐲_test = data[:, 1:2:end, 1:2:end, (end-n_test+1):(end-1)], data[:, 1:2:end, 1:2:end, (end-n_test+2):end]
55+
𝐱_test, 𝐲_test = reshape(𝐱_test, 1, :, n_test-1), reshape(𝐲_test, 1, :, n_test-1)
56+
loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=batchsize, shuffle=false)
57+
58+
return loader_train, loader_test
59+
end

example/SuperResolution/src/models.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ function train_mno(; cuda=true, η=1f-3, λ=1f-4, epochs=50)
6060
return m
6161
end
6262

63-
function train_gno(; channel=64, cuda=true, η=1f-3, λ=1f-4, epochs=50)
63+
function train_gno(; channel=64, cuda=true, η=1f-3, λ=1f-4, epochs=50, batchsize=64)
6464
# GPU config
6565
if cuda && CUDA.has_cuda()
6666
device = gpu
@@ -72,10 +72,10 @@ function train_gno(; channel=64, cuda=true, η=1f-3, λ=1f-4, epochs=50)
7272
end
7373

7474
@info "gen data... "
75-
@time loader_train, loader_test = get_dataloader()
75+
@time loader_train, loader_test = get_same_resolution(batchsize=batchsize)
7676

7777
# build model
78-
g = grid([12, 8])
78+
g = grid([96, 64])
7979
fg = FeaturedGraph(g)
8080

8181
m = Chain(

0 commit comments

Comments
 (0)