@@ -61,19 +61,20 @@ def rv_dict_to_flat_array_wrapper(
6161
6262 @wraps (fn )
6363 def seeded_array_fn (seed : SeedType = None ):
64- inital_value_dict = fn (seed )
65- total_size = sum (np .prod (shape ) for shape in shapes )
64+ initial_value_dict = fn (seed )
65+ total_size = sum (np .prod (shape ). astype ( int ) for shape in shapes )
6666 flat_array = np .empty (total_size , dtype = "float64" , order = "C" )
6767 cursor = 0
6868
6969 for name , shape in zip (names , shapes ):
70- initial_value = inital_value_dict [name ]
70+ initial_value = initial_value_dict [name ]
7171 n = int (np .prod (initial_value .shape ))
7272 if initial_value .shape != shape :
7373 raise ValueError (
7474 f"Size of initial value for { name } is { initial_value .shape } , "
7575 f"expected { shape } "
7676 )
77+
7778 flat_array [cursor : cursor + n ] = initial_value .ravel ().astype ("float64" )
7879 cursor += n
7980
@@ -144,16 +145,16 @@ def with_data(self, **updates):
144145 user_data = user_data ,
145146 )
146147
147- def _make_sampler (self , settings , init_mean , cores , progress_type ):
148- model = self ._make_model (init_mean )
148+ def _make_sampler (self , settings , cores , progress_type ):
149+ model = self ._make_model ()
149150 return _lib .PySampler .from_pymc (
150151 settings ,
151152 cores ,
152153 model ,
153154 progress_type ,
154155 )
155156
156- def _make_model (self , init_mean ):
157+ def _make_model (self ):
157158 expand_fn = _lib .ExpandFunc (
158159 self .n_dim ,
159160 self .n_expanded ,
@@ -169,14 +170,15 @@ def _make_model(self, init_mean):
169170 )
170171
171172 var_sizes = [prod (shape ) for shape in self .shape_info [2 ]]
173+ var_names = self .shape_info [0 ]
172174
173175 return _lib .PyMcModel (
174176 self .n_dim ,
175177 logp_fn ,
176178 expand_fn ,
177179 self .initial_point_func ,
178180 var_sizes ,
179- self . shape_info [ 0 ] ,
181+ var_names ,
180182 )
181183
182184
@@ -472,7 +474,7 @@ def compile_pymc_model(
472474 overrides = overrides ,
473475 default_strategy = default_strategy ,
474476 jitter_rvs = jitter_rvs ,
475- return_transformed = False ,
477+ return_transformed = True ,
476478 )
477479
478480 if backend .lower () == "numba" :
0 commit comments