Skip to content

Commit 61ae0ef

Browse files
don't rely on LearnBase.getobs (#90)
1 parent 1aade04 commit 61ae0ef

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

src/GNNGraphs/GNNGraphs.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import NearestNeighbors
1111
import NNlib
1212
import LearnBase
1313
import StatsBase
14-
using LearnBase: getobs
1514
import KrylovKit
1615
using ChainRulesCore
1716
using LinearAlgebra, Random

src/GNNGraphs/utils.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,63 @@ end
128128
@non_differentiable edge_encoding(x...)
129129
@non_differentiable edge_decoding(x...)
130130

131+
132+
133+
####################################
134+
# FROM MLBASE.jl
135+
# https://github.com/JuliaML/MLBase.jl/pull/1/files
136+
# remove when package is registered
137+
##############################################
138+
139+
numobs(A::AbstractArray{<:Any, N}) where {N} = size(A, N)
140+
141+
# 0-dim arrays
142+
numobs(A::AbstractArray{<:Any, 0}) = 1
143+
144+
function getobs(A::AbstractArray{<:Any, N}, idx) where N
145+
I = ntuple(_ -> :, N-1)
146+
return A[I..., idx]
147+
end
148+
149+
getobs(A::AbstractArray{<:Any, 0}, idx) = A[idx]
150+
151+
function getobs!(buffer::AbstractArray, A::AbstractArray{<:Any, N}, idx) where N
152+
I = ntuple(_ -> :, N-1)
153+
buffer .= A[I..., idx]
154+
return buffer
155+
end
156+
157+
# --------------------------------------------------------------------
158+
# Tuples and NamedTuples
159+
160+
_check_numobs_error() =
161+
throw(DimensionMismatch("All data containers must have the same number of observations."))
162+
163+
function _check_numobs(tup::Union{Tuple, NamedTuple})
164+
length(tup) == 0 && return
165+
n1 = numobs(tup[1])
166+
for i=2:length(tup)
167+
numobs(tup[i]) != n1 && _check_numobs_error()
168+
end
169+
end
170+
171+
function numobs(tup::Union{Tuple, NamedTuple})::Int
172+
_check_numobs(tup)
173+
return length(tup) == 0 ? 0 : numobs(tup[1])
174+
end
175+
176+
function getobs(tup::Union{Tuple, NamedTuple}, indices)
177+
_check_numobs(tup)
178+
return map(x -> getobs(x, indices), tup)
179+
end
180+
181+
function getobs!(buffers::Union{Tuple, NamedTuple},
182+
tup::Union{Tuple, NamedTuple},
183+
indices)
184+
_check_numobs(tup)
185+
186+
return map(buffers, tup) do buffer, x
187+
getobs!(buffer, x, indices)
188+
end
189+
end
190+
#######################################################

0 commit comments

Comments
 (0)