@@ -100,13 +100,11 @@ function broadcast_edges(g::GNNGraph, x)
100
100
return gather (x, gi)
101
101
end
102
102
103
- # return a permuted matrix according to the sorting of the sortby column
104
103
function _sort_col (matrix:: AbstractArray ; rev:: Bool = true , sortby:: Int = 1 )
105
104
index = sortperm (view (matrix, sortby, :); rev)
106
105
return matrix[:, index]
107
106
end
108
107
109
- # sort and reshape matrix
110
108
function _sort_matrix (matrix:: AbstractArray , k:: Int ; rev:: Bool = true , sortby = nothing )
111
109
if sortby === nothing
112
110
return sort (matrix, dims = 2 ; rev)[:, 1 : k]
@@ -115,12 +113,10 @@ function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby =
115
113
end
116
114
end
117
115
118
- # sort the iterator of batch matrices
119
116
function _sort_batch (matrices, k:: Int ; rev:: Bool = true , sortby = nothing )
120
117
return map (x -> _sort_matrix (x, k; rev, sortby), matrices)
121
118
end
122
119
123
- # sort and reshape batch matrix
124
120
function _topk_batch (matrix:: AbstractArray , number_graphs:: Int , k:: Int ; rev:: Bool = true ,
125
121
sortby = nothing )
126
122
tensor_matrix = reshape (matrix, size (matrix, 1 ), size (matrix, 2 ) ÷ number_graphs,
@@ -129,7 +125,6 @@ function _topk_batch(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Boo
129
125
return reduce (hcat, sorted_matrix)
130
126
end
131
127
132
- # topk for a feature matrix
133
128
function _topk (matrix:: AbstractArray , number_graphs:: Int , k:: Int ; rev:: Bool = true ,
134
129
sortby = nothing )
135
130
if number_graphs == 1
0 commit comments