|
66 | 66 | from bayeslite.util import casefold |
67 | 67 |
|
68 | 68 | import cgpm_schema.parse |
| 69 | +import cgpm_analyze.parse |
69 | 70 |
|
70 | 71 | CGPM_SCHEMA_1 = ''' |
71 | 72 | INSERT INTO bayesdb_metamodel (name, version) VALUES ('cgpm', 1); |
@@ -221,7 +222,7 @@ def initialize_models(self, bdb, generator_id, modelnos): |
221 | 222 | for cgpm_ext in schema['cgpm_composition']: |
222 | 223 | cgpms = [self._initialize_cgpm(bdb, generator_id, cgpm_ext) |
223 | 224 | for _ in xrange(n)] |
224 | | - engine.compose_cgpm(cgpms, N=1, multithread=self._ncpu) |
| 225 | + engine.compose_cgpm(cgpms, multithread=self._ncpu) |
225 | 226 |
|
226 | 227 | # Store the newly initialized engine. |
227 | 228 | engine_json = json_dumps(engine.to_metadata()) |
@@ -252,16 +253,70 @@ def analyze_models(self, bdb, generator_id, modelnos=None, iterations=1, |
252 | 253 |
|
253 | 254 | if ckpt_iterations is not None or ckpt_seconds is not None: |
254 | 255 | # XXX |
255 | | - raise NotImplementedError('cgpm analysis checkpoint') |
256 | | - if program is not None: |
257 | | - # XXX |
258 | | - raise NotImplementedError('cgpm analysis programs') |
| 256 | + raise NotImplementedError('CGpm analysis checkpoint not supported.') |
| 257 | + |
| 258 | + if program is None: |
| 259 | + program = [] |
| 260 | + |
| 261 | + population_id = core.bayesdb_generator_population(bdb, generator_id) |
| 262 | + |
| 263 | + def retrieve_analyze_variables(ast): |
| 264 | + # Transition all variables by default. |
| 265 | + if len(ast) == 0: |
| 266 | + variables = core.bayesdb_variable_names(bdb, population_id) |
| 267 | + # Exactly 1 clause supported. |
| 268 | + elif len(ast) == 1: |
| 269 | + clause = ast[0] |
| 270 | + # Transition user specified variables only. |
| 271 | + if isinstance(clause, cgpm_analyze.parse.Variables): |
| 272 | + variables = clause.vars |
| 273 | + # Transition all variables except user specified skip. |
| 274 | + elif isinstance(clause, cgpm_analyze.parse.Skip): |
| 275 | + variables = filter( |
| 276 | + lambda v: v not in clause.vars, |
| 277 | + core.bayesdb_variable_names(bdb, population_id)) |
| 278 | + # Unknown/impossible clause. |
| 279 | + else: |
| 280 | + raise ValueError('Unknown clause in ANALYZE: %s.' % ast) |
| 281 | + # Crash if more than 1 clause. |
| 282 | + else: |
| 283 | + raise ValueError('1 clause permitted in ANALYZE: %s.' % ast) |
| 284 | + return variables |
| 285 | + |
| 286 | + def foreign(varname): |
| 287 | + schema = self._schema(bdb, generator_id) |
| 288 | + return all(v[0]!=varname for v in schema['variables']) |
| 289 | + |
| 290 | + # Retrieve target variables. |
| 291 | + analyze_ast = cgpm_analyze.parse.parse(program) |
| 292 | + variables = retrieve_analyze_variables(analyze_ast) |
| 293 | + varnames_gpmcc = [v for v in variables if not foreign(v)] |
| 294 | + varnames_foreign = [v for v in variables if foreign(v)] |
259 | 295 |
|
260 | 296 | # Get the engine. |
261 | 297 | engine = self._engine(bdb, generator_id) |
262 | 298 |
|
263 | | - # Do the transition. |
264 | | - engine.transition(N=iterations, S=max_seconds, multithread=self._ncpu) |
| 299 | + # Transition gpmcc variables. |
| 300 | + if varnames_gpmcc: |
| 301 | + print varnames_gpmcc |
| 302 | + varnos_gpmcc = [ |
| 303 | + core.bayesdb_variable_number(bdb, population_id, v) |
| 304 | + for v in varnames_gpmcc |
| 305 | + ] |
| 306 | + engine.transition( |
| 307 | + N=iterations, S=max_seconds, cols=varnos_gpmcc, |
| 308 | + multithread=self._ncpu) |
| 309 | + |
| 310 | + # Transition foreign variables. |
| 311 | + if varnames_foreign: |
| 312 | + print varnames_foreign |
| 313 | + varnos_foreign = [ |
| 314 | + core.bayesdb_variable_number(bdb, population_id, v) |
| 315 | + for v in varnames_foreign |
| 316 | + ] |
| 317 | + engine.transition_foreign( |
| 318 | + N=iterations, S=max_seconds, cols=varnos_foreign, |
| 319 | + multithread=self._ncpu) |
265 | 320 |
|
266 | 321 | # Serialize the engine. |
267 | 322 | engine_json = json_dumps(engine.to_metadata()) |
|
0 commit comments