@@ -100,11 +100,13 @@ function broadcast_edges(g::GNNGraph, x)
100
100
return gather (x, gi)
101
101
end
102
102
103
+ # return a permuted matrix according to the sorting of the sortby column
103
104
function _sort_col (matrix:: AbstractArray ; rev:: Bool = true , sortby:: Int = 1 )
104
- index = sortperm (view (matrix, sortby, : ); rev)
105
- return matrix[ :, index]
105
+ index = sortperm (view (matrix, sortby, :); rev)
106
+ return matrix[:, index]
106
107
end
107
108
109
+ # sort and reshape matrix
108
110
function _sort_matrix (matrix:: AbstractArray , k:: Int ; rev:: Bool = true , sortby = nothing )
109
111
if sortby === nothing
110
112
return sort (matrix, dims = 2 ; rev)[:, 1 : k]
@@ -113,32 +115,45 @@ function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby =
113
115
end
114
116
end
115
117
116
- function _sort_batch (matrices:: AbstractArray , k:: Int ; rev:: Bool = true , sortby = nothing )
118
+ # sort the iterator of batch matrices
119
+ function _sort_batch (matrices, k:: Int ; rev:: Bool = true , sortby = nothing )
117
120
return map (x -> _sort_matrix (x, k; rev, sortby), matrices)
118
121
end
119
122
123
+ # sort and reshape batch matrix
120
124
function _topk_batch (matrix:: AbstractArray , number_graphs:: Int , k:: Int ; rev:: Bool = true ,
121
125
sortby = nothing )
122
126
tensor_matrix = reshape (matrix, size (matrix, 1 ), size (matrix, 2 ) ÷ number_graphs,
123
127
number_graphs)
124
- sorted_matrix = _sort_batch (collect ( eachslice (tensor_matrix, dims = 3 ) ), k; rev, sortby)
128
+ sorted_matrix = _sort_batch (eachslice (tensor_matrix, dims = 3 ), k; rev, sortby)
125
129
return reduce (hcat, sorted_matrix)
126
130
end
127
131
132
+ # topk for a feature matrix
128
133
function _topk (matrix:: AbstractArray , number_graphs:: Int , k:: Int ; rev:: Bool = true ,
129
134
sortby = nothing )
130
- if number_graphs== 1
135
+ if number_graphs == 1
131
136
return _sort_matrix (matrix, k; rev, sortby)
132
137
else
133
138
return _topk_batch (matrix, number_graphs, k; rev, sortby)
134
139
end
135
140
end
136
141
142
+ """
143
+ topk_nodes(g, feat, k; rev = true, sortby = nothing)
144
+
145
+ Graph-wise top-k on node features `feat` according to the `sortby` feature index.
146
+ """
137
147
function topk_nodes (g:: GNNGraph , feat:: Symbol , k:: Int ; rev = true , sortby = nothing )
138
148
matrix = getproperty (g. ndata, feat)
139
149
return _topk (matrix, g. num_graphs, k; rev, sortby)
140
150
end
141
151
152
+ """
153
+ topk_edges(g, feat, k; rev = true, sortby = nothing)
154
+
155
+ Graph-wise top-k on edge features `feat` according to the `sortby` feature index.
156
+ """
142
157
function topk_edges (g:: GNNGraph , feat:: Symbol , k:: Int ; rev = true , sortby = nothing )
143
158
matrix = getproperty (g. edata, feat)
144
159
return _topk (matrix, g. num_graphs, k; rev, sortby)
0 commit comments