Skip to content

Commit 86f234f

Browse files
implement StatsBase.nobs instead of LearnBase.nobs (#62)
* implement StatsBase.nobs * fix test * fix test
1 parent 0567199 commit 86f234f

File tree

5 files changed

+8
-4
lines changed

5 files changed

+8
-4
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
1919
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2020
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2121
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
22+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2223
TestEnv = "1e6cf692-eddd-4d53-88a5-2d735e33781b"
2324

2425
[compat]
@@ -33,8 +34,9 @@ LearnBase = "0.4, 0.5"
3334
MacroTools = "0.5"
3435
NNlib = "0.7"
3536
NNlibCUDA = "0.1"
36-
julia = "1.6"
37+
StatsBase = "0.32, 0.33"
3738
TestEnv = "1"
39+
julia = "1.6"
3840

3941
[extras]
4042
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using Flux
1010
using Flux: glorot_uniform, leakyrelu, GRUCell, @functor, batch
1111
using MacroTools: @forward
1212
import LearnBase
13+
import StatsBase
1314
using LearnBase: getobs
1415
using NNlib, NNlibCUDA
1516
using NNlib: scatter, gather

src/gnngraph.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,8 +499,8 @@ Equivalent to [`SparseArrays.blockdiag`](@ref).
499499
"""
500500
Flux.batch(xs::Vector{<:GNNGraph}) = blockdiag(xs...)
501501

502-
### LearnBase compatibility
503-
LearnBase.nobs(g::GNNGraph) = g.num_graphs
502+
### StatsBase/LearnBase compatibility
503+
StatsBase.nobs(g::GNNGraph) = g.num_graphs
504504
LearnBase.getobs(g::GNNGraph, i) = getgraph(g, i)
505505

506506
# Flux's Dataloader compatibility. Related PR https://github.com/FluxML/Flux.jl/pull/1683

test/gnngraph.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@
272272

273273
@test LearnBase.getobs(g, 3) == getgraph(g, 3)
274274
@test LearnBase.getobs(g, 3:5) == getgraph(g, 3:5)
275-
@test LearnBase.nobs(g) == g.num_graphs
275+
@test StatsBase.nobs(g) == g.num_graphs
276276

277277
d = Flux.Data.DataLoader(g, batchsize = 2, shuffle=false)
278278
@test first(d) == getgraph(g, 1:2)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Flux: gpu, @functor
55
using LinearAlgebra, Statistics, Random
66
using NNlib
77
using LearnBase
8+
import StatsBase
89
using Graphs
910
using Zygote
1011
using Test

0 commit comments

Comments
 (0)