@@ -99,3 +99,49 @@ function broadcast_edges(g::GNNGraph, x)
99
99
gi = graph_indicator (g, edges = true )
100
100
return gather (x, gi)
101
101
end
102
+
103
+ function _sort_row (matrix:: AbstractArray ; rev:: Bool = true , sortby:: Int = 1 )
104
+ index = sortperm (view (matrix, :, sortby); rev)
105
+ return matrix[index, :]
106
+ end
107
+
108
+ function _sort_row2 (matrix:: AbstractArray ; rev:: Bool = true , sortby:: Int = 1 ) # scalarindexing
109
+ sorted_matrix= sort (collect (eachrow (matrix)),by= x-> x[end ])
110
+ reduce (hcat,sorted_matrix)'
111
+ end
112
+
113
+ function _topk (feat:: DataStore , k:: Int ; rev:: Bool = true , sortby = nothing )
114
+ matrices = values (feat)
115
+ if sortby === nothing
116
+ return map (matrix -> sort (matrix, dims = 1 ; rev)[1 : k, :], matrices)
117
+ else
118
+ return map (matrix -> _sort_row (matrix; rev, sortby)[1 : k, :], matrices)
119
+ end
120
+ end
121
+
122
+ function _topk2 (matrices, k:: Int ; rev:: Bool = true , sortby = nothing )
123
+ if sortby === nothing
124
+ return map (matrix -> sort (matrix, dims = 1 ; rev)[1 : k, :], matrices)
125
+ else
126
+ return map (matrix -> _sort_row (matrix; rev, sortby)[1 : k, :], matrices)
127
+ end
128
+ end
129
+
130
+ function _topk_tensor (feat:: DataStore ,numgra, k:: Int ; rev:: Bool = true , sortby = nothing )
131
+ matrices = values (feat)
132
+ p= map (matrix -> reshape (matrix,size (matrix,1 ),size (matrix,2 )÷ numgra,numgra),matrices)
133
+ v= map (x -> _topk2 (collect (eachslice (x,dims= 3 )), k; rev,sortby), p)
134
+ p= map (matrix -> reduce (hcat,matrix),v)
135
+ end
136
+
137
+
138
+
139
+
140
+ function topk_nodes (g:: GNNGraph , k:: Int ; rev = true , sortby = nothing )
141
+ return _topk (g. ndata, k; rev, sortby)
142
+ end
143
+
144
+ function topk_edges (g:: GNNGraph , k:: Int ; rev = true , sortby = nothing )
145
+ return _topk (g. edata, k; rev, sortby)
146
+ end
147
+
0 commit comments