@@ -166,18 +166,15 @@ def _get_row_at(
166
166
result = dict (zip (var_names , data [0 ][0 ]))
167
167
return result
168
168
169
- def _get_rows ( # pylint: disable=W0221
169
+ def _get_rows (
170
170
self ,
171
171
var_name : str ,
172
172
nshape : Optional [Sequence [int ]],
173
173
dtype : str ,
174
- * ,
175
- burn : int = 0 ,
174
+ slc : slice = slice (None ),
176
175
) -> numpy .ndarray :
177
176
self ._commit ()
178
- data = self ._client .execute (
179
- f"SELECT (`{ var_name } `) FROM { self .cid } WHERE _draw_idx>={ burn } ;"
180
- )
177
+ data = self ._client .execute (f"SELECT (`{ var_name } `) FROM { self .cid } ;" )
181
178
draws = len (data )
182
179
183
180
# Safety checks
@@ -201,20 +198,20 @@ def _get_rows( # pylint: disable=W0221
201
198
arr [:] = buffer
202
199
return arr
203
200
# Otherwise (identical shapes) we can collapse into one ndarray
204
- return numpy .asarray (buffer , dtype = dtype )
201
+ return numpy .asarray (buffer , dtype = dtype )[ slc ]
205
202
206
- def get_draws (self , var_name : str ) -> numpy .ndarray :
203
+ def get_draws (self , var_name : str , slc : slice = slice ( None ) ) -> numpy .ndarray :
207
204
var = self .variables [var_name ]
208
205
nshape = var .shape if not var .undefined_ndim else None
209
- return self ._get_rows (var_name , nshape , var .dtype )
206
+ return self ._get_rows (var_name , nshape , var .dtype , slc )
210
207
211
208
def get_draws_at (self , idx : int , var_names : Sequence [str ]) -> Dict [str , numpy .ndarray ]:
212
209
return self ._get_row_at (idx , var_names )
213
210
214
- def get_stats (self , stat_name : str ) -> numpy .ndarray :
211
+ def get_stats (self , stat_name : str , slc : slice = slice ( None ) ) -> numpy .ndarray :
215
212
var = self .sample_stats [stat_name ]
216
213
nshape = var .shape if not var .undefined_ndim else None
217
- return self ._get_rows (f"__stat_{ stat_name } " , nshape , var .dtype )
214
+ return self ._get_rows (f"__stat_{ stat_name } " , nshape , var .dtype , slc )
218
215
219
216
def get_stats_at (self , idx : int , stat_names : Sequence [str ]) -> Dict [str , numpy .ndarray ]:
220
217
stats = self ._get_row_at (idx , [f"__stat_{ sname } " for sname in stat_names ])
0 commit comments