Skip to content

Commit cc0a015

Browse files
committed
Fix functions
1 parent 92e7314 commit cc0a015

File tree

1 file changed

+25
-27
lines changed

1 file changed

+25
-27
lines changed

src/utils.jl

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -105,43 +105,41 @@ function _sort_row(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1)
105105
return matrix[index, :]
106106
end
107107

108-
function _sort_row2(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1) #scalarindexing
109-
sorted_matrix=sort(collect(eachrow(matrix)),by= x->x[end])
110-
reduce(hcat,sorted_matrix)'
111-
end
112-
113-
function _topk(feat::DataStore, k::Int; rev::Bool = true, sortby = nothing)
114-
matrices = values(feat)
108+
function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = nothing)
115109
if sortby === nothing
116-
return map(matrix -> sort(matrix, dims = 1; rev)[1:k, :], matrices)
110+
return sort(matrix, dims = 1; rev)[1:k, :]
117111
else
118-
return map(matrix -> _sort_row(matrix; rev, sortby)[1:k, :], matrices)
112+
return _sort_row(matrix; rev, sortby)[1:k, :]
119113
end
120114
end
121115

122-
function _topk2(matrices, k::Int; rev::Bool = true, sortby = nothing)
123-
if sortby === nothing
124-
return map(matrix -> sort(matrix, dims = 1; rev)[1:k, :], matrices)
125-
else
126-
return map(matrix -> _sort_row(matrix; rev, sortby)[1:k, :], matrices)
127-
end
116+
function _sort_batch(matrices::AbstractArray, k::Int; rev::Bool = true, sortby = nothing)
117+
return map(x -> _sort_matrix(x, k; rev, sortby), matrices)
128118
end
129119

130-
function _topk_tensor(feat::DataStore,numgra, k::Int; rev::Bool = true, sortby = nothing)
131-
matrices = values(feat)
132-
p=map(matrix -> reshape(matrix,size(matrix,1),size(matrix,2)÷numgra,numgra),matrices)
133-
v=map(x -> _topk2(collect(eachslice(x,dims=3)), k; rev,sortby), p)
134-
p=map(matrix -> reduce(hcat,matrix),v)
120+
function _topk_batch(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true,
121+
sortby = nothing)
122+
tensor_matrix = reshape(matrix, size(matrix, 1), size(matrix, 2) ÷ number_graphs,
123+
number_graphs)
124+
sorted_matrix = _sort_batch(collect(eachslice(tensor_matrix, dims = 3)), k; rev, sortby)
125+
return reduce(hcat, sorted_matrix)
135126
end
136127

137-
138-
139-
140-
function topk_nodes(g::GNNGraph, k::Int; rev = true, sortby = nothing)
141-
return _topk(g.ndata, k; rev, sortby)
128+
function _topk(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true,
129+
sortby = nothing)
130+
if number_graphs==1
131+
return _sort_matrix(matrix, k; rev, sortby)
132+
else
133+
return _topk_batch(matrix, number_graphs, k; rev, sortby)
134+
end
142135
end
143136

144-
function topk_edges(g::GNNGraph, k::Int; rev = true, sortby = nothing)
145-
return _topk(g.edata, k; rev, sortby)
137+
function topk_nodes(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing)
138+
matrix = getproperty(g.ndata, feat)
139+
return _topk(matrix, g.num_graphs, k; rev, sortby)
146140
end
147141

142+
function topk_edges(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing)
143+
matrix = getproperty(g.edata, feat)
144+
return _topk(matrix, g.num_graphs, k; rev, sortby)
145+
end

0 commit comments

Comments
 (0)