Skip to content

Commit e69402e

Browse files
committed
Add docstrings
1 parent 9de994f commit e69402e

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

src/utils.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,13 @@ function broadcast_edges(g::GNNGraph, x)
100100
return gather(x, gi)
101101
end
102102

103+
# return a permuted matrix according to the sorting of the sortby column
103104
function _sort_col(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1)
104-
index = sortperm(view(matrix, sortby, : ); rev)
105-
return matrix[ :, index]
105+
index = sortperm(view(matrix, sortby, :); rev)
106+
return matrix[:, index]
106107
end
107108

109+
# sort and reshape matrix
108110
function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = nothing)
109111
if sortby === nothing
110112
return sort(matrix, dims = 2; rev)[:, 1:k]
@@ -113,32 +115,45 @@ function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby =
113115
end
114116
end
115117

116-
function _sort_batch(matrices::AbstractArray, k::Int; rev::Bool = true, sortby = nothing)
118+
# sort the iterator of batch matrices
119+
function _sort_batch(matrices, k::Int; rev::Bool = true, sortby = nothing)
117120
return map(x -> _sort_matrix(x, k; rev, sortby), matrices)
118121
end
119122

123+
# sort and reshape batch matrix
120124
function _topk_batch(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true,
121125
sortby = nothing)
122126
tensor_matrix = reshape(matrix, size(matrix, 1), size(matrix, 2) ÷ number_graphs,
123127
number_graphs)
124-
sorted_matrix = _sort_batch(collect(eachslice(tensor_matrix, dims = 3)), k; rev, sortby)
128+
sorted_matrix = _sort_batch(eachslice(tensor_matrix, dims = 3), k; rev, sortby)
125129
return reduce(hcat, sorted_matrix)
126130
end
127131

132+
# topk for a feature matrix
128133
function _topk(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true,
129134
sortby = nothing)
130-
if number_graphs==1
135+
if number_graphs == 1
131136
return _sort_matrix(matrix, k; rev, sortby)
132137
else
133138
return _topk_batch(matrix, number_graphs, k; rev, sortby)
134139
end
135140
end
136141

142+
"""
143+
topk_nodes(g, feat, k; rev = true, sortby = nothing)
144+
145+
Graph-wise top-k on node features `feat` according to the `sortby` feature index.
146+
"""
137147
function topk_nodes(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing)
138148
matrix = getproperty(g.ndata, feat)
139149
return _topk(matrix, g.num_graphs, k; rev, sortby)
140150
end
141151

152+
"""
153+
topk_edges(g, feat, k; rev = true, sortby = nothing)
154+
155+
Graph-wise top-k on edge features `feat` according to the `sortby` feature index.
156+
"""
142157
function topk_edges(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing)
143158
matrix = getproperty(g.edata, feat)
144159
return _topk(matrix, g.num_graphs, k; rev, sortby)

0 commit comments

Comments
 (0)