@@ -143,16 +143,16 @@ def pandas_to_array(data):
143
143
144
144
145
145
def change_rv_size (
146
- rv_var : TensorVariable ,
146
+ rv : TensorVariable ,
147
147
new_size : PotentialShapeType ,
148
148
expand : Optional [bool ] = False ,
149
149
) -> TensorVariable :
150
150
"""Change or expand the size of a `RandomVariable`.
151
151
152
152
Parameters
153
153
==========
154
- rv_var
155
- The `RandomVariable` output.
154
+ rv
155
+ The old `RandomVariable` output.
156
156
new_size
157
157
The new size.
158
158
expand:
@@ -167,32 +167,32 @@ def change_rv_size(
167
167
new_size = (new_size ,)
168
168
169
169
# Extract the RV node that is to be resized, together with its inputs, name and tag
170
- if isinstance (rv_var .owner .op , SpecifyShape ):
171
- rv_var = rv_var .owner .inputs [0 ]
172
- rv_node = rv_var .owner
170
+ if isinstance (rv .owner .op , SpecifyShape ):
171
+ rv = rv .owner .inputs [0 ]
172
+ rv_node = rv .owner
173
173
rng , size , dtype , * dist_params = rv_node .inputs
174
- name = rv_var .name
175
- tag = rv_var .tag
174
+ name = rv .name
175
+ tag = rv .tag
176
176
177
177
if expand :
178
- old_shape = tuple (rv_node .op ._infer_shape (size , dist_params ))
179
- old_size = old_shape [: len (old_shape ) - rv_node .op .ndim_supp ]
180
- new_size = tuple (new_size ) + tuple (old_size )
178
+ shape = tuple (rv_node .op ._infer_shape (size , dist_params ))
179
+ size = shape [: len (shape ) - rv_node .op .ndim_supp ]
180
+ new_size = tuple (new_size ) + tuple (size )
181
181
182
182
# Make sure the new size is a tensor. This dtype-aware conversion helps
183
183
# to not unnecessarily pick up a `Cast` in some cases (see #4652).
184
184
new_size = at .as_tensor (new_size , ndim = 1 , dtype = "int64" )
185
185
186
186
new_rv_node = rv_node .op .make_node (rng , new_size , dtype , * dist_params )
187
- rv_var = new_rv_node .outputs [- 1 ]
188
- rv_var .name = name
187
+ new_rv = new_rv_node .outputs [- 1 ]
188
+ new_rv .name = name
189
189
for k , v in tag .__dict__ .items ():
190
- rv_var .tag .__dict__ .setdefault (k , v )
190
+ new_rv .tag .__dict__ .setdefault (k , v )
191
191
192
192
if config .compute_test_value != "off" :
193
193
compute_test_value (new_rv_node )
194
194
195
- return rv_var
195
+ return new_rv
196
196
197
197
198
198
def extract_rv_and_value_vars (
0 commit comments