Skip to content

Commit 7bb6bf8

Browse files
Armavicatwiecki
authored andcommitted
Fix UP038 (isinstance(..., X | Y))
1 parent 2a86c6b commit 7bb6bf8

38 files changed

+75
-79
lines changed

pymc/backends/arviz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def is_data(name, var, model) -> bool:
7878
and var not in model.potentials
7979
and var not in model.value_vars
8080
and name not in observations
81-
and isinstance(var, (Constant, SharedVariable))
81+
and isinstance(var, Constant | SharedVariable)
8282
)
8383

8484
# The assumption is that constants (like pm.Data) are named

pymc/distributions/censored.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class Censored(Distribution):
8989
@classmethod
9090
def dist(cls, dist, lower, upper, **kwargs):
9191
if not isinstance(dist, TensorVariable) or not isinstance(
92-
dist.owner.op, (RandomVariable, SymbolicRandomVariable)
92+
dist.owner.op, RandomVariable | SymbolicRandomVariable
9393
):
9494
raise ValueError(
9595
f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}"

pymc/distributions/distribution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def rewrite_support_point_scan_node(self, node):
102102

103103
for nd in local_fgraph_topo:
104104
if nd not in to_replace_set and isinstance(
105-
nd.op, (RandomVariable, SymbolicRandomVariable)
105+
nd.op, RandomVariable | SymbolicRandomVariable
106106
):
107107
replace_with_support_point.append(nd.out)
108108
to_replace_set.add(nd)
@@ -132,7 +132,7 @@ def add_requirements(self, fgraph):
132132

133133
def apply(self, fgraph):
134134
for node in fgraph.toposort():
135-
if isinstance(node.op, (RandomVariable, SymbolicRandomVariable)):
135+
if isinstance(node.op, RandomVariable | SymbolicRandomVariable):
136136
fgraph.replace(node.out, support_point(node.out))
137137
elif isinstance(node.op, Scan):
138138
new_node = self.rewrite_support_point_scan_node(node)
@@ -837,7 +837,7 @@ def custom_dist_get_support_point(op, rv, size, *params):
837837
*[
838838
p
839839
for p in params
840-
if not isinstance(p.type, (RandomType, RandomGeneratorType))
840+
if not isinstance(p.type, RandomType | RandomGeneratorType)
841841
],
842842
)
843843

pymc/distributions/mixture.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ class Mixture(Distribution):
178178

179179
@classmethod
180180
def dist(cls, w, comp_dists, **kwargs):
181-
if not isinstance(comp_dists, (tuple, list)):
181+
if not isinstance(comp_dists, tuple | list):
182182
# comp_dists is a single component
183183
comp_dists = [comp_dists]
184184
elif len(comp_dists) == 1:
@@ -204,7 +204,7 @@ def dist(cls, w, comp_dists, **kwargs):
204204
# TODO: Allow these to not be a RandomVariable as long as we can call `ndim_supp` on them
205205
# and resize them
206206
if not isinstance(dist, TensorVariable) or not isinstance(
207-
dist.owner.op, (RandomVariable, SymbolicRandomVariable)
207+
dist.owner.op, RandomVariable | SymbolicRandomVariable
208208
):
209209
raise ValueError(
210210
f"Component dist must be a distribution created via the `.dist()` API, got {type(dist)}"

pymc/distributions/multivariate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,7 +1083,7 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, initv
10831083

10841084
def _lkj_normalizing_constant(eta, n):
10851085
# TODO: This is mixing python branching with the potentially symbolic n and eta variables
1086-
if not isinstance(eta, (int, float)):
1086+
if not isinstance(eta, int | float):
10871087
raise NotImplementedError("eta must be an int or float")
10881088
if not isinstance(n, int):
10891089
raise NotImplementedError("n must be an integer")
@@ -1185,7 +1185,7 @@ def dist(cls, n, eta, sd_dist, **kwargs):
11851185
if not (
11861186
isinstance(sd_dist, Variable)
11871187
and sd_dist.owner is not None
1188-
and isinstance(sd_dist.owner.op, (RandomVariable, SymbolicRandomVariable))
1188+
and isinstance(sd_dist.owner.op, RandomVariable | SymbolicRandomVariable)
11891189
and sd_dist.owner.op.ndim_supp < 2
11901190
):
11911191
raise TypeError("sd_dist must be a scalar or vector distribution variable")
@@ -2262,7 +2262,7 @@ def logp(value, mu, W, alpha, tau):
22622262
TensorVariable
22632263
"""
22642264

2265-
sparse = isinstance(W, (pytensor.sparse.SparseConstant, pytensor.sparse.SparseVariable))
2265+
sparse = isinstance(W, pytensor.sparse.SparseConstant | pytensor.sparse.SparseVariable)
22662266

22672267
if sparse:
22682268
D = sp_sum(W, axis=0)

pymc/distributions/shape_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def convert_dims(dims: Dims | None) -> StrongDims | None:
193193

194194
if isinstance(dims, str):
195195
dims = (dims,)
196-
elif isinstance(dims, (list, tuple)):
196+
elif isinstance(dims, list | tuple):
197197
dims = tuple(dims)
198198
else:
199199
raise ValueError(f"The `dims` parameter must be a tuple, str or list. Actual: {type(dims)}")
@@ -209,7 +209,7 @@ def convert_shape(shape: Shape) -> StrongShape | None:
209209
shape = (shape,)
210210
elif isinstance(shape, TensorVariable) and shape.ndim == 1:
211211
shape = tuple(shape)
212-
elif isinstance(shape, (list, tuple)):
212+
elif isinstance(shape, list | tuple):
213213
shape = tuple(shape)
214214
else:
215215
raise ValueError(
@@ -227,7 +227,7 @@ def convert_size(size: Size) -> StrongSize | None:
227227
size = (size,)
228228
elif isinstance(size, TensorVariable) and size.ndim == 1:
229229
size = tuple(size)
230-
elif isinstance(size, (list, tuple)):
230+
elif isinstance(size, list | tuple):
231231
size = tuple(size)
232232
else:
233233
raise ValueError(

pymc/distributions/timeseries.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,15 @@ def dist(cls, init_dist, innovation_dist, steps=None, **kwargs) -> pt.TensorVari
8888
if not (
8989
isinstance(init_dist, pt.TensorVariable)
9090
and init_dist.owner is not None
91-
and isinstance(init_dist.owner.op, (RandomVariable, SymbolicRandomVariable))
91+
and isinstance(init_dist.owner.op, RandomVariable | SymbolicRandomVariable)
9292
):
9393
raise TypeError("init_dist must be a distribution variable")
9494
check_dist_not_registered(init_dist)
9595

9696
if not (
9797
isinstance(innovation_dist, pt.TensorVariable)
9898
and innovation_dist.owner is not None
99-
and isinstance(innovation_dist.owner.op, (RandomVariable, SymbolicRandomVariable))
99+
and isinstance(innovation_dist.owner.op, RandomVariable | SymbolicRandomVariable)
100100
):
101101
raise TypeError("innovation_dist must be a distribution variable")
102102
check_dist_not_registered(innovation_dist)
@@ -129,7 +129,7 @@ def get_steps(cls, innovation_dist, steps, shape, dims, observed):
129129
if not (
130130
isinstance(innovation_dist, pt.TensorVariable)
131131
and innovation_dist.owner is not None
132-
and isinstance(innovation_dist.owner.op, (RandomVariable, SymbolicRandomVariable))
132+
and isinstance(innovation_dist.owner.op, RandomVariable | SymbolicRandomVariable)
133133
):
134134
raise TypeError("innovation_dist must be a distribution variable")
135135

@@ -549,7 +549,7 @@ def dist(
549549

550550
if init_dist is not None:
551551
if not isinstance(init_dist, TensorVariable) or not isinstance(
552-
init_dist.owner.op, (RandomVariable, SymbolicRandomVariable)
552+
init_dist.owner.op, RandomVariable | SymbolicRandomVariable
553553
):
554554
raise ValueError(
555555
f"Init dist must be a distribution created via the `.dist()` API, "
@@ -948,7 +948,7 @@ def dist(cls, dt, sde_fn, sde_pars, *, init_dist=None, steps=None, **kwargs):
948948

949949
if init_dist is not None:
950950
if not isinstance(init_dist, TensorVariable) or not isinstance(
951-
init_dist.owner.op, (RandomVariable, SymbolicRandomVariable)
951+
init_dist.owner.op, RandomVariable | SymbolicRandomVariable
952952
):
953953
raise ValueError(
954954
f"Init dist must be a distribution created via the `.dist()` API, "

pymc/gp/cov.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -256,11 +256,7 @@ def _merge_factors_cov(self, X, Xs=None, diag=False):
256256

257257
elif isinstance(
258258
factor,
259-
(
260-
TensorConstant,
261-
TensorVariable,
262-
TensorSharedVariable,
263-
),
259+
TensorConstant | TensorVariable | TensorSharedVariable,
264260
):
265261
if factor.ndim == 2 and diag:
266262
factor_list.append(pt.diag(factor))
@@ -524,7 +520,7 @@ def __init__(
524520
if (ls is None and ls_inv is None) or (ls is not None and ls_inv is not None):
525521
raise ValueError("Only one of 'ls' or 'ls_inv' must be provided")
526522
elif ls_inv is not None:
527-
if isinstance(ls_inv, (list, tuple)):
523+
if isinstance(ls_inv, list | tuple):
528524
ls = 1.0 / np.asarray(ls_inv)
529525
else:
530526
ls = 1.0 / ls_inv

pymc/gp/hsgp_approx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def prior_linearized(self, Xs: TensorLike):
328328

329329
# If not provided, use Xs and c to set L
330330
if self._L is None:
331-
assert isinstance(self._c, (numbers.Real, np.ndarray, pt.TensorVariable))
331+
assert isinstance(self._c, numbers.Real | np.ndarray | pt.TensorVariable)
332332
self.L = pt.as_tensor(set_boundary(Xs, self._c))
333333
else:
334334
self.L = self._L

pymc/gp/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def kmeans_inducing_points(n_inducing, X, **kmeans_kwargs):
113113
# first whiten X
114114
if isinstance(X, TensorConstant):
115115
X = X.value
116-
elif isinstance(X, (np.ndarray, tuple, list)):
116+
elif isinstance(X, np.ndarray | tuple | list):
117117
X = np.asarray(X)
118118
else:
119119
raise TypeError(

0 commit comments

Comments
 (0)