@@ -170,7 +170,7 @@ def default_data(var: TensorVariable) -> GraphvizNodeKwargs:
170170 }
171171
172172
173- def get_node_type (var_name : VarName , model ) -> NodeType :
173+ def get_node_type (var_name : str , model ) -> NodeType :
174174 """Return the node type of the variable in the model."""
175175 v = model [var_name ]
176176
@@ -239,7 +239,7 @@ def __init__(self, model):
239239 self ._all_vars = {model [var_name ] for var_name in self ._all_var_names }
240240 self .var_list = self .model .named_vars .values ()
241241
242- def get_parent_names (self , var : TensorVariable ) -> set [VarName ]:
242+ def get_parent_names (self , var : TensorVariable ) -> set [str ]:
243243 if var .owner is None :
244244 return set ()
245245
@@ -258,12 +258,12 @@ def _expand(x):
258258 return x .owner .inputs
259259
260260 return {
261- cast (VarName , ancestor .name ) # type: ignore[union-attr]
261+ cast (str , ancestor .name ) # type: ignore[union-attr]
262262 for ancestor in walk (nodes = var .owner .inputs , expand = _expand )
263263 if ancestor in named_vars
264264 }
265265
266- def vars_to_plot (self , var_names : Iterable [VarName ] | None = None ) -> list [VarName ]:
266+ def vars_to_plot (self , var_names : Iterable [str ] | None = None ) -> list [str ]:
267267 if var_names is None :
268268 return self ._all_var_names
269269
@@ -294,12 +294,12 @@ def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarNa
294294 return [get_var_name (var ) for var in selected_ancestors ]
295295
296296 def make_compute_graph (
297- self , var_names : Iterable [VarName ] | None = None
298- ) -> dict [VarName , set [VarName ]]:
297+ self , var_names : Iterable [str ] | None = None
298+ ) -> dict [str , set [str ]]:
299299 """Get map of var_name -> set(input var names) for the model."""
300300 model = self .model
301301 named_vars = self ._all_vars
302- input_map : dict [VarName , set [VarName ]] = defaultdict (set )
302+ input_map : dict [str , set [str ]] = defaultdict (set )
303303
304304 var_names_to_plot = self .vars_to_plot (var_names )
305305 for var_name in var_names_to_plot :
@@ -316,15 +316,15 @@ def make_compute_graph(
316316 for ancestor in ancestors ([obs_var ]):
317317 if ancestor not in named_vars :
318318 continue
319- obs_name = cast (VarName , ancestor .name )
319+ obs_name = cast (str , ancestor .name )
320320 input_map [var_name ].discard (obs_name )
321321 input_map [obs_name ].add (var_name )
322322
323323 return input_map
324324
325325 def get_plates (
326326 self ,
327- var_names : Iterable [VarName ] | None = None ,
327+ var_names : Iterable [str ] | None = None ,
328328 ) -> list [Plate ]:
329329 """Rough but surprisingly accurate plate detection.
330330
@@ -386,8 +386,8 @@ def get_plates(
386386
387387 def edges (
388388 self ,
389- var_names : Iterable [VarName ] | None = None ,
390- ) -> list [tuple [VarName , VarName ]]:
389+ var_names : Iterable [str ] | None = None ,
390+ ) -> list [tuple [str , str ]]:
391391 """Get edges between the variables in the model.
392392
393393 Parameters
@@ -402,7 +402,7 @@ def edges(
402402
403403 """
404404 return [
405- (VarName (child .replace (":" , "&" )), VarName (parent .replace (":" , "&" )))
405+ (str (child .replace (":" , "&" )), str (parent .replace (":" , "&" )))
406406 for child , parents in self .make_compute_graph (var_names = var_names ).items ()
407407 for parent in parents
408408 ]
@@ -419,7 +419,7 @@ def nodes(self, plates: list[Plate] | None = None) -> list[NodeInfo]:
419419def make_graph (
420420 name : str ,
421421 plates : list [Plate ],
422- edges : list [tuple [VarName , VarName ]],
422+ edges : list [tuple [str , str ]],
423423 formatting : str = "plain" ,
424424 save = None ,
425425 figsize = None ,
@@ -493,7 +493,7 @@ def make_graph(
493493def make_networkx (
494494 name : str ,
495495 plates : list [Plate ],
496- edges : list [tuple [VarName , VarName ]],
496+ edges : list [tuple [str , str ]],
497497 formatting : str = "plain" ,
498498 node_formatters : NodeTypeFormatterMapping | None = None ,
499499 create_plate_label : PlateLabelFunc = create_plate_label_with_dim_length ,
@@ -563,7 +563,7 @@ def make_networkx(
563563def model_to_networkx (
564564 model = None ,
565565 * ,
566- var_names : Iterable [VarName ] | None = None ,
566+ var_names : Iterable [str ] | None = None ,
567567 formatting : str = "plain" ,
568568 node_formatters : NodeTypeFormatterMapping | None = None ,
569569 include_dim_lengths : bool = True ,
@@ -657,7 +657,7 @@ def model_to_networkx(
657657def model_to_graphviz (
658658 model = None ,
659659 * ,
660- var_names : Iterable [VarName ] | None = None ,
660+ var_names : Iterable [str ] | None = None ,
661661 formatting : str = "plain" ,
662662 save : str | None = None ,
663663 figsize : tuple [int , int ] | None = None ,
0 commit comments