Skip to content

Commit 76de9ae

Browse files
committed
Merge branch 'master' into 0.5.x
2 parents 1998335 + 76d66fd commit 76de9ae

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

examples/pytorch/rgcn/experimental/partition_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def load_ogb(dataset, global_norm):
5050
if ntype == category:
5151
category_id = i
5252

53-
g = dgl.to_homo(hg)
53+
g = dgl.to_homogeneous(hg, edata=['norm'])
5454
if global_norm:
5555
u, v, eid = g.all_edges(form='all')
5656
_, inverse_index, count = th.unique(v, return_inverse=True, return_counts=True)

python/dgl/dataloading/dataloader.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .. import backend as F
99
from .. import utils
1010
from ..convert import heterograph
11+
from ..distributed.dist_graph import DistGraph
1112

1213
# pylint: disable=unused-argument
1314
def assign_block_eids(block, frontier):
@@ -244,6 +245,7 @@ def sample_blocks(self, g, seed_nodes, exclude_eids=None):
244245
assign_block_eids(block, frontier)
245246

246247
seed_nodes = {ntype: block.srcnodes[ntype].data[NID] for ntype in block.srctypes}
248+
247249
# Pre-generate CSR format so that it can be used in training directly
248250
block.create_formats_()
249251
blocks.insert(0, block)
@@ -309,6 +311,7 @@ class NodeCollator(Collator):
309311
"""
310312
def __init__(self, g, nids, block_sampler):
311313
self.g = g
314+
self._is_distributed = isinstance(g, DistGraph)
312315
if not isinstance(nids, Mapping):
313316
assert len(g.ntypes) == 1, \
314317
"nids should be a dict of node type and ids for graph with multiple node types"
@@ -352,6 +355,15 @@ def collate(self, items):
352355
if isinstance(items[0], tuple):
353356
# returns a list of pairs: group them by node types into a dict
354357
items = utils.group_as_dict(items)
358+
359+
# TODO(BarclayII) Because DistGraph doesn't have idtype and device implemented,
360+
# this function does not work. I'm again skipping this step as a workaround.
361+
# We need to fix this.
362+
if not self._is_distributed:
363+
if isinstance(items, dict):
364+
items = utils.prepare_tensor_dict(self.g, items, 'items')
365+
else:
366+
items = utils.prepare_tensor(self.g, items, 'items')
355367
blocks = self.block_sampler.sample_blocks(self.g, items)
356368
output_nodes = blocks[-1].dstdata[NID]
357369
input_nodes = blocks[0].srcdata[NID]
@@ -559,10 +571,11 @@ def dataset(self):
559571

560572
def _collate(self, items):
561573
if isinstance(items[0], tuple):
574+
# returns a list of pairs: group them by node types into a dict
562575
items = utils.group_as_dict(items)
563-
items = {k: F.zerocopy_from_numpy(np.asarray(v)) for k, v in items.items()}
576+
items = utils.prepare_tensor_dict(self.g_sampling, items, 'items')
564577
else:
565-
items = F.zerocopy_from_numpy(np.asarray(items))
578+
items = utils.prepare_tensor(self.g_sampling, items, 'items')
566579

567580
pair_graph = self.g.edge_subgraph(items)
568581
seed_nodes = pair_graph.ndata[NID]
@@ -582,10 +595,11 @@ def _collate(self, items):
582595

583596
def _collate_with_negative_sampling(self, items):
584597
if isinstance(items[0], tuple):
598+
# returns a list of pairs: group them by node types into a dict
585599
items = utils.group_as_dict(items)
586-
items = {k: F.zerocopy_from_numpy(np.asarray(v)) for k, v in items.items()}
600+
items = utils.prepare_tensor_dict(self.g_sampling, items, 'items')
587601
else:
588-
items = F.zerocopy_from_numpy(np.asarray(items))
602+
items = utils.prepare_tensor(self.g_sampling, items, 'items')
589603

590604
pair_graph = self.g.edge_subgraph(items, preserve_nodes=True)
591605
induced_edges = pair_graph.edata[EID]

0 commit comments

Comments
 (0)