@@ -105,43 +105,41 @@ function _sort_row(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1)
105
105
return matrix[index, :]
106
106
end
107
107
108
- function _sort_row2 (matrix:: AbstractArray ; rev:: Bool = true , sortby:: Int = 1 ) # scalarindexing
109
- sorted_matrix= sort (collect (eachrow (matrix)),by= x-> x[end ])
110
- reduce (hcat,sorted_matrix)'
111
- end
112
-
113
- function _topk (feat:: DataStore , k:: Int ; rev:: Bool = true , sortby = nothing )
114
- matrices = values (feat)
108
+ function _sort_matrix (matrix:: AbstractArray , k:: Int ; rev:: Bool = true , sortby = nothing )
115
109
if sortby === nothing
116
- return map (matrix -> sort (matrix, dims = 1 ; rev)[1 : k, :], matrices)
110
+ return sort (matrix, dims = 1 ; rev)[1 : k, :]
117
111
else
118
- return map (matrix -> _sort_row (matrix; rev, sortby)[1 : k, :], matrices)
112
+ return _sort_row (matrix; rev, sortby)[1 : k, :]
119
113
end
120
114
end
121
115
122
- function _topk2 (matrices, k:: Int ; rev:: Bool = true , sortby = nothing )
123
- if sortby === nothing
124
- return map (matrix -> sort (matrix, dims = 1 ; rev)[1 : k, :], matrices)
125
- else
126
- return map (matrix -> _sort_row (matrix; rev, sortby)[1 : k, :], matrices)
127
- end
116
+ function _sort_batch (matrices:: AbstractArray , k:: Int ; rev:: Bool = true , sortby = nothing )
117
+ return map (x -> _sort_matrix (x, k; rev, sortby), matrices)
128
118
end
129
119
130
- function _topk_tensor (feat:: DataStore ,numgra, k:: Int ; rev:: Bool = true , sortby = nothing )
131
- matrices = values (feat)
132
- p= map (matrix -> reshape (matrix,size (matrix,1 ),size (matrix,2 )÷ numgra,numgra),matrices)
133
- v= map (x -> _topk2 (collect (eachslice (x,dims= 3 )), k; rev,sortby), p)
134
- p= map (matrix -> reduce (hcat,matrix),v)
120
+ function _topk_batch (matrix:: AbstractArray , number_graphs:: Int , k:: Int ; rev:: Bool = true ,
121
+ sortby = nothing )
122
+ tensor_matrix = reshape (matrix, size (matrix, 1 ), size (matrix, 2 ) ÷ number_graphs,
123
+ number_graphs)
124
+ sorted_matrix = _sort_batch (collect (eachslice (tensor_matrix, dims = 3 )), k; rev, sortby)
125
+ return reduce (hcat, sorted_matrix)
135
126
end
136
127
137
-
138
-
139
-
140
- function topk_nodes (g:: GNNGraph , k:: Int ; rev = true , sortby = nothing )
141
- return _topk (g. ndata, k; rev, sortby)
128
+ function _topk (matrix:: AbstractArray , number_graphs:: Int , k:: Int ; rev:: Bool = true ,
129
+ sortby = nothing )
130
+ if number_graphs== 1
131
+ return _sort_matrix (matrix, k; rev, sortby)
132
+ else
133
+ return _topk_batch (matrix, number_graphs, k; rev, sortby)
134
+ end
142
135
end
143
136
144
- function topk_edges (g:: GNNGraph , k:: Int ; rev = true , sortby = nothing )
145
- return _topk (g. edata, k; rev, sortby)
137
+ function topk_nodes (g:: GNNGraph , feat:: Symbol , k:: Int ; rev = true , sortby = nothing )
138
+ matrix = getproperty (g. ndata, feat)
139
+ return _topk (matrix, g. num_graphs, k; rev, sortby)
146
140
end
147
141
142
+ function topk_edges (g:: GNNGraph , feat:: Symbol , k:: Int ; rev = true , sortby = nothing )
143
+ matrix = getproperty (g. edata, feat)
144
+ return _topk (matrix, g. num_graphs, k; rev, sortby)
145
+ end
0 commit comments