1+ import jittor as jt
2+
3+
4+ def unsorted_segment_sum (x , segment_ids , num_segments ):
5+ if num_segments is None :
6+ num_segments = int (segment_ids .asnumpy ().max () + 1 )
7+
8+ segment_ids = jt .array (segment_ids , dtype = jt .int64 )
9+ assert x .shape [0 ] == segment_ids .shape [0 ], "the length of segment_ids should be equal to data.shape[0]."
10+ if len (segment_ids .shape ) == 1 :
11+ s = jt .prod (jt .array (tuple (x .shape [1 :]))).to (jt .int32 ).item ()
12+ segment_ids = segment_ids .repeat_interleave (s ).view (segment_ids .shape [0 ], * x .shape [1 :])
13+
14+ assert x .shape == segment_ids .shape , "data.shape and segment_ids.shape should be equal"
15+
16+ shape = [num_segments ] + list (x .shape [1 :])
17+ tensor = jt .zeros (* shape ).to (x .dtype ).scatter (0 , segment_ids , x , 'add' )
18+ return tensor
19+
20+ def unsorted_segment_mean (x , segment_ids , num_segments = None ):
21+ if num_segments is None :
22+ num_segments = int (segment_ids .numpy ().max () + 1 )
23+
24+ segment_ids = jt .array (segment_ids , dtype = jt .int64 )
25+ assert x .shape [0 ] == segment_ids .shape [0 ], "the length of segment_ids should be equal to data.shape[0]."
26+ res = []
27+ for i in range (num_segments ):
28+ mask_index = segment_ids == i
29+ if jt .any (mask_index ):
30+ a = jt .mean (x [mask_index ], 0 )
31+ res .append (a )
32+ else :
33+ a = jt .zeros_like (x [0 ])
34+ res .append (a )
35+ if res [0 ].shape == [1 ]:
36+ return jt .concat (res , 0 )
37+ else :
38+ return jt .stack (res , 0 )
39+
40+ def unsorted_segment_max (x , segment_ids , num_segments = None ):
41+ if num_segments is None :
42+ num_segments = int (segment_ids .numpy ().max () + 1 )
43+
44+ segment_ids = jt .array (segment_ids , dtype = jt .int64 )
45+ assert x .shape [0 ] == segment_ids .shape [0 ], "the length of segment_ids should be equal to data.shape[0]."
46+ res = []
47+ for i in range (num_segments ):
48+ mask_index = segment_ids == i
49+ if jt .any (mask_index ):
50+ res .append (jt .max (x [mask_index ], 0 )[0 ])
51+ else :
52+ a = jt .zeros_like (x [0 ])
53+ a .fill_ (jt .array (float ('-inf' )).to (a .dtype ))
54+ res .append (a )
55+ if res [0 ].shape == [1 ]:
56+ return jt .concat (res , 0 )
57+ else :
58+ return jt .stack (res , 0 )
59+
60+
61+ def segment_sum (x , segment_ids , num_segments = None ):
62+ if num_segments is None :
63+ num_segments = int (segment_ids .numpy ().max () + 1 )
64+
65+ segment_ids = jt .array (segment_ids , dtype = jt .int64 )
66+ assert x .shape [0 ] == segment_ids .shape [0 ], "the length of segment_ids should be equal to data.shape[0]."
67+ if len (segment_ids .shape ) == 1 :
68+ s = jt .prod (jt .array (x .shape [1 :])).to (jt .int32 )
69+ segment_ids = segment_ids .repeat_interleave (s ).view (segment_ids .shape [0 ], * x .shape [1 :])
70+
71+ assert x .shape == segment_ids .shape , "data.shape and segment_ids.shape should be equal"
72+
73+ shape = [num_segments ] + list (x .shape [1 :])
74+ tensor = jt .zeros (* shape ).to (x .dtype ).scatter_add (0 , segment_ids , x )
75+ return tensor
76+
77+
78+
79+ def segment_mean (x , segment_ids , num_segments = None ):
80+ if num_segments is None :
81+ num_segments = int (segment_ids .numpy ().max () + 1 )
82+
83+ segment_ids = jt .array (segment_ids , dtype = jt .int64 )
84+ assert x .shape [0 ] == segment_ids .shape [0 ], "the length of segment_ids should be equal to data.shape[0]."
85+ res = []
86+ for i in range (num_segments ):
87+ mask_index = segment_ids == i
88+ if jt .any (mask_index ):
89+ a = jt .mean (x [mask_index ], 0 )
90+ res .append (a )
91+ else :
92+ a = jt .zeros_like (x [0 ])
93+ res .append (a )
94+ if res [0 ].shape == [1 ]:
95+ return jt .concat (res , 0 )
96+ else :
97+ return jt .stack (res , 0 )
98+
99+ def segment_max (x , segment_ids , num_segments = None ):
100+ if num_segments is None :
101+ num_segments = int (segment_ids .numpy ().max () + 1 )
102+
103+ segment_ids = jt .array (segment_ids , dtype = jt .int64 )
104+ assert x .shape [0 ] == segment_ids .shape [0 ], "the length of segment_ids should be equal to data.shape[0]."
105+ res = []
106+ for i in range (num_segments ):
107+ mask_index = segment_ids == i
108+ if jt .any (mask_index ):
109+ res .append (jt .max (x [mask_index ], 0 )[0 ])
110+ else :
111+ a = jt .zeros_like (x [0 ])
112+ a .fill_ (jt .array (float ('-inf' )).to (a .dtype ))
113+ res .append (a )
114+ if res [0 ].shape == [1 ]:
115+ return jt .concat (res , 0 )
116+ else :
117+ return jt .stack (res , 0 )
118+
119+ def gspmm (index , weight = None , x = None , reduce = 'sum' ):
120+ pass
121+
122+ def bspmm (index , weight = None , x = None , reduce = 'sum' ):
123+ pass
0 commit comments