@@ -99,26 +99,27 @@ def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None):
9999 :py:func:`pymc.util.get_random_generator` for more information.
100100 """
101101 self .vars = vars
102+ self .var_names = tuple (cast (str , var .name ) for var in vars )
102103 self .shared = {get_var_name (var ): shared for var , shared in shared .items ()}
103104 self .blocked = blocked
104105 self .rng = get_random_generator (rng )
105106
106107 def step (self , point : PointType ) -> tuple [PointType , StatsType ]:
107- for name , shared_var in self .shared .items ():
108- shared_var .set_value (point [name ])
109-
110- var_dict = {cast (str , v .name ): point [cast (str , v .name )] for v in self .vars }
111- q = DictToArrayBijection .map (var_dict )
112-
108+ full_point = None
109+ if self .shared :
110+ for name , shared_var in self .shared .items ():
111+ shared_var .set_value (point [name ], borrow = True )
112+ full_point = point
113+ point = {name : point [name ] for name in self .var_names }
114+
115+ q = DictToArrayBijection .map (point )
113116 apoint , stats = self .astep (q )
114117
115118 if not isinstance (apoint , RaveledVars ):
116119 # We assume that the mapping has stayed the same
117120 apoint = RaveledVars (apoint , q .point_map_info )
118121
119- new_point = DictToArrayBijection .rmap (apoint , start_point = point )
120-
121- return new_point , stats
122+ return DictToArrayBijection .rmap (apoint , start_point = full_point ), stats
122123
123124 @abstractmethod
124125 def astep (self , q0 : RaveledVars ) -> tuple [RaveledVars , StatsType ]:
0 commit comments