Skip to content

Commit 1ab7720

Browse files
authored
add jittor mpops (BUPT-GAMMA#227)
1 parent 26d3793 commit 1ab7720

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

gammagl/mpops/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
# !/usr/bin/env python3
23
# -*- coding:utf-8 -*-
34

@@ -19,6 +20,9 @@
1920

2021
elif os.environ['TL_BACKEND'] == 'torch':
2122
from .torch import *
23+
24+
elif os.environ['TL_BACKEND'] == 'jittor':
25+
from .jittor import *
2226

2327
else:
2428
raise NotImplementedError("This backend is not supported")

gammagl/mpops/jittor.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import jittor as jt
2+
3+
4+
def unsorted_segment_sum(x, segment_ids, num_segments):
5+
if num_segments is None:
6+
num_segments = int(segment_ids.asnumpy().max() + 1)
7+
8+
segment_ids = jt.array(segment_ids, dtype=jt.int64)
9+
assert x.shape[0] == segment_ids.shape[0], "the length of segment_ids should be equal to data.shape[0]."
10+
if len(segment_ids.shape) == 1:
11+
s = jt.prod(jt.array(tuple(x.shape[1:]))).to(jt.int32).item()
12+
segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *x.shape[1:])
13+
14+
assert x.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal"
15+
16+
shape = [num_segments] + list(x.shape[1:])
17+
tensor = jt.zeros(*shape).to(x.dtype).scatter(0, segment_ids, x, 'add')
18+
return tensor
19+
20+
def unsorted_segment_mean(x, segment_ids, num_segments=None):
21+
if num_segments is None:
22+
num_segments = int(segment_ids.numpy().max() + 1)
23+
24+
segment_ids = jt.array(segment_ids, dtype=jt.int64)
25+
assert x.shape[0] == segment_ids.shape[0], "the length of segment_ids should be equal to data.shape[0]."
26+
res = []
27+
for i in range(num_segments):
28+
mask_index = segment_ids == i
29+
if jt.any(mask_index):
30+
a = jt.mean(x[mask_index], 0)
31+
res.append(a)
32+
else:
33+
a = jt.zeros_like(x[0])
34+
res.append(a)
35+
if res[0].shape == [1]:
36+
return jt.concat(res, 0)
37+
else:
38+
return jt.stack(res, 0)
39+
40+
def unsorted_segment_max(x, segment_ids, num_segments=None):
41+
if num_segments is None:
42+
num_segments = int(segment_ids.numpy().max() + 1)
43+
44+
segment_ids = jt.array(segment_ids, dtype=jt.int64)
45+
assert x.shape[0] == segment_ids.shape[0], "the length of segment_ids should be equal to data.shape[0]."
46+
res = []
47+
for i in range(num_segments):
48+
mask_index = segment_ids == i
49+
if jt.any(mask_index):
50+
res.append(jt.max(x[mask_index], 0)[0])
51+
else:
52+
a = jt.zeros_like(x[0])
53+
a.fill_(jt.array(float('-inf')).to(a.dtype))
54+
res.append(a)
55+
if res[0].shape == [1]:
56+
return jt.concat(res, 0)
57+
else:
58+
return jt.stack(res, 0)
59+
60+
61+
def segment_sum(x, segment_ids, num_segments=None):
62+
if num_segments is None:
63+
num_segments = int(segment_ids.numpy().max() + 1)
64+
65+
segment_ids = jt.array(segment_ids, dtype=jt.int64)
66+
assert x.shape[0] == segment_ids.shape[0], "the length of segment_ids should be equal to data.shape[0]."
67+
if len(segment_ids.shape) == 1:
68+
s = jt.prod(jt.array(x.shape[1:])).to(jt.int32)
69+
segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *x.shape[1:])
70+
71+
assert x.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal"
72+
73+
shape = [num_segments] + list(x.shape[1:])
74+
tensor = jt.zeros(*shape).to(x.dtype).scatter_add(0, segment_ids, x)
75+
return tensor
76+
77+
78+
79+
def segment_mean(x, segment_ids, num_segments=None):
80+
if num_segments is None:
81+
num_segments = int(segment_ids.numpy().max() + 1)
82+
83+
segment_ids = jt.array(segment_ids, dtype=jt.int64)
84+
assert x.shape[0] == segment_ids.shape[0], "the length of segment_ids should be equal to data.shape[0]."
85+
res = []
86+
for i in range(num_segments):
87+
mask_index = segment_ids == i
88+
if jt.any(mask_index):
89+
a = jt.mean(x[mask_index], 0)
90+
res.append(a)
91+
else:
92+
a = jt.zeros_like(x[0])
93+
res.append(a)
94+
if res[0].shape == [1]:
95+
return jt.concat(res, 0)
96+
else:
97+
return jt.stack(res, 0)
98+
99+
def segment_max(x, segment_ids, num_segments=None):
100+
if num_segments is None:
101+
num_segments = int(segment_ids.numpy().max() + 1)
102+
103+
segment_ids = jt.array(segment_ids, dtype=jt.int64)
104+
assert x.shape[0] == segment_ids.shape[0], "the length of segment_ids should be equal to data.shape[0]."
105+
res = []
106+
for i in range(num_segments):
107+
mask_index = segment_ids == i
108+
if jt.any(mask_index):
109+
res.append(jt.max(x[mask_index], 0)[0])
110+
else:
111+
a = jt.zeros_like(x[0])
112+
a.fill_(jt.array(float('-inf')).to(a.dtype))
113+
res.append(a)
114+
if res[0].shape == [1]:
115+
return jt.concat(res, 0)
116+
else:
117+
return jt.stack(res, 0)
118+
119+
def gspmm(index, weight=None, x=None, reduce='sum'):
120+
pass
121+
122+
def bspmm(index, weight=None, x=None, reduce='sum'):
123+
pass

0 commit comments

Comments
 (0)