@@ -528,43 +528,24 @@ def add_requirements(self, fgraph):
528528
529529 @staticmethod
530530 def elemwise_to_scalar (inputs , outputs ):
531- replace_inputs = [(inp , inp .clone ()) for inp in inputs ]
532- outputs = clone_replace (outputs , replace = replace_inputs )
533-
534- inputs = [inp for _ , inp in replace_inputs ]
535- fg = FunctionGraph (inputs = inputs , outputs = outputs , clone = False )
536- middle_inputs = []
537-
538- scalar_inputs = [
539- ps .get_scalar_type (inp .type .dtype ).make_variable () for inp in inputs
540- ]
541- middle_scalar_inputs = []
542-
543- for node in fg .toposort ():
544- node_scalar_inputs = []
545- for inp in node .inputs :
546- if inp in inputs :
547- node_scalar_inputs .append (scalar_inputs [inputs .index (inp )])
548- elif inp in middle_inputs :
549- node_scalar_inputs .append (
550- middle_scalar_inputs [middle_inputs .index (inp )]
531+ replacement = {
532+ inp : ps .get_scalar_type (inp .type .dtype ).make_variable () for inp in inputs
533+ }
534+ for node in toposort (outputs , blockers = inputs ):
535+ scalar_inputs = [replacement [inp ] for inp in node .inputs ]
536+ replacement .update (
537+ dict (
538+ zip (
539+ node .outputs ,
540+ node .op .scalar_op .make_node (* scalar_inputs ).outputs ,
551541 )
552- else :
553- new_scalar_input = ps .get_scalar_type (
554- inp .type .dtype
555- ).make_variable ()
556- node_scalar_inputs .append (new_scalar_input )
557- middle_scalar_inputs .append (new_scalar_input )
558- middle_inputs .append (inp )
559-
560- new_scalar_node = node .op .scalar_op .make_node (* node_scalar_inputs )
561- middle_scalar_inputs .append (new_scalar_node .outputs [0 ])
562- middle_inputs .append (node .outputs [0 ])
563-
564- scalar_outputs = [
565- middle_scalar_inputs [middle_inputs .index (out )] for out in fg .outputs
566- ]
567- return scalar_inputs , scalar_outputs
542+ )
543+ )
544+
545+ return (
546+ [replacement [inp ] for inp in inputs ],
547+ [replacement [out ] for out in outputs ],
548+ )
568549
569550 def apply (self , fgraph ):
570551 nb_replacement = 0
0 commit comments