Skip to content

Commit e810543

Browse files
Merge remote-tracking branch 'origin/20160628-fsaad-cgpm' into 20160624-riastradh-cgpm
2 parents 0d3fc83 + fbc0f27 commit e810543

File tree

4 files changed

+127
-43
lines changed

4 files changed

+127
-43
lines changed

src/metamodels/cgpm_metamodel.py

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from bayeslite.util import json_dumps
3636

3737
import cgpm_schema.parse
38+
import cgpm_analyze.parse
3839

3940
CGPM_SCHEMA_1 = '''
4041
INSERT INTO bayesdb_metamodel (name, version) VALUES ('cgpm', 1);
@@ -190,7 +191,7 @@ def initialize_models(self, bdb, generator_id, modelnos):
190191
for cgpm_ext in schema['cgpm_composition']:
191192
cgpms = [self._initialize_cgpm(bdb, generator_id, cgpm_ext)
192193
for _ in xrange(n)]
193-
engine.compose_cgpm(cgpms, N=1, multithread=self._ncpu)
194+
engine.compose_cgpm(cgpms, multithread=self._ncpu)
194195

195196
# Store the newly initialized engine.
196197
engine_json = json_dumps(engine.to_metadata())
@@ -221,16 +222,46 @@ def analyze_models(self, bdb, generator_id, modelnos=None, iterations=1,
221222

222223
if ckpt_iterations is not None or ckpt_seconds is not None:
223224
# XXX
224-
raise NotImplementedError('cgpm analysis checkpoint')
225-
if program is not None:
226-
# XXX
227-
raise NotImplementedError('cgpm analysis programs')
225+
raise NotImplementedError('CGpm analysis checkpoint not supported.')
228226

229-
# Get the engine.
230-
engine = self._engine(bdb, generator_id)
227+
if program is None:
228+
program = []
229+
230+
population_id = core.bayesdb_generator_population(bdb, generator_id)
231+
232+
def retrieve_analyze_variables(ast):
233+
# Transition all variables by default.
234+
if len(ast) == 0:
235+
variables = core.bayesdb_variable_names(bdb, population_id)
236+
# Exactly 1 clause supported.
237+
elif len(ast) == 1:
238+
clause = ast[0]
239+
# Transition user specified variables only.
240+
if isinstance(clause, cgpm_analyze.parse.Variables):
241+
variables = clause.vars
242+
# Transition all variables except user specified skip.
243+
elif isinstance(clause, cgpm_analyze.parse.Skip):
244+
variables = filter(
245+
lambda v: v not in clause.vars,
246+
core.bayesdb_variable_names(bdb, population_id))
247+
# Unknown/impossible clause.
248+
else:
249+
raise ValueError('Unknown clause in ANALYZE: %s.' % ast)
250+
# Crash if more than 1 clause.
251+
else:
252+
raise ValueError('1 clause permitted in ANALYZE: %s.' % ast)
253+
return variables
231254

232-
# Do the transition.
233-
engine.transition(N=iterations, S=max_seconds, multithread=self._ncpu)
255+
# Retrieve target variables.
256+
analyze_ast = cgpm_analyze.parse.parse(program)
257+
variables = retrieve_analyze_variables(analyze_ast)
258+
varnos = [core.bayesdb_variable_number(bdb, population_id, v)
259+
for v in variables]
260+
261+
# Run transition.
262+
engine = self._engine(bdb, generator_id)
263+
engine.transition(
264+
N=iterations, S=max_seconds, cols=varnos, multithread=self._ncpu)
234265

235266
# Serialize the engine.
236267
engine_json = json_dumps(engine.to_metadata())
@@ -323,7 +354,7 @@ def simulate_joint(self, bdb, generator_id, targets, constraints, modelno,
323354
cgpm_query = [colno for _r, colno in targets]
324355
cgpm_evidence = {
325356
colno: self._to_numeric(bdb, generator_id, colno, value)
326-
for colno, value in constraints
357+
for _r, colno, value in constraints
327358
}
328359
engine = self._engine(bdb, generator_id)
329360
samples = engine.simulate(
@@ -586,6 +617,7 @@ def _create_schema(bdb, generator_id, schema_ast):
586617
# State.
587618
variables = []
588619
categoricals = {}
620+
declared_latents = []
589621
cgpm_composition = []
590622
modelled = set()
591623
subsample = None
@@ -634,6 +666,18 @@ def _retrieve_stattype_dist_params(var):
634666
variables.append([var, stattype, dist, params])
635667
modelled.add(var)
636668

669+
elif isinstance(clause, cgpm_schema.parse.Latent):
670+
# Reject if the latent variable has already been declared.
671+
if any(l[0] == clause.name for l in declared_latents):
672+
duplicate.add(clause.name)
673+
674+
# XXX FILL ME XXX
675+
676+
# Register the latent variable name and stattype into bayesdb.
677+
# Do something related to the error checking data structures.
678+
679+
declared_latents.append((clause.name, clause.stattype))
680+
637681
elif isinstance(clause, cgpm_schema.parse.Foreign):
638682
# Foreign model: some set of output variables is to be
639683
# modelled by foreign logic, possibly conditional on some
@@ -677,22 +721,21 @@ def _retrieve_stattype_dist_params(var):
677721
_, dist, params = _retrieve_stattype_dist_params(var)
678722
cctypes.append(dist)
679723
ccargs.append(params)
680-
else:
681-
# Finally, add a cgpm_composition record.
682-
cgpm_composition.append({
683-
'name': name,
684-
'inputs': inputs,
685-
'outputs': outputs,
686-
'kwds': kwds,
687-
})
724+
# Finally, add a cgpm_composition record.
725+
cgpm_composition.append({
726+
'name': name,
727+
'inputs': inputs,
728+
'outputs': outputs,
729+
'kwds': kwds,
730+
})
688731

689732
elif isinstance(clause, cgpm_schema.parse.Subsample):
690733
if subsample is not None:
691734
raise BQLError(bdb, 'Duplicate subsample: %r' % (clause.n,))
692735
subsample = clause.n
693736

694737
else:
695-
assert False
738+
raise BQLError(bdb, 'Unknown clause: %r' % (clause,))
696739

697740
# Raise an exception if there were duplicates or unknown
698741
# variables.

src/metamodels/cgpm_schema/grammar.y

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,34 +21,38 @@
2121
* - L_ means a lexeme, which has useful associated text, e.g. an integer.
2222
*/
2323

24-
cgpm(empty) ::= .
25-
cgpm(schema) ::= schema(s).
24+
cgpm(empty) ::= .
25+
cgpm(schema) ::= schema(s).
2626

27-
schema(one) ::= clause(c).
28-
schema(some) ::= schema(s) T_COMMA clause_opt(c).
27+
schema(one) ::= clause(c).
28+
schema(some) ::= schema(s) T_COMMA clause_opt(c).
2929

30-
clause_opt(none)::= .
31-
clause_opt(some)::= clause(c).
30+
clause_opt(none) ::= .
31+
clause_opt(some) ::= clause(c).
3232

33-
clause(basic) ::= var(var) dist(dist) param_opt(params).
34-
clause(foreign) ::= K_MODEL vars(outputs) given_opt(inputs)
35-
K_USING foreign(name) param_opt(params).
36-
clause(subsamp) ::= K_SUBSAMPLE L_NUMBER(n).
33+
clause(basic) ::= var(var) dist(dist) param_opt(params).
34+
clause(foreign) ::= K_MODEL vars(outputs) given_opt(inputs)
35+
K_USING foreign(name) param_opt(params).
36+
clause(subsamp) ::= K_SUBSAMPLE L_NUMBER(n).
37+
clause(latent) ::= K_LATENT var(var) stattype(st).
3738

38-
dist(name) ::= L_NAME(dist).
39-
foreign(name) ::= L_NAME(foreign).
39+
dist(name) ::= L_NAME(dist).
40+
foreign(name) ::= L_NAME(foreign).
4041

41-
given_opt(none) ::= .
42-
given_opt(some) ::= K_GIVEN vars(vars).
42+
given_opt(none) ::= .
43+
given_opt(some) ::= K_GIVEN vars(vars).
4344

44-
vars(one) ::= var(var).
45-
vars(many) ::= vars(vars) T_COMMA var(var).
46-
var(name) ::= L_NAME(var).
45+
vars(one) ::= var(var).
46+
vars(many) ::= vars(vars) T_COMMA var(var).
4747

48-
param_opt(none) ::= .
49-
param_opt(some) ::= T_LROUND params(ps) T_RROUND.
50-
params(one) ::= param(param).
51-
params(many) ::= params(params) T_COMMA param(param).
48+
var(name) ::= L_NAME(var).
5249

53-
param(num) ::= L_NAME(p) T_EQ L_NUMBER(num).
54-
param(nam) ::= L_NAME(p) T_EQ L_NAME(nam).
50+
stattype(s) ::= L_NAME(st).
51+
52+
param_opt(none) ::= .
53+
param_opt(some) ::= T_LROUND params(ps) T_RROUND.
54+
params(one) ::= param(param).
55+
params(many) ::= params(params) T_COMMA param(param).
56+
57+
param(num) ::= L_NAME(p) T_EQ L_NUMBER(num).
58+
param(nam) ::= L_NAME(p) T_EQ L_NAME(nam).

src/metamodels/cgpm_schema/parse.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
KEYWORDS = {
3333
'given': grammar.K_GIVEN,
34+
'latent': grammar.K_LATENT,
3435
'model': grammar.K_MODEL,
3536
'subsample': grammar.K_SUBSAMPLE,
3637
'using': grammar.K_USING,
@@ -132,6 +133,8 @@ def p_clause_foreign(self, outputs, inputs, name, params):
132133
return Foreign(outputs, inputs, name, params)
133134
def p_clause_subsamp(self, n):
134135
return Subsample(n)
136+
def p_clause_latent(self, var, st):
137+
return Latent(var, st)
135138

136139
def p_dist_name(self, dist): return casefold(dist)
137140
def p_foreign_name(self, foreign): return casefold(foreign)
@@ -141,7 +144,9 @@ def p_given_opt_some(self, vars): return vars
141144

142145
def p_vars_one(self, var): return [var]
143146
def p_vars_many(self, vars, var): vars.append(var); return vars
144-
def p_var_name(self, var): return var
147+
def p_var_name(self, var): return casefold(var)
148+
149+
def p_stattype_s(self, st): return st
145150

146151
def p_param_opt_none(self): return []
147152
def p_param_opt_some(self, ps): return ps
@@ -166,3 +171,8 @@ def p_param_nam(self, p, nam): return (p, nam)
166171
Subsample = namedtuple('Subsample', [
167172
'n',
168173
])
174+
175+
Latent = namedtuple('Latent', [
176+
'name',
177+
'stattype',
178+
])

tests/test_vscgpm.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import math
1818
import numpy as np
19+
import pytest
1920
import random # XXX
2021

2122
#from cgpm.regressions.forest import RandomForest
@@ -165,6 +166,8 @@ def test_cgpm_extravaganza__ci_slow():
165166
bdb.execute('''
166167
CREATE GENERATOR g0 FOR satellites USING cgpm (
167168
apogee NORMAL,
169+
LATENT kepler_cluster_id NUMERICAL,
170+
LATENT kepler_noise NUMERICAL,
168171
MODEL kepler_cluster_id, kepler_noise, period
169172
GIVEN apogee, perigee
170173
USING venturescript (source = "{}"),
@@ -177,6 +180,30 @@ def test_cgpm_extravaganza__ci_slow():
177180

178181
# -- MODEL country_of_operator GIVEN class_of_orbit USING forest;
179182
bdb.execute('INITIALIZE 1 MODELS FOR g0')
183+
bdb.execute('ANALYZE g0 FOR 1 iteration WAIT (;)')
184+
bdb.execute('''
185+
ANALYZE g0 FOR 1 iteration WAIT (VARIABLES kepler_cluster_id)
186+
''')
187+
bdb.execute('''
188+
ANALYZE g0 FOR 1 iteration WAIT (
189+
SKIP kepler_cluster_id, kepler_noise, period;
190+
)
191+
''')
192+
with pytest.raises(Exception):
193+
# Disallow both SKIP and VARIABLES clauses.
194+
#
195+
# XXX Catch a more specific exception.
196+
bdb.execute('''
197+
ANALYZE g0 FOR 1 ITERATION WAIT (
198+
SKIP kepler_cluster_id;
199+
VARIABLES apogee, perigee;
200+
)
201+
''')
202+
bdb.execute('''
203+
ANALYZE g0 FOR 1 iteration WAIT (
204+
SKIP kepler_cluster_id, kepler_noise, period;
205+
)
206+
''')
180207
bdb.execute('ANALYZE g0 FOR 1 ITERATION WAIT')
181208

182209
bdb.execute('''

0 commit comments

Comments
 (0)