@@ -103,6 +103,11 @@ def create_generator(self, bdb, generator_id, schema_tokens):
103103 table = core .bayesdb_population_table (bdb , population_id )
104104 qt = sqlite3_quote_name (table )
105105
106+ # Assign latent variable numbers.
107+ for var , stattype in sorted (schema ['latents' ].iteritems ()):
108+ core .bayesdb_add_latent (
109+ bdb , population_id , generator_id , var , stattype )
110+
106111 # Assign codes to categories and consecutive column numbers to
107112 # the modelled variables.
108113 vars_cursor = bdb .sql_execute ('''
@@ -617,16 +622,18 @@ def _create_schema(bdb, generator_id, schema_ast):
617622 # State.
618623 variables = []
619624 categoricals = {}
620- declared_latents = []
625+ latents = {}
621626 cgpm_composition = []
622627 modelled = set ()
628+ default_modelled = set ()
623629 subsample = None
624630
625631 # Error-reporting state.
626632 duplicate = set ()
627633 unknown = set ()
628634 needed = set ()
629635 invalid_latent = set ()
636+ must_exist = []
630637
631638 def _retrieve_stattype_dist_params (var ):
632639 colno = core .bayesdb_variable_number (bdb , population_id , var )
@@ -660,23 +667,31 @@ def _retrieve_stattype_dist_params(var):
660667 invalid_latent .add (var )
661668 continue
662669
663- # Add it to the list and mark it modelled.
670+ # Add it to the list and mark it modelled by default .
664671 stattype = core .bayesdb_variable_stattype (
665672 bdb , population_id , colno )
666673 variables .append ([var , stattype , dist , params ])
667674 modelled .add (var )
675+ default_modelled .add (var )
668676
669677 elif isinstance (clause , cgpm_schema .parse .Latent ):
670- # Reject if the latent variable has already been declared.
671- if any (l [0 ] == clause .name for l in declared_latents ):
672- duplicate .add (clause .name )
678+ var = clause .name
679+ stattype = clause .stattype
673680
674- # XXX FILL ME XXX
681+ # Reject if the variable has already been modelled by the
682+ # default model.
683+ if var in default_modelled :
684+ duplicate .add (var )
685+ continue
675686
676- # Register the latent variable name and stattype into bayesdb.
677- # Do something related to the error checking data structures.
687+ # Reject if the variable even *exists* in the population
688+ # at all yet.
689+ if core .bayesdb_has_variable (bdb , population_id , var ):
690+ duplicate .add (var )
691+ continue
678692
679- declared_latents .append ((clause .name , clause .stattype ))
693+ # Add it to the set of latent variables.
694+ latents [var ] = stattype
680695
681696 elif isinstance (clause , cgpm_schema .parse .Foreign ):
682697 # Foreign model: some set of output variables is to be
@@ -699,9 +714,7 @@ def _retrieve_stattype_dist_params(var):
699714 # First make sure all the output variables exist and have
700715 # not yet been modelled.
701716 for var in clause .outputs :
702- if not core .bayesdb_has_variable (bdb , population_id , var ):
703- unknown .add (var )
704- continue
717+ must_exist .append (var )
705718 if var in modelled :
706719 duplicate .add (var )
707720 break
@@ -711,9 +724,7 @@ def _retrieve_stattype_dist_params(var):
711724 # them needed, and record where to put their
712725 # distribution type and parameters.
713726 for var in inputs :
714- if not core .bayesdb_has_variable (bdb , population_id , var ):
715- unknown .add (var )
716- continue
727+ must_exist .append (var )
717728 needed .add (var )
718729 # XXX check agreement with statistical type
719730 assert len (cctypes ) == len (ccargs )
@@ -737,6 +748,15 @@ def _retrieve_stattype_dist_params(var):
737748 else :
738749 raise BQLError (bdb , 'Unknown clause: %r' % (clause ,))
739750
751+ # Make sure all the outputs and inputs exist, either in the
752+ # population or as latents in this generator.
753+ for var in must_exist :
754+ if core .bayesdb_has_variable (bdb , population_id , var ):
755+ continue
756+ if var in latents :
757+ continue
758+ missing .add (var )
759+
740760 # Raise an exception if there were duplicates or unknown
741761 # variables.
742762 if duplicate :
@@ -775,6 +795,7 @@ def _retrieve_stattype_dist_params(var):
775795 'variables' : variables ,
776796 'cgpm_composition' : cgpm_composition ,
777797 'subsample' : subsample ,
798+ 'latents' : latents ,
778799 }
779800
780801def _default_categorical (bdb , generator_id , var ):
0 commit comments