@@ -55,6 +55,7 @@ def __init__(
5555 constant : Sequence [Variable ] | None = None ,
5656 until : Variable | None = None ,
5757 name = "ScalarLoop" ,
58+ ** kwargs ,
5859 ):
5960 if constant is None :
6061 constant = []
@@ -75,7 +76,7 @@ def __init__(
7576 self .nout = len (self .outputs )
7677 self .name = name
7778
78- super ().__init__ ()
79+ super ().__init__ (** kwargs )
7980
8081 def output_types (self , input_types ):
8182 return self .outputs_type
@@ -115,7 +116,7 @@ def fgraph(self):
115116 self ._fgraph = fgraph
116117 return self ._fgraph
117118
118- def clone (self ):
119+ def clone (self , name = None , ** kwargs ):
119120 if self .is_while :
120121 * update , until = self .outputs
121122 else :
@@ -127,28 +128,16 @@ def clone(self):
127128 update = update ,
128129 constant = constant ,
129130 until = until ,
130- name = self .name ,
131+ name = self .name if name is None else name ,
132+ ** kwargs ,
131133 )
132134
133135 @property
134136 def fn (self ):
135137 raise NotImplementedError
136138
137139 def make_new_inplace (self , output_types_preference = None , name = None ):
138- """
139- This op.__init__ fct don't have the same parameter as other scalar op.
140- This break the insert_inplace_optimizer optimization.
141- This fct allow fix patch this.
142-
143- """
144- d = {k : getattr (self , k ) for k in self .init_param }
145- out = self .__class__ (** d )
146- if name :
147- out .name = name
148- else :
149- name = out .name
150- super (ScalarLoop , out ).__init__ (output_types_preference , name )
151- return out
140+ return self .clone (output_types_preference = output_types_preference , name = name )
152141
153142 def make_node (self , n_steps , * inputs ):
154143 assert len (inputs ) == self .nin - 1
@@ -229,11 +218,11 @@ def c_code_template(self):
229218 c : f"%(i{ int (i )} )s"
230219 for i , c in enumerate (fgraph .inputs [n_update :], start = n_update + 1 )
231220 }
232- update_subd = {
221+ out_subd = {
233222 u : f"%(o{ int (i )} )s" for i , u in enumerate (fgraph .outputs [:n_update ])
234223 }
235224 until_subd = {u : "until" for u in fgraph .outputs [n_update :]}
236- subd = {** carry_subd , ** constant_subd , ** update_subd , ** until_subd }
225+ subd = {** carry_subd , ** constant_subd , ** until_subd }
237226
238227 for var in fgraph .variables :
239228 if var .owner is None :
@@ -257,11 +246,11 @@ def c_code_template(self):
257246 _c_code += "bool until = 1;\n \n "
258247
259248 # Copy carried inputs
260- for i , (var , name ) in enumerate (carry_subd .items ()):
261- copy_var_name = f"{ name } _copy { i } "
262- _c_code += f"{ var .type .dtype_specs ()[1 ]} { copy_var_name } = { name } ;\n "
263- carry_subd [var ] = copy_var_name
264- subd [var ] = copy_var_name
249+ for i , (var , name ) in enumerate (carry_subd .items (), start = 1 ):
250+ carry_var_name = f"{ name } _carry { i } "
251+ _c_code += f"{ var .type .dtype_specs ()[1 ]} { carry_var_name } = { name } ;\n "
252+ carry_subd [var ] = carry_var_name
253+ subd [var ] = carry_var_name
265254
266255 # _c_code += 'printf("inputs=[");'
267256 # for i in range(1, len(fgraph.inputs)):
@@ -270,9 +259,8 @@ def c_code_template(self):
270259
271260 _c_code += "\n for(%(n_steps_dtype)s i = 0; i < %(n_steps)s; i++){\n "
272261
273- self .nodenames = [
274- f"%(nodename)s_subnode{ int (j )} " for j , n in enumerate (fgraph .toposort ())
275- ]
262+ # Used by self.c_support_code_apply
263+ self .nodenames = nodenames = []
276264
277265 i = 0
278266 for j , node in enumerate (fgraph .toposort ()):
@@ -282,9 +270,13 @@ def c_code_template(self):
282270 name = f"V%(id)s_tmp{ int (i )} "
283271 subd [output ] = name
284272 _c_code += f"{ output .type .dtype_specs ()[1 ]} { name } ;\n "
273+
274+ nodename = f"%(nodename)s_subnode{ int (j )} "
275+ nodenames .append (nodename )
276+
285277 s = node .op .c_code (
286278 node ,
287- self . nodenames [ j ] ,
279+ nodename ,
288280 # Any node that depended on `init` will depend on `update` instead
289281 # The initial value of `update` was set to `init` before the loop
290282 [subd [input ] for input in node .inputs ],
@@ -294,10 +286,12 @@ def c_code_template(self):
294286 _c_code += s
295287 _c_code += "\n "
296288
297- # Set the carry variables to the output variables
289+ # Update the carry variables to the output variables
298290 _c_code += "\n "
299- for init , update in zip (carry_subd .values (), update_subd .values (), strict = True ):
300- _c_code += f"{ init } = { update } ;\n "
291+ for carry , out in zip (
292+ carry_subd .values (), fgraph .outputs [:n_update ], strict = True
293+ ):
294+ _c_code += f"{ carry } = { subd [out ]} ;\n "
301295
302296 # _c_code += 'printf("%%ld\\n", i);\n'
303297 # for carry in range(1, 10):
@@ -309,6 +303,10 @@ def c_code_template(self):
309303 # End of the loop
310304 _c_code += "}\n "
311305
306+ # Assign the carry variables to the outputs
307+ for out , carry in zip (out_subd .values (), carry_subd .values (), strict = True ):
308+ _c_code += f"{ out } = { carry } ;\n "
309+
312310 # Output until flag
313311 if self .is_while :
314312 _c_code += f"%(o{ len (fgraph .outputs )- 1 } )s = until;\n "
@@ -343,4 +341,4 @@ def c_code(self, node, nodename, inames, onames, sub):
343341 return res
344342
345343 def c_code_cache_version_outer (self ):
346- return (3 ,)
344+ return (4 ,)
0 commit comments