@@ -42,91 +42,6 @@ def is_scalar(x):
4242 return False
4343 return False
4444
45-
46- class _FlattenIndexMapping (object ):
47- def __init__ (self , stride = 1 , reverse = False ):
48- self ._stride = stride
49- self .reverse = reverse
50-
51- def __call__ (self , idxs : _HybridIndex ):
52- new_idxs = []
53-
54- if self .reverse == True :
55- for i in idxs :
56- new_idxs .append ( _HybridIndex ( idx = (i .idx // self ._stride ), root_idx = i .root_idx ) )
57- new_idxs = list (set (new_idxs ))
58- else :
59- for i in idxs :
60- new_idxs .extend (
61- [ _HybridIndex (idx = k , root_idx = i .root_idx ) for k in range (i .idx * self ._stride , (i .idx + 1 ) * self ._stride ) ]
62- )
63- return new_idxs
64-
65-
66- class _ConcatIndexMapping (object ):
67- def __init__ (self , offset , reverse = False ):
68- self .offset = offset
69- self .reverse = reverse
70-
71- def __call__ (self , idxs : _HybridIndex ):
72- if self .reverse == True :
73- new_idxs = [
74- _HybridIndex (idx = i .idx - self .offset [0 ], root_idx = i .root_idx )
75- for i in idxs
76- if (i .idx >= self .offset [0 ] and i .idx < self .offset [1 ])
77- ]
78- else :
79- new_idxs = [ _HybridIndex (idx = i .idx + self .offset [0 ], root_idx = i .root_idx ) for i in idxs ]
80- return new_idxs
81-
82- class _GQAIndexMapping (object ):
83- def __init__ (self , repeat , head_dim , reverse = False ):
84- self .repeat = repeat
85- self .reverse = reverse
86- self .head_dim = head_dim
87-
88- def __call__ (self , idxs : _HybridIndex ):
89- head_dim = self .head_dim
90- repeat = self .repeat
91- if self .reverse == True :
92- new_idxs = [ _HybridIndex (idx = ( i .idx - i .idx // (head_dim * repeat ) * head_dim * (repeat - 1 ) - i .idx // head_dim % repeat * head_dim ), root_idx = None ) for i in idxs ]
93- else :
94- new_idxs = []
95-
96- return new_idxs
97-
98- class _SliceIndexMapping (object ):
99- def __init__ (self , dim , start , step , end , reverse = False ):
100- self .start = start
101- self .step = step
102- self .end = end
103- self .reverse = reverse
104- self .dim = dim
105-
106- def __call__ (self , idxs : _HybridIndex ):
107-
108- if self .reverse == True :
109- new_idxs = [ _HybridIndex (idx = i .idx * self .step + self .start , root_idx = i .root_idx ) for i in idxs ]
110- else :
111- new_idxs = [ _HybridIndex (idx = (i .idx - self .start ) // self .step , root_idx = i .root_idx ) for i in idxs if (i .idx >= self .start and i .idx < self .end and (i .idx - self .start )% self .step == 0 ) ]
112- return new_idxs
113-
114- class _SplitIndexMapping (object ):
115- def __init__ (self , offset , reverse = False ):
116- self .offset = offset
117- self .reverse = reverse
118-
119- def __call__ (self , idxs : _HybridIndex ):
120- if self .reverse == True :
121- new_idxs = [ _HybridIndex (idx = i .idx + self .offset [0 ], root_idx = i .root_idx ) for i in idxs ]
122- else :
123- new_idxs = [
124- _HybridIndex (idx = i .idx - self .offset [0 ], root_idx = i .root_idx )
125- for i in idxs
126- if (i .idx >= self .offset [0 ] and i .idx < self .offset [1 ])
127- ]
128- return new_idxs
129-
13045class ScalarSum :
13146 def __init__ (self ):
13247 self ._results = {}
0 commit comments