@@ -206,10 +206,9 @@ def __init__(self):
206206 )
207207
208208
209- def flatten_attributes (component_group , absolute_name : str , d = None ) -> OrderedDict :
210- if d is None :
211- d = OrderedDict ()
212-
209+ def flatten_attributes (
210+ component_group , absolute_name : str , d : OrderedDict
211+ ) -> OrderedDict :
213212 if not hasattr (component_group , "__dict__" ):
214213 return d
215214
@@ -218,14 +217,14 @@ def flatten_attributes(component_group, absolute_name: str, d=None) -> OrderedDi
218217 if name .startswith ("_" ):
219218 # Private attribute
220219 continue
221- elif elem in component_group . __dict__ .values ():
220+ elif elem in d .values ():
222221 # Don't duplicate any tiems
223222 continue
224223 elif isinstance (elem , Component ):
225224 # Only add components to dict
226225 d [new_absolute_name ] = elem
227226 else :
228- d = flatten_attributes (elem , new_absolute_name , d = d )
227+ flatten_attributes (elem , new_absolute_name , d )
229228
230229 return d
231230
@@ -250,26 +249,35 @@ def __init__(self, demo: gr.Blocks) -> None:
250249 show_progress = False ,
251250 )
252251
252+ ignore = ["df" , "predictions_plot" ]
253253 self .run .click (
254- create_processing_function (self , ignore = ["df" , "predictions_plot" ]),
255- inputs = list (flatten_attributes (self , "interface" ).values ()),
254+ create_processing_function (self , ignore = ignore ),
255+ inputs = [
256+ v
257+ for k , v in flatten_attributes (self , "interface" , OrderedDict ()).items ()
258+ if last_part (k ) not in ignore
259+ ],
256260 outputs = [self .results .df , self .results .predictions_plot ],
257261 show_progress = True ,
258262 )
259263
260264
265+ def last_part (k : str ) -> str :
266+ return k .split ("." )[- 1 ]
267+
268+
261269def create_processing_function (interface : AppInterface , ignore = []):
262- d = flatten_attributes (interface , "interface" )
263- keys = [k .split ("." )[- 1 ] for k in d .keys ()]
264- keys = [k for k in keys if k not in ignore ]
270+ d = flatten_attributes (interface , "interface" , OrderedDict ())
271+ keys = [k for k in map (last_part , d .keys ()) if k not in ignore ]
265272 _ , idx , counts = np .unique (keys , return_index = True , return_counts = True )
266273 if np .any (counts > 1 ):
267274 raise AssertionError ("Bad keys: " + "," .join (np .array (keys )[idx [counts > 1 ]]))
268275
269- def f (components ):
276+ def f (* components ):
270277 n = len (components )
271278 assert n == len (keys )
272- return processing (** {keys [i ]: components [i ] for i in range (n )})
279+ for output in processing (** {keys [i ]: components [i ] for i in range (n )}):
280+ yield output
273281
274282 return f
275283
0 commit comments