88from .. import backend as F
99from .. import utils
1010from ..convert import heterograph
11+ from ..distributed .dist_graph import DistGraph
1112
1213# pylint: disable=unused-argument
1314def assign_block_eids (block , frontier ):
@@ -244,6 +245,7 @@ def sample_blocks(self, g, seed_nodes, exclude_eids=None):
244245 assign_block_eids (block , frontier )
245246
246247 seed_nodes = {ntype : block .srcnodes [ntype ].data [NID ] for ntype in block .srctypes }
248+
247249 # Pre-generate CSR format so that it can be used in training directly
248250 block .create_formats_ ()
249251 blocks .insert (0 , block )
@@ -309,6 +311,7 @@ class NodeCollator(Collator):
309311 """
310312 def __init__ (self , g , nids , block_sampler ):
311313 self .g = g
314+ self ._is_distributed = isinstance (g , DistGraph )
312315 if not isinstance (nids , Mapping ):
313316 assert len (g .ntypes ) == 1 , \
314317 "nids should be a dict of node type and ids for graph with multiple node types"
@@ -352,6 +355,15 @@ def collate(self, items):
352355 if isinstance (items [0 ], tuple ):
353356 # returns a list of pairs: group them by node types into a dict
354357 items = utils .group_as_dict (items )
358+
359+ # TODO(BarclayII) Because DistGraph doesn't have idtype and device implemented,
360+ # this function does not work. I'm again skipping this step as a workaround.
361+ # We need to fix this.
362+ if not self ._is_distributed :
363+ if isinstance (items , dict ):
364+ items = utils .prepare_tensor_dict (self .g , items , 'items' )
365+ else :
366+ items = utils .prepare_tensor (self .g , items , 'items' )
355367 blocks = self .block_sampler .sample_blocks (self .g , items )
356368 output_nodes = blocks [- 1 ].dstdata [NID ]
357369 input_nodes = blocks [0 ].srcdata [NID ]
@@ -559,10 +571,11 @@ def dataset(self):
559571
560572 def _collate (self , items ):
561573 if isinstance (items [0 ], tuple ):
574+ # returns a list of pairs: group them by node types into a dict
562575 items = utils .group_as_dict (items )
563- items = { k : F . zerocopy_from_numpy ( np . asarray ( v )) for k , v in items . items ()}
576+ items = utils . prepare_tensor_dict ( self . g_sampling , items , ' items' )
564577 else :
565- items = F . zerocopy_from_numpy ( np . asarray ( items ) )
578+ items = utils . prepare_tensor ( self . g_sampling , items , 'items' )
566579
567580 pair_graph = self .g .edge_subgraph (items )
568581 seed_nodes = pair_graph .ndata [NID ]
@@ -582,10 +595,11 @@ def _collate(self, items):
582595
583596 def _collate_with_negative_sampling (self , items ):
584597 if isinstance (items [0 ], tuple ):
598+ # returns a list of pairs: group them by node types into a dict
585599 items = utils .group_as_dict (items )
586- items = { k : F . zerocopy_from_numpy ( np . asarray ( v )) for k , v in items . items ()}
600+ items = utils . prepare_tensor_dict ( self . g_sampling , items , ' items' )
587601 else :
588- items = F . zerocopy_from_numpy ( np . asarray ( items ) )
602+ items = utils . prepare_tensor ( self . g_sampling , items , 'items' )
589603
590604 pair_graph = self .g .edge_subgraph (items , preserve_nodes = True )
591605 induced_edges = pair_graph .edata [EID ]
0 commit comments