Skip to content

Commit 23992ca

Browse files
fix cat empty features (#286)
1 parent 2527750 commit 23992ca

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

src/GNNGraphs/datastore.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@ function Base.getproperty(ds::DataStore, s::Symbol)
106106
end
107107

108108
function Base.setproperty!(ds::DataStore, s::Symbol, x)
109-
@assert s!=:_n "cannot set _n directly"
110-
@assert s!=:_data "cannot set _data directly"
111-
if getn(ds) > 0
112-
@assert numobs(x)==getn(ds) "expected (numobs(x) == getn(ds)) but got $(numobs(x)) != $(getn(ds))"
109+
@assert s != :_n "cannot set _n directly"
110+
@assert s != :_data "cannot set _data directly"
111+
if getn(ds) >= 0
112+
@assert numobs(x) == getn(ds) "expected (numobs(x) == getn(ds)) but got $(numobs(x)) != $(getn(ds))"
113113
end
114114
return getdata(ds)[s] = x
115115
end
@@ -164,7 +164,7 @@ function MLUtils.getobs(ds::DataStore,
164164
i::AbstractVector{T}) where {T <: Union{Integer, Bool}}
165165
newdata = getobs(getdata(ds), i)
166166
n = getn(ds)
167-
if n > -1
167+
if n >= 0
168168
if length(ds) > 0
169169
n = numobs(newdata)
170170
else
@@ -180,14 +180,14 @@ end
180180

181181
function cat_features(ds1::DataStore, ds2::DataStore)
182182
n1, n2 = getn(ds1), getn(ds2)
183-
n1 = n1 > 0 ? n1 : 1
184-
n2 = n2 > 0 ? n2 : 1
183+
n1 = n1 >= 0 ? n1 : 1
184+
n2 = n2 >= 0 ? n2 : 1
185185
return DataStore(n1 + n2, cat_features(getdata(ds1), getdata(ds2)))
186186
end
187187

188188
function cat_features(dss::AbstractVector{DataStore}; kws...)
189189
ns = getn.(dss)
190-
ns = map(n -> n > 0 ? n : 1, ns)
190+
ns = map(n -> n >= 0 ? n : 1, ns)
191191
return DataStore(sum(ns), cat_features(getdata.(dss); kws...))
192192
end
193193

test/GNNGraphs/datastore.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,24 @@ end
3939
@test_throws KeyError ds.n
4040
end
4141

42+
@testset "cat empty" begin
43+
ds1 = DataStore(2, (:x => rand(2)))
44+
ds2 = DataStore(1, (:x => rand(1)))
45+
dsempty = DataStore(0, (:x => rand(0)))
46+
47+
ds = GNNGraphs.cat_features(ds1, ds2)
48+
@test getn(ds) == 3
49+
ds = GNNGraphs.cat_features(ds1, dsempty)
50+
@test getn(ds) == 2
51+
52+
# issue #280
53+
g = GNNGraph([1], [2])
54+
h = add_edges(g, Int[], Int[]) # adds no edges
55+
@test getn(g.edata) == 1
56+
@test getn(h.edata) == 1
57+
end
58+
59+
4260
@testset "gradient" begin
4361
ds = DataStore(10, (:x => rand(10), :y => rand(2, 10)))
4462

0 commit comments

Comments
 (0)