Skip to content

Commit b9cdd5c

Browse files
committed
Rename internal variables in change_rv_size
1 parent 72f255d commit b9cdd5c

File tree

2 files changed

+17
-17
lines changed

2 files changed

+17
-17
lines changed

pymc/aesaraf.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -143,16 +143,16 @@ def pandas_to_array(data):
143143

144144

145145
def change_rv_size(
146-
rv_var: TensorVariable,
146+
rv: TensorVariable,
147147
new_size: PotentialShapeType,
148148
expand: Optional[bool] = False,
149149
) -> TensorVariable:
150150
"""Change or expand the size of a `RandomVariable`.
151151
152152
Parameters
153153
==========
154-
rv_var
155-
The `RandomVariable` output.
154+
rv
155+
The old `RandomVariable` output.
156156
new_size
157157
The new size.
158158
expand:
@@ -167,32 +167,32 @@ def change_rv_size(
167167
new_size = (new_size,)
168168

169169
# Extract the RV node that is to be resized, together with its inputs, name and tag
170-
if isinstance(rv_var.owner.op, SpecifyShape):
171-
rv_var = rv_var.owner.inputs[0]
172-
rv_node = rv_var.owner
170+
if isinstance(rv.owner.op, SpecifyShape):
171+
rv = rv.owner.inputs[0]
172+
rv_node = rv.owner
173173
rng, size, dtype, *dist_params = rv_node.inputs
174-
name = rv_var.name
175-
tag = rv_var.tag
174+
name = rv.name
175+
tag = rv.tag
176176

177177
if expand:
178-
old_shape = tuple(rv_node.op._infer_shape(size, dist_params))
179-
old_size = old_shape[: len(old_shape) - rv_node.op.ndim_supp]
180-
new_size = tuple(new_size) + tuple(old_size)
178+
shape = tuple(rv_node.op._infer_shape(size, dist_params))
179+
size = shape[: len(shape) - rv_node.op.ndim_supp]
180+
new_size = tuple(new_size) + tuple(size)
181181

182182
# Make sure the new size is a tensor. This dtype-aware conversion helps
183183
# to not unnecessarily pick up a `Cast` in some cases (see #4652).
184184
new_size = at.as_tensor(new_size, ndim=1, dtype="int64")
185185

186186
new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params)
187-
rv_var = new_rv_node.outputs[-1]
188-
rv_var.name = name
187+
new_rv = new_rv_node.outputs[-1]
188+
new_rv.name = name
189189
for k, v in tag.__dict__.items():
190-
rv_var.tag.__dict__.setdefault(k, v)
190+
new_rv.tag.__dict__.setdefault(k, v)
191191

192192
if config.compute_test_value != "off":
193193
compute_test_value(new_rv_node)
194194

195-
return rv_var
195+
return new_rv
196196

197197

198198
def extract_rv_and_value_vars(

pymc/distributions/distribution.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def __new__(
269269

270270
if resize_shape:
271271
# A batch size was specified through `dims`, or implied by `observed`.
272-
rv_out = change_rv_size(rv_var=rv_out, new_size=resize_shape, expand=True)
272+
rv_out = change_rv_size(rv=rv_out, new_size=resize_shape, expand=True)
273273

274274
rv_out = model.register_rv(
275275
rv_out,
@@ -355,7 +355,7 @@ def dist(
355355
# Replicate dimensions may be prepended via a shape with Ellipsis as the last element:
356356
if shape is not None and Ellipsis in shape:
357357
replicate_shape = cast(StrongShape, shape[:-1])
358-
rv_out = change_rv_size(rv_var=rv_out, new_size=replicate_shape, expand=True)
358+
rv_out = change_rv_size(rv=rv_out, new_size=replicate_shape, expand=True)
359359

360360
rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
361361
rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")

0 commit comments

Comments
 (0)