Skip to content

Commit 6fec0f0

Browse files
authored
Add edge_sample dispatch and CPU kernel (#594)
1 parent 8b5daf0 commit 6fec0f0

File tree

5 files changed

+190
-0
lines changed

5 files changed

+190
-0
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#include "../edge_sampler.h"
2+
3+
#include <ATen/ATen.h>
4+
#include <torch/library.h>
5+
6+
#include <cmath>
7+
#include <unordered_set>
8+
#include <vector>
9+
10+
namespace pyg {
11+
namespace ops {
12+
13+
namespace {
14+
15+
at::Tensor edge_sample_kernel(const at::Tensor& start,
16+
const at::Tensor& rowptr,
17+
int64_t count,
18+
double factor) {
19+
auto start_data = start.data_ptr<int64_t>();
20+
auto rowptr_data = rowptr.data_ptr<int64_t>();
21+
22+
std::vector<int64_t> e_ids;
23+
24+
for (int64_t i = 0; i < start.size(0); i++) {
25+
auto row_start = rowptr_data[start_data[i]];
26+
auto row_end = rowptr_data[start_data[i] + 1];
27+
auto num_neighbors = row_end - row_start;
28+
29+
int64_t size = count;
30+
if (count < 1)
31+
size = static_cast<int64_t>(std::ceil(factor * double(num_neighbors)));
32+
if (size > num_neighbors)
33+
size = num_neighbors;
34+
35+
if (size < 0.7 * double(num_neighbors)) {
36+
std::unordered_set<int64_t> set;
37+
while (static_cast<int64_t>(set.size()) < size) {
38+
int64_t sample = std::rand() % num_neighbors;
39+
set.insert(sample + row_start);
40+
}
41+
std::vector<int64_t> v(set.begin(), set.end());
42+
e_ids.insert(e_ids.end(), v.begin(), v.end());
43+
} else {
44+
auto sample = at::randperm(num_neighbors, start.options());
45+
auto sample_data = sample.data_ptr<int64_t>();
46+
for (int64_t j = 0; j < size; j++) {
47+
e_ids.push_back(sample_data[j] + row_start);
48+
}
49+
}
50+
}
51+
52+
int64_t length = static_cast<int64_t>(e_ids.size());
53+
return at::from_blob(e_ids.data(), {length}, start.options()).clone();
54+
}
55+
56+
} // namespace
57+
58+
TORCH_LIBRARY_IMPL(pyg, CPU, m) {
59+
m.impl(TORCH_SELECTIVE_NAME("pyg::edge_sample"),
60+
TORCH_FN(edge_sample_kernel));
61+
}
62+
63+
} // namespace ops
64+
} // namespace pyg

pyg_lib/csrc/ops/edge_sampler.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include "edge_sampler.h"
2+
3+
#include <ATen/core/dispatch/Dispatcher.h>
4+
#include <torch/library.h>
5+
6+
namespace pyg {
7+
namespace ops {
8+
9+
PYG_API at::Tensor edge_sample(const at::Tensor& start,
10+
const at::Tensor& rowptr,
11+
int64_t count,
12+
double factor) {
13+
at::TensorArg start_arg{start, "start", 0};
14+
at::TensorArg rowptr_arg{rowptr, "rowptr", 1};
15+
at::CheckedFrom c{"edge_sample"};
16+
17+
at::checkAllDefined(c, {start_arg, rowptr_arg});
18+
at::checkDim(c, start_arg, 1);
19+
at::checkDim(c, rowptr_arg, 1);
20+
21+
static auto op = c10::Dispatcher::singleton()
22+
.findSchemaOrThrow("pyg::edge_sample", "")
23+
.typed<decltype(edge_sample)>();
24+
return op.call(start, rowptr, count, factor);
25+
}
26+
27+
TORCH_LIBRARY_FRAGMENT(pyg, m) {
28+
m.def(
29+
TORCH_SELECTIVE_SCHEMA("pyg::edge_sample(Tensor start, Tensor rowptr, "
30+
"int count=0, float factor=1.0) -> Tensor"));
31+
}
32+
33+
} // namespace ops
34+
} // namespace pyg

pyg_lib/csrc/ops/edge_sampler.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include "pyg_lib/csrc/macros.h"
5+
6+
namespace pyg {
7+
namespace ops {
8+
9+
PYG_API at::Tensor edge_sample(const at::Tensor& start,
10+
const at::Tensor& rowptr,
11+
int64_t count,
12+
double factor);
13+
14+
} // namespace ops
15+
} // namespace pyg

pyg_lib/ops/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,30 @@ def graclus_cluster(
543543
return torch.ops.pyg.graclus_cluster(rowptr, col, weight)
544544

545545

546+
def edge_sample(
547+
start: Tensor,
548+
rowptr: Tensor,
549+
count: int = 0,
550+
factor: float = 1.0,
551+
) -> Tensor:
552+
r"""Samples edges incident to the given start nodes.
553+
554+
For each start node, samples up to :obj:`count` edges. If
555+
:obj:`count < 1`, samples :obj:`ceil(factor * degree)` edges instead.
556+
557+
Args:
558+
start: Start node indices of shape :obj:`[S]`.
559+
rowptr: CSR row pointer of shape :obj:`[N + 1]`.
560+
count: Fixed number of edges to sample per node. If :obj:`< 1`,
561+
uses :obj:`factor` instead.
562+
factor: Fraction of edges to sample when :obj:`count < 1`.
563+
564+
Returns:
565+
Sampled edge indices (into the edge list).
566+
"""
567+
return torch.ops.pyg.edge_sample(start, rowptr, count, factor)
568+
569+
546570
__all__ = [
547571
'grouped_matmul',
548572
'segment_matmul',
@@ -560,4 +584,5 @@ def graclus_cluster(
560584
'radius',
561585
'nearest',
562586
'graclus_cluster',
587+
'edge_sample',
563588
]

test/ops/test_edge_sampler.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
3+
import pyg_lib
4+
from pyg_lib.testing import withCUDA
5+
6+
7+
@withCUDA
8+
def test_edge_sample_count(device: torch.device) -> None:
9+
if device.type == 'cuda':
10+
return # CPU only
11+
# Graph: node 0 has 3 edges, node 1 has 2 edges
12+
rowptr = torch.tensor([0, 3, 5], dtype=torch.long, device=device)
13+
start = torch.tensor([0, 1], dtype=torch.long, device=device)
14+
15+
out = pyg_lib.ops.edge_sample(start, rowptr, count=2)
16+
assert out.numel() == 4 # 2 per node * 2 nodes
17+
18+
# All sampled edge indices should be valid
19+
assert (out >= 0).all()
20+
assert (out < 5).all()
21+
22+
# Node 0 edges in [0, 3), node 1 edges in [3, 5)
23+
node0_edges = out[:2]
24+
node1_edges = out[2:]
25+
assert (node0_edges < 3).all()
26+
assert (node1_edges >= 3).all()
27+
28+
29+
@withCUDA
30+
def test_edge_sample_factor(device: torch.device) -> None:
31+
if device.type == 'cuda':
32+
return # CPU only
33+
# Node with 10 edges
34+
rowptr = torch.tensor([0, 10], dtype=torch.long, device=device)
35+
start = torch.tensor([0], dtype=torch.long, device=device)
36+
37+
out = pyg_lib.ops.edge_sample(start, rowptr, count=0, factor=0.5)
38+
# ceil(0.5 * 10) = 5
39+
assert out.numel() == 5
40+
41+
42+
@withCUDA
43+
def test_edge_sample_cap(device: torch.device) -> None:
44+
if device.type == 'cuda':
45+
return # CPU only
46+
# Node with 3 edges, request 10
47+
rowptr = torch.tensor([0, 3], dtype=torch.long, device=device)
48+
start = torch.tensor([0], dtype=torch.long, device=device)
49+
50+
out = pyg_lib.ops.edge_sample(start, rowptr, count=10)
51+
# Capped at degree = 3
52+
assert out.numel() == 3

0 commit comments

Comments
 (0)