Skip to content

Commit a2318e4

Browse files
committed
Replace VarName usage with str in type hints and conversion
1 parent 8a436d8 commit a2318e4

File tree

3 files changed

+20
-20
lines changed

3 files changed

+20
-20
lines changed

pymc/model/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,7 +1939,7 @@ def debug_parameters(rv):
19391939
def to_graphviz(
19401940
self,
19411941
*,
1942-
var_names: Iterable[VarName] | None = None,
1942+
var_names: Iterable[str] | None = None,
19431943
formatting: str = "plain",
19441944
save: str | None = None,
19451945
figsize: tuple[int, int] | None = None,
@@ -2143,7 +2143,7 @@ def compile_fn(
21432143
)
21442144

21452145

2146-
def Point(*args, filter_model_vars=False, **kwargs) -> dict[VarName, np.ndarray]:
2146+
def Point(*args, filter_model_vars=False, **kwargs) -> dict[str, np.ndarray]:
21472147
"""Build a point.
21482148
21492149
Uses same args as dict() does.

pymc/model_graph.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def default_data(var: TensorVariable) -> GraphvizNodeKwargs:
170170
}
171171

172172

173-
def get_node_type(var_name: VarName, model) -> NodeType:
173+
def get_node_type(var_name: str, model) -> NodeType:
174174
"""Return the node type of the variable in the model."""
175175
v = model[var_name]
176176

@@ -239,7 +239,7 @@ def __init__(self, model):
239239
self._all_vars = {model[var_name] for var_name in self._all_var_names}
240240
self.var_list = self.model.named_vars.values()
241241

242-
def get_parent_names(self, var: TensorVariable) -> set[VarName]:
242+
def get_parent_names(self, var: TensorVariable) -> set[str]:
243243
if var.owner is None:
244244
return set()
245245

@@ -258,12 +258,12 @@ def _expand(x):
258258
return x.owner.inputs
259259

260260
return {
261-
cast(VarName, ancestor.name) # type: ignore[union-attr]
261+
cast(str, ancestor.name) # type: ignore[union-attr]
262262
for ancestor in walk(nodes=var.owner.inputs, expand=_expand)
263263
if ancestor in named_vars
264264
}
265265

266-
def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarName]:
266+
def vars_to_plot(self, var_names: Iterable[str] | None = None) -> list[str]:
267267
if var_names is None:
268268
return self._all_var_names
269269

@@ -294,12 +294,12 @@ def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarNa
294294
return [get_var_name(var) for var in selected_ancestors]
295295

296296
def make_compute_graph(
297-
self, var_names: Iterable[VarName] | None = None
298-
) -> dict[VarName, set[VarName]]:
297+
self, var_names: Iterable[str] | None = None
298+
) -> dict[str, set[str]]:
299299
"""Get map of var_name -> set(input var names) for the model."""
300300
model = self.model
301301
named_vars = self._all_vars
302-
input_map: dict[VarName, set[VarName]] = defaultdict(set)
302+
input_map: dict[str, set[str]] = defaultdict(set)
303303

304304
var_names_to_plot = self.vars_to_plot(var_names)
305305
for var_name in var_names_to_plot:
@@ -316,15 +316,15 @@ def make_compute_graph(
316316
for ancestor in ancestors([obs_var]):
317317
if ancestor not in named_vars:
318318
continue
319-
obs_name = cast(VarName, ancestor.name)
319+
obs_name = cast(str, ancestor.name)
320320
input_map[var_name].discard(obs_name)
321321
input_map[obs_name].add(var_name)
322322

323323
return input_map
324324

325325
def get_plates(
326326
self,
327-
var_names: Iterable[VarName] | None = None,
327+
var_names: Iterable[str] | None = None,
328328
) -> list[Plate]:
329329
"""Rough but surprisingly accurate plate detection.
330330
@@ -386,8 +386,8 @@ def get_plates(
386386

387387
def edges(
388388
self,
389-
var_names: Iterable[VarName] | None = None,
390-
) -> list[tuple[VarName, VarName]]:
389+
var_names: Iterable[str] | None = None,
390+
) -> list[tuple[str, str]]:
391391
"""Get edges between the variables in the model.
392392
393393
Parameters
@@ -402,7 +402,7 @@ def edges(
402402
403403
"""
404404
return [
405-
(VarName(child.replace(":", "&")), VarName(parent.replace(":", "&")))
405+
(str(child.replace(":", "&")), str(parent.replace(":", "&")))
406406
for child, parents in self.make_compute_graph(var_names=var_names).items()
407407
for parent in parents
408408
]
@@ -419,7 +419,7 @@ def nodes(self, plates: list[Plate] | None = None) -> list[NodeInfo]:
419419
def make_graph(
420420
name: str,
421421
plates: list[Plate],
422-
edges: list[tuple[VarName, VarName]],
422+
edges: list[tuple[str, str]],
423423
formatting: str = "plain",
424424
save=None,
425425
figsize=None,
@@ -493,7 +493,7 @@ def make_graph(
493493
def make_networkx(
494494
name: str,
495495
plates: list[Plate],
496-
edges: list[tuple[VarName, VarName]],
496+
edges: list[tuple[str, str]],
497497
formatting: str = "plain",
498498
node_formatters: NodeTypeFormatterMapping | None = None,
499499
create_plate_label: PlateLabelFunc = create_plate_label_with_dim_length,
@@ -563,7 +563,7 @@ def make_networkx(
563563
def model_to_networkx(
564564
model=None,
565565
*,
566-
var_names: Iterable[VarName] | None = None,
566+
var_names: Iterable[str] | None = None,
567567
formatting: str = "plain",
568568
node_formatters: NodeTypeFormatterMapping | None = None,
569569
include_dim_lengths: bool = True,
@@ -657,7 +657,7 @@ def model_to_networkx(
657657
def model_to_graphviz(
658658
model=None,
659659
*,
660-
var_names: Iterable[VarName] | None = None,
660+
var_names: Iterable[str] | None = None,
661661
formatting: str = "plain",
662662
save: str | None = None,
663663
figsize: tuple[int, int] | None = None,

pymc/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,9 @@ def get_default_varnames(var_iterator, include_transformed):
214214
return [var for var in var_iterator if not is_transformed_name(get_var_name(var))]
215215

216216

217-
def get_var_name(var) -> VarName:
217+
def get_var_name(var) -> str:
218218
"""Get an appropriate, plain variable name for a variable."""
219-
return VarName(str(getattr(var, "name", var)))
219+
return str(getattr(var, "name", var))
220220

221221

222222
def get_transformed(z):

0 commit comments

Comments
 (0)