@@ -71,29 +71,6 @@ function train_mno(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
71
71
return learner
72
72
end
73
73
74
- function batch_featured_graph (data, graph, batchsize)
75
- tot_len = size (data)[end ]
76
- bch_data = FeaturedGraph[]
77
- for i in 1 : batchsize: tot_len
78
- bch_rng = (i + batchsize >= tot_len) ? (i: tot_len) : (i: (i + batchsize - 1 ))
79
- fg = FeaturedGraph (graph, nf = data[:, :, bch_rng], pf = data[:, :, bch_rng])
80
- push! (bch_data, fg)
81
- end
82
-
83
- return bch_data
84
- end
85
-
86
- function batch_data (data, batchsize)
87
- tot_len = size (data)[end ]
88
- bch_data = Array{Float32, 3 }[]
89
- for i in 1 : batchsize: tot_len
90
- bch_rng = (i + batchsize >= tot_len) ? (i: tot_len) : (i: (i + batchsize - 1 ))
91
- push! (bch_data, data[:, :, bch_rng])
92
- end
93
-
94
- return bch_data
95
- end
96
-
97
74
function get_gno_dataloader (; ts:: AbstractRange = LinRange (100 , 11000 , 10000 ),
98
75
ratio:: Float64 = 0.95 , batchsize = 8 )
99
76
data = gen_data (ts)
@@ -111,17 +88,11 @@ function get_gno_dataloader(; ts::AbstractRange = LinRange(100, 11000, 10000),
111
88
# flatten
112
89
𝐱, 𝐲 = reshape (𝐱, size (𝐱, 1 ), :, n), reshape (𝐲, 1 , :, n)
113
90
114
- data_train, data_test = splitobs (shuffleobs ((𝐱, 𝐲)), at = ratio)
115
-
116
- batched_train_X = batch_featured_graph (data_train[1 ], graph, batchsize)
117
- batched_test_X = batch_featured_graph (data_test[1 ], graph, batchsize)
118
- batched_train_y = batch_data (data_train[2 ], batchsize)
119
- batched_test_y = batch_data (data_test[2 ], batchsize)
91
+ fg = FeaturedGraph (graph, nf = 𝐱, pf = 𝐱)
92
+ data_train, data_test = splitobs (shuffleobs ((fg, 𝐲)), at = ratio)
120
93
121
- loader_train = DataLoader ((batched_train_X, batched_train_y), batchsize = - 1 ,
122
- shuffle = true )
123
- loader_test = DataLoader ((batched_test_X, batched_test_y), batchsize = - 1 ,
124
- shuffle = false )
94
+ loader_train = DataLoader (data_train, batchsize = batchsize, shuffle = true )
95
+ loader_test = DataLoader (data_test, batchsize = batchsize, shuffle = false )
125
96
126
97
return loader_train, loader_test
127
98
end
0 commit comments