Skip to content

Commit f72cacb

Browse files
committed
Add functions
1 parent c5bd656 commit f72cacb

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

src/utils.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,49 @@ function broadcast_edges(g::GNNGraph, x)
9999
gi = graph_indicator(g, edges = true)
100100
return gather(x, gi)
101101
end
102+
103+
function _sort_row(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1)
104+
index = sortperm(view(matrix, :, sortby); rev)
105+
return matrix[index, :]
106+
end
107+
108+
function _sort_row2(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1) #scalarindexing
109+
sorted_matrix=sort(collect(eachrow(matrix)),by= x->x[end])
110+
reduce(hcat,sorted_matrix)'
111+
end
112+
113+
function _topk(feat::DataStore, k::Int; rev::Bool = true, sortby = nothing)
114+
matrices = values(feat)
115+
if sortby === nothing
116+
return map(matrix -> sort(matrix, dims = 1; rev)[1:k, :], matrices)
117+
else
118+
return map(matrix -> _sort_row(matrix; rev, sortby)[1:k, :], matrices)
119+
end
120+
end
121+
122+
function _topk2(matrices, k::Int; rev::Bool = true, sortby = nothing)
123+
if sortby === nothing
124+
return map(matrix -> sort(matrix, dims = 1; rev)[1:k, :], matrices)
125+
else
126+
return map(matrix -> _sort_row(matrix; rev, sortby)[1:k, :], matrices)
127+
end
128+
end
129+
130+
function _topk_tensor(feat::DataStore,numgra, k::Int; rev::Bool = true, sortby = nothing)
131+
matrices = values(feat)
132+
p=map(matrix -> reshape(matrix,size(matrix,1),size(matrix,2)÷numgra,numgra),matrices)
133+
v=map(x -> _topk2(collect(eachslice(x,dims=3)), k; rev,sortby), p)
134+
p=map(matrix -> reduce(hcat,matrix),v)
135+
end
136+
137+
138+
139+
140+
function topk_nodes(g::GNNGraph, k::Int; rev = true, sortby = nothing)
141+
return _topk(g.ndata, k; rev, sortby)
142+
end
143+
144+
function topk_edges(g::GNNGraph, k::Int; rev = true, sortby = nothing)
145+
return _topk(g.edata, k; rev, sortby)
146+
end
147+

0 commit comments

Comments
 (0)