@@ -69,6 +69,7 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
6969 fanout_array = [None ] * len (g .etypes )
7070 for etype , value in fanout .items ():
7171 fanout_array [g .get_etype_id (etype )] = value
72+ fanout_array = utils .toindex (fanout_array ).todgltensor ()
7273
7374 if prob is None :
7475 prob_arrays = [nd .array ([], ctx = nd .cpu ())] * len (g .etypes )
@@ -100,7 +101,7 @@ def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False):
100101 ----------
101102 g : DGLHeteroGraph
102103 Full graph structure.
103- k : int
104+ k : int or dict[etype, int]
104105 The K value.
105106 weight : str
106107 Feature name of the weights associated with each edge. Its shape should be
@@ -138,11 +139,16 @@ def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False):
138139 else :
139140 nodes_all_types .append (nd .array ([], ctx = nd .cpu ()))
140141
141- if not isinstance (k , list ):
142- k = [int (k )] * len (g .etypes )
143- if len (k ) != len (g .etypes ):
144- raise DGLError ('K value must be specified for each edge type '
145- 'if a list is provided.' )
142+ if not isinstance (k , dict ):
143+ k_array = [int (k )] * len (g .etypes )
144+ else :
145+ if len (k ) != len (g .etypes ):
146+ raise DGLError ('K value must be specified for each edge type '
147+ 'if a dict is provided.' )
148+ k_array = [None ] * len (g .etypes )
149+ for etype , value in k .items ():
150+ k_array [g .get_etype_id (etype )] = value
151+ k_array = utils .toindex (k_array ).todgltensor ()
146152
147153 weight_arrays = []
148154 for etype in g .canonical_etypes :
@@ -153,7 +159,7 @@ def select_topk(g, k, weight, nodes=None, edge_dir='in', ascending=False):
153159 weight , etype ))
154160
155161 subgidx = _CAPI_DGLSampleNeighborsTopk (
156- g ._graph , nodes_all_types , k , edge_dir , weight_arrays , bool (ascending ))
162+ g ._graph , nodes_all_types , k_array , edge_dir , weight_arrays , bool (ascending ))
157163 induced_edges = subgidx .induced_edges
158164 ret = DGLHeteroGraph (subgidx .graph , g .ntypes , g .etypes )
159165 for i , etype in enumerate (ret .canonical_etypes ):
0 commit comments