128
128
@non_differentiable edge_encoding (x... )
129
129
@non_differentiable edge_decoding (x... )
130
130
131
+
132
+
133
+ # ###################################
134
+ # FROM MLBASE.jl
135
+ # https://github.com/JuliaML/MLBase.jl/pull/1/files
136
+ # remove when package is registered
137
+ # #############################################
138
+
139
+ numobs (A:: AbstractArray{<:Any, N} ) where {N} = size (A, N)
140
+
141
+ # 0-dim arrays
142
+ numobs (A:: AbstractArray{<:Any, 0} ) = 1
143
+
144
+ function getobs (A:: AbstractArray{<:Any, N} , idx) where N
145
+ I = ntuple (_ -> :, N- 1 )
146
+ return A[I... , idx]
147
+ end
148
+
149
+ getobs (A:: AbstractArray{<:Any, 0} , idx) = A[idx]
150
+
151
+ function getobs! (buffer:: AbstractArray , A:: AbstractArray{<:Any, N} , idx) where N
152
+ I = ntuple (_ -> :, N- 1 )
153
+ buffer .= A[I... , idx]
154
+ return buffer
155
+ end
156
+
157
+ # --------------------------------------------------------------------
158
+ # Tuples and NamedTuples
159
+
160
+ _check_numobs_error () =
161
+ throw (DimensionMismatch (" All data containers must have the same number of observations." ))
162
+
163
+ function _check_numobs (tup:: Union{Tuple, NamedTuple} )
164
+ length (tup) == 0 && return
165
+ n1 = numobs (tup[1 ])
166
+ for i= 2 : length (tup)
167
+ numobs (tup[i]) != n1 && _check_numobs_error ()
168
+ end
169
+ end
170
+
171
+ function numobs (tup:: Union{Tuple, NamedTuple} ):: Int
172
+ _check_numobs (tup)
173
+ return length (tup) == 0 ? 0 : numobs (tup[1 ])
174
+ end
175
+
176
+ function getobs (tup:: Union{Tuple, NamedTuple} , indices)
177
+ _check_numobs (tup)
178
+ return map (x -> getobs (x, indices), tup)
179
+ end
180
+
181
+ function getobs! (buffers:: Union{Tuple, NamedTuple} ,
182
+ tup:: Union{Tuple, NamedTuple} ,
183
+ indices)
184
+ _check_numobs (tup)
185
+
186
+ return map (buffers, tup) do buffer, x
187
+ getobs! (buffer, x, indices)
188
+ end
189
+ end
190
+ # ######################################################
0 commit comments