Skip to content

Commit 95ccf8e

Browse files
Finish teaching CGPM metamodel to handle latents.
1 parent 1b8d22c commit 95ccf8e

File tree

1 file changed

+36
-15
lines changed

1 file changed

+36
-15
lines changed

src/metamodels/cgpm_metamodel.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ def create_generator(self, bdb, generator_id, schema_tokens):
103103
table = core.bayesdb_population_table(bdb, population_id)
104104
qt = sqlite3_quote_name(table)
105105

106+
# Assign latent variable numbers.
107+
for var, stattype in sorted(schema['latents'].iteritems()):
108+
core.bayesdb_add_latent(
109+
bdb, population_id, generator_id, var, stattype)
110+
106111
# Assign codes to categories and consecutive column numbers to
107112
# the modelled variables.
108113
vars_cursor = bdb.sql_execute('''
@@ -617,16 +622,18 @@ def _create_schema(bdb, generator_id, schema_ast):
617622
# State.
618623
variables = []
619624
categoricals = {}
620-
declared_latents = []
625+
latents = {}
621626
cgpm_composition = []
622627
modelled = set()
628+
default_modelled = set()
623629
subsample = None
624630

625631
# Error-reporting state.
626632
duplicate = set()
627633
unknown = set()
628634
needed = set()
629635
invalid_latent = set()
636+
must_exist = []
630637

631638
def _retrieve_stattype_dist_params(var):
632639
colno = core.bayesdb_variable_number(bdb, population_id, var)
@@ -660,23 +667,31 @@ def _retrieve_stattype_dist_params(var):
660667
invalid_latent.add(var)
661668
continue
662669

663-
# Add it to the list and mark it modelled.
670+
# Add it to the list and mark it modelled by default.
664671
stattype = core.bayesdb_variable_stattype(
665672
bdb, population_id, colno)
666673
variables.append([var, stattype, dist, params])
667674
modelled.add(var)
675+
default_modelled.add(var)
668676

669677
elif isinstance(clause, cgpm_schema.parse.Latent):
670-
# Reject if the latent variable has already been declared.
671-
if any(l[0] == clause.name for l in declared_latents):
672-
duplicate.add(clause.name)
678+
var = clause.name
679+
stattype = clause.stattype
673680

674-
# XXX FILL ME XXX
681+
# Reject if the variable has already been modelled by the
682+
# default model.
683+
if var in default_modelled:
684+
duplicate.add(var)
685+
continue
675686

676-
# Register the latent variable name and stattype into bayesdb.
677-
# Do something related to the error checking data structures.
687+
# Reject if the variable even *exists* in the population
688+
# at all yet.
689+
if core.bayesdb_has_variable(bdb, population_id, var):
690+
duplicate.add(var)
691+
continue
678692

679-
declared_latents.append((clause.name, clause.stattype))
693+
# Add it to the set of latent variables.
694+
latents[var] = stattype
680695

681696
elif isinstance(clause, cgpm_schema.parse.Foreign):
682697
# Foreign model: some set of output variables is to be
@@ -699,9 +714,7 @@ def _retrieve_stattype_dist_params(var):
699714
# First make sure all the output variables exist and have
700715
# not yet been modelled.
701716
for var in clause.outputs:
702-
if not core.bayesdb_has_variable(bdb, population_id, var):
703-
unknown.add(var)
704-
continue
717+
must_exist.append(var)
705718
if var in modelled:
706719
duplicate.add(var)
707720
break
@@ -711,9 +724,7 @@ def _retrieve_stattype_dist_params(var):
711724
# them needed, and record where to put their
712725
# distribution type and parameters.
713726
for var in inputs:
714-
if not core.bayesdb_has_variable(bdb, population_id, var):
715-
unknown.add(var)
716-
continue
727+
must_exist.append(var)
717728
needed.add(var)
718729
# XXX check agreement with statistical type
719730
assert len(cctypes) == len(ccargs)
@@ -737,6 +748,15 @@ def _retrieve_stattype_dist_params(var):
737748
else:
738749
raise BQLError(bdb, 'Unknown clause: %r' % (clause,))
739750

751+
# Make sure all the outputs and inputs exist, either in the
752+
# population or as latents in this generator.
753+
for var in must_exist:
754+
if core.bayesdb_has_variable(bdb, population_id, var):
755+
continue
756+
if var in latents:
757+
continue
758+
missing.add(var)
759+
740760
# Raise an exception if there were duplicates or unknown
741761
# variables.
742762
if duplicate:
@@ -775,6 +795,7 @@ def _retrieve_stattype_dist_params(var):
775795
'variables': variables,
776796
'cgpm_composition': cgpm_composition,
777797
'subsample': subsample,
798+
'latents': latents,
778799
}
779800

780801
def _default_categorical(bdb, generator_id, var):

0 commit comments

Comments
 (0)