Skip to content

Commit c3ce585

Browse files
Tie up various loose ends around latent variables.
- Make bayesdb_add_latent actually associate the generator with it. - Require caller to specify generator id or None in . core.bayesdb_has_variable . core.bayesdb_variable_number - Teach compiler to reject attempts to use latent variables when no generator is specified, or when they are not latent variables of the specified generator. - Fix various MODELLED BY mistakes in the compiler. - Fix bayesdb_variable table's generator_id column to cascade on delete so that DROP GENERATOR works again with latent variables.
1 parent 612c3d3 commit c3ce585

File tree

8 files changed

+206
-97
lines changed

8 files changed

+206
-97
lines changed

src/bql.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,15 @@ def execute_phrase(bdb, phrase, bindings=()):
148148
nsamples = cursor[0][0]
149149
assert isinstance(nsamples, int)
150150
constraints = [
151-
(core.bayesdb_variable_number(bdb, population_id, var), value)
151+
(core.bayesdb_variable_number(bdb, population_id, generator_id,
152+
var),
153+
value)
152154
for (var, _expression), value in
153155
zip(phrase.simulation.constraints, cursor[0][1:])
154156
]
155157
colnos = [
156-
core.bayesdb_variable_number(bdb, population_id, var)
158+
core.bayesdb_variable_number(bdb, population_id, generator_id,
159+
var)
157160
for var in column_names
158161
]
159162
bdb.sql_execute('CREATE %sTABLE %s%s (%s)' %

src/compiler.py

Lines changed: 90 additions & 40 deletions
Large diffs are not rendered by default.

src/core.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -208,20 +208,40 @@ def bayesdb_population_generators(bdb, population_id):
208208
''', (population_id,))
209209
return [generator_id for (generator_id,) in cursor]
210210

211-
def bayesdb_has_variable(bdb, population_id, name):
212-
"""True if the population has a given variable."""
213-
cursor = bdb.sql_execute('''
214-
SELECT COUNT(*) FROM bayesdb_variable
215-
WHERE population_id = ? AND name = ?
216-
''', (population_id, name))
211+
def bayesdb_has_variable(bdb, population_id, generator_id, name):
212+
"""True if the population has a given variable.
213+
214+
generator_id is None for manifest variables and the id of a
215+
generator for variables that may be latent.
216+
"""
217+
if generator_id is None:
218+
cursor = bdb.sql_execute('''
219+
SELECT COUNT(*) FROM bayesdb_variable
220+
WHERE population_id = ? AND generator_id IS NULL AND name = ?
221+
''', (population_id, name))
222+
else:
223+
cursor = bdb.sql_execute('''
224+
SELECT COUNT(*) FROM bayesdb_variable
225+
WHERE population_id = ?
226+
AND (generator_id IS NULL OR generator_id = ?)
227+
AND name = ?
228+
''', (population_id, generator_id, name))
217229
return cursor_value(cursor) != 0
218230

219-
def bayesdb_variable_number(bdb, population_id, name):
231+
def bayesdb_variable_number(bdb, population_id, generator_id, name):
220232
"""Return the column number of a population variable."""
221-
cursor = bdb.sql_execute('''
222-
SELECT colno FROM bayesdb_variable
223-
WHERE population_id = ? AND name = ?
224-
''', (population_id, name))
233+
if generator_id is None:
234+
cursor = bdb.sql_execute('''
235+
SELECT colno FROM bayesdb_variable
236+
WHERE population_id = ? AND generator_id IS NULL AND name = ?
237+
''', (population_id, name))
238+
else:
239+
cursor = bdb.sql_execute('''
240+
SELECT colno FROM bayesdb_variable
241+
WHERE population_id = ?
242+
AND (generator_id IS NULL OR generator_id = ?)
243+
AND name = ?
244+
''', (population_id, generator_id, name))
225245
return cursor_value(cursor)
226246

227247
def bayesdb_variable_names(bdb, population_id):
@@ -234,7 +254,7 @@ def bayesdb_variable_numbers(bdb, population_id):
234254
"""Return a list of the numbers of columns modelled in `population_id`."""
235255
sql = '''
236256
SELECT colno FROM bayesdb_variable
237-
WHERE population_id = ?
257+
WHERE population_id = ? AND generator_id IS NULL
238258
ORDER BY colno ASC
239259
'''
240260
return [row[0] for row in bdb.sql_execute(sql, (population_id,))]
@@ -291,9 +311,17 @@ def bayesdb_add_latent(bdb, population_id, generator_id, var, stattype):
291311
colno = min(-1, cursor_value(cursor) - 1)
292312
bdb.sql_execute('''
293313
INSERT INTO bayesdb_variable
294-
(population_id, colno, name, stattype)
295-
VALUES (?, ?, ?, ?)
296-
''', (population_id, colno, var, stattype))
314+
(population_id, generator_id, colno, name, stattype)
315+
VALUES (?, ?, ?, ?, ?)
316+
''', (population_id, generator_id, colno, var, stattype))
317+
318+
def bayesdb_has_latent(bdb, population_id, var):
319+
"""True if the population has a latent variable by the given name."""
320+
cursor = bdb.sql_execute('''
321+
SELECT COUNT(*) FROM bayesdb_variable
322+
WHERE population_id = ? AND name = ? AND generator_id IS NOT NULL
323+
''', (population_id, var))
324+
return cursor_value(cursor)
297325

298326
def bayesdb_population_cell_value(bdb, population_id, rowid, colno):
299327
if colno < 0:

src/metamodels/cgpm_metamodel.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,8 @@ def retrieve_analyze_variables(ast):
260260
# Retrieve target variables.
261261
analyze_ast = cgpm_analyze.parse.parse(program)
262262
variables = retrieve_analyze_variables(analyze_ast)
263-
varnos = [core.bayesdb_variable_number(bdb, population_id, v)
263+
varnos = [core.bayesdb_variable_number(bdb, population_id,
264+
generator_id, v)
264265
for v in variables]
265266

266267
# Run transition.
@@ -399,7 +400,8 @@ def _unique_rowid(self, rowids):
399400
def _data(self, bdb, generator_id, vars):
400401
# Get the column numbers and statistical types.
401402
population_id = core.bayesdb_generator_population(bdb, generator_id)
402-
colnos = [core.bayesdb_variable_number(bdb, population_id, var)
403+
colnos = [core.bayesdb_variable_number(bdb, population_id,
404+
generator_id, var)
403405
for var in vars]
404406
stattypes = [core.bayesdb_variable_stattype(bdb, population_id, colno)
405407
for colno in colnos]
@@ -436,7 +438,8 @@ def map_value(colno, value):
436438
def _initialize_engine(self, bdb, generator_id, n, variables):
437439
population_id = core.bayesdb_generator_population(bdb, generator_id)
438440
def map_var(var):
439-
return core.bayesdb_variable_number(bdb, population_id, var)
441+
return core.bayesdb_variable_number(bdb, population_id,
442+
generator_id, var)
440443
outputs = [map_var(var) for var, _st, _cct, _da in variables]
441444
cctypes = [cctype for _n, _st, cctype, _da in variables]
442445
distargs = [distargs for _n, _st, _cct, distargs in variables]
@@ -449,7 +452,8 @@ def map_var(var):
449452
def _initialize_cgpm(self, bdb, generator_id, cgpm_ext):
450453
population_id = core.bayesdb_generator_population(bdb, generator_id)
451454
def map_var(var):
452-
return core.bayesdb_variable_number(bdb, population_id, var)
455+
return core.bayesdb_variable_number(bdb, population_id,
456+
generator_id, var)
453457
name = cgpm_ext['name']
454458
outputs = map(map_var, cgpm_ext['outputs'])
455459
inputs = map(map_var, cgpm_ext['inputs'])
@@ -632,11 +636,11 @@ def _create_schema(bdb, generator_id, schema_ast):
632636
duplicate = set()
633637
unknown = set()
634638
needed = set()
635-
invalid_latent = set()
639+
existing_latent = set()
636640
must_exist = []
637641

638642
def _retrieve_stattype_dist_params(var):
639-
colno = core.bayesdb_variable_number(bdb, population_id, var)
643+
colno = core.bayesdb_variable_number(bdb, population_id, None, var)
640644
stattype = core.bayesdb_variable_stattype(bdb, population_id, colno)
641645
dist, params = _DEFAULT_DIST[stattype](bdb, generator_id, var)
642646
return stattype, dist, params
@@ -652,7 +656,7 @@ def _retrieve_stattype_dist_params(var):
652656
params = dict(clause.params) # XXX error checking
653657

654658
# Reject if the variable does not exist.
655-
if not core.bayesdb_has_variable(bdb, population_id, var):
659+
if not core.bayesdb_has_variable(bdb, population_id, None, var):
656660
unknown.add(var)
657661
continue
658662

@@ -661,12 +665,15 @@ def _retrieve_stattype_dist_params(var):
661665
duplicate.add(var)
662666
continue
663667

664-
# Reject if it is a latent variable.
665-
colno = core.bayesdb_variable_number(bdb, population_id, var)
666-
if colno < 0:
667-
invalid_latent.add(var)
668+
# Reject if the variable is latent.
669+
if core.bayesdb_has_latent(bdb, population_id, var):
670+
existing_latent.add(var)
668671
continue
669672

673+
# Get the column number.
674+
colno = core.bayesdb_variable_number(bdb, population_id, None, var)
675+
assert not colno < 0
676+
670677
# Add it to the list and mark it modelled by default.
671678
stattype = core.bayesdb_variable_stattype(
672679
bdb, population_id, colno)
@@ -686,10 +693,16 @@ def _retrieve_stattype_dist_params(var):
686693

687694
# Reject if the variable even *exists* in the population
688695
# at all yet.
689-
if core.bayesdb_has_variable(bdb, population_id, var):
696+
if core.bayesdb_has_variable(bdb, population_id, None, var):
690697
duplicate.add(var)
691698
continue
692699

700+
# Reject if the variable is already latent, from another
701+
# generator.
702+
if core.bayesdb_has_latent(bdb, population_id, var):
703+
existing_latent.add(var)
704+
continue
705+
693706
# Add it to the set of latent variables.
694707
latents[var] = stattype
695708

@@ -751,7 +764,7 @@ def _retrieve_stattype_dist_params(var):
751764
# Make sure all the outputs and inputs exist, either in the
752765
# population or as latents in this generator.
753766
for var in must_exist:
754-
if core.bayesdb_has_variable(bdb, population_id, var):
767+
if core.bayesdb_has_variable(bdb, population_id, None, var):
755768
continue
756769
if var in latents:
757770
continue
@@ -765,17 +778,17 @@ def _retrieve_stattype_dist_params(var):
765778
if unknown:
766779
raise BQLError(bdb, 'Unknown model variables: %r' %
767780
(sorted(unknown),))
768-
if invalid_latent:
769-
raise BQLError(bdb, 'Invalid latent variables: %r' %
770-
(sorted(invalid_latent),))
781+
if existing_latent:
782+
raise BQLError(bdb, 'Latent variables already defined: %r' %
783+
(sorted(existing_latent),))
771784

772785
# Use the default distribution for any variables that remain to be
773786
# modelled, excluding any that are latent or that have statistical
774787
# types we don't know about.
775788
for var in core.bayesdb_variable_names(bdb, population_id):
776789
if var in modelled:
777790
continue
778-
colno = core.bayesdb_variable_number(bdb, population_id, var)
791+
colno = core.bayesdb_variable_number(bdb, population_id, None, var)
779792
if colno < 0:
780793
continue
781794
stattype, dist, params = _retrieve_stattype_dist_params(var)

src/schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@
131131
132132
CREATE TABLE bayesdb_variable (
133133
population_id INTEGER NOT NULL REFERENCES bayesdb_population(id),
134-
generator_id INTEGER REFERENCES bayesdb_generator(id),
134+
generator_id INTEGER REFERENCES bayesdb_generator(id)
135+
ON DELETE CASCADE,
135136
colno INTEGER NOT NULL,
136137
name TEXT COLLATE NOCASE NOT NULL,
137138
stattype TEXT COLLATE NOCASE NOT NULL

tests/test_bql.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,8 @@ def test_estimate_columns_trivial():
454454
prefix1 = ' FROM bayesdb_population AS p,' \
455455
' bayesdb_variable AS v, bayesdb_column AS c' \
456456
' WHERE p.id = 1 AND v.population_id = p.id' \
457-
' AND c.tabname = p.tabname AND c.colno = v.colno'
457+
' AND c.tabname = p.tabname AND c.colno = v.colno' \
458+
' AND v.generator_id IS NULL'
458459
prefix = prefix0 + prefix1
459460
assert bql2sql('estimate * from columns of p1;') == \
460461
prefix + ';'
@@ -557,6 +558,8 @@ def test_estimate_pairwise_trivial():
557558
infix0 += ' AND v0.population_id = p.id AND v1.population_id = p.id'
558559
infix0 += ' AND c0.tabname = p.tabname AND c0.colno = v0.colno'
559560
infix0 += ' AND c1.tabname = p.tabname AND c1.colno = v1.colno'
561+
infix0 += ' AND v0.generator_id IS NULL'
562+
infix0 += ' AND v1.generator_id IS NULL'
560563
infix += infix0
561564
assert bql2sql('estimate dependence probability'
562565
' from pairwise columns of p1;') == \
@@ -695,6 +698,7 @@ def test_estimate_pairwise_selected_columns():
695698
' AND v0.population_id = p.id AND v1.population_id = p.id' \
696699
' AND c0.tabname = p.tabname AND c0.colno = v0.colno' \
697700
' AND c1.tabname = p.tabname AND c1.colno = v1.colno' \
701+
' AND v0.generator_id IS NULL AND v1.generator_id IS NULL' \
698702
' AND c0.colno IN (1, 2) AND c1.colno IN (1, 2);'
699703
assert bql2sql('estimate dependence probability'
700704
' from pairwise columns of p1'
@@ -710,6 +714,7 @@ def test_estimate_pairwise_selected_columns():
710714
' AND v0.population_id = p.id AND v1.population_id = p.id' \
711715
' AND c0.tabname = p.tabname AND c0.colno = v0.colno' \
712716
' AND c1.tabname = p.tabname AND c1.colno = v1.colno' \
717+
' AND v0.generator_id IS NULL AND v1.generator_id IS NULL' \
713718
' AND c0.colno IN (3, 1) AND c1.colno IN (3, 1);'
714719

715720
def test_select_columns_subquery():
@@ -980,9 +985,11 @@ def trace(string, _bindings):
980985
' bayesdb_column AS c'
981986
' WHERE p.id = 1 AND v.population_id = p.id'
982987
' AND c.tabname = p.tabname AND c.colno = v.colno'
988+
' AND v.generator_id IS NULL'
983989
' LIMIT 1',
984990
'SELECT colno FROM bayesdb_variable'
985-
' WHERE population_id = ? AND name = ?',
991+
' WHERE population_id = ? AND generator_id IS NULL'
992+
' AND name = ?',
986993
# ESTIMATE SIMILARITY TO (rowid=1):
987994
'SELECT tabname FROM bayesdb_population WHERE id = ?',
988995
'SELECT bql_row_similarity(1, NULL, _rowid_,'
@@ -1022,9 +1029,11 @@ def trace(string, _bindings):
10221029
' bayesdb_column AS c'
10231030
' WHERE p.id = 1 AND v.population_id = p.id'
10241031
' AND c.tabname = p.tabname AND c.colno = v.colno'
1032+
' AND v.generator_id IS NULL'
10251033
' LIMIT ?1',
10261034
'SELECT colno FROM bayesdb_variable'
1027-
' WHERE population_id = ? AND name = ?',
1035+
' WHERE population_id = ? AND generator_id IS NULL'
1036+
' AND name = ?',
10281037
'SELECT tabname FROM bayesdb_population WHERE id = ?',
10291038
# ESTIMATE SIMILARITY TO (rowid=1):
10301039
'SELECT bql_row_similarity(1, NULL, _rowid_,'
@@ -1054,13 +1063,17 @@ def trace(string, _bindings):
10541063
'PRAGMA table_info("t")',
10551064
"SELECT CAST(4 AS INTEGER), 'F'",
10561065
'SELECT colno FROM bayesdb_variable'
1057-
' WHERE population_id = ? AND name = ?',
1066+
' WHERE population_id = ? AND generator_id IS NULL'
1067+
' AND name = ?',
10581068
'SELECT colno FROM bayesdb_variable'
1059-
' WHERE population_id = ? AND name = ?',
1069+
' WHERE population_id = ? AND generator_id IS NULL'
1070+
' AND name = ?',
10601071
'SELECT colno FROM bayesdb_variable'
1061-
' WHERE population_id = ? AND name = ?',
1072+
' WHERE population_id = ? AND generator_id IS NULL'
1073+
' AND name = ?',
10621074
'SELECT colno FROM bayesdb_variable'
1063-
' WHERE population_id = ? AND name = ?',
1075+
' WHERE population_id = ? AND generator_id IS NULL'
1076+
' AND name = ?',
10641077
'CREATE TEMP TABLE IF NOT EXISTS "sim"'
10651078
' ("age" NUMERIC,"RANK" NUMERIC,"division" NUMERIC)',
10661079
'SELECT tabname FROM bayesdb_population WHERE id = ?',
@@ -1194,13 +1207,13 @@ def trace(string, _bindings):
11941207
'SELECT id FROM bayesdb_population WHERE name = ?',
11951208
'SELECT tabname FROM bayesdb_population WHERE id = ?',
11961209
'PRAGMA table_info("t")',
1197-
'SELECT name, stattype FROM bayesdb_variable'
1198-
' WHERE population_id = ? AND colno < 0',
11991210
"SELECT CAST(4 AS INTEGER), 'F'",
12001211
'SELECT colno FROM bayesdb_variable'
1201-
' WHERE population_id = ? AND name = ?',
1212+
' WHERE population_id = ? AND generator_id IS NULL'
1213+
' AND name = ?',
12021214
'SELECT colno FROM bayesdb_variable'
1203-
' WHERE population_id = ? AND name = ?',
1215+
' WHERE population_id = ? AND generator_id IS NULL'
1216+
' AND name = ?',
12041217
'SELECT tabname FROM bayesdb_population WHERE id = ?',
12051218
'SELECT MAX(_rowid_) FROM "t"',
12061219
'SELECT id FROM bayesdb_generator WHERE population_id = ?',
@@ -1440,7 +1453,7 @@ def test_createtab():
14401453
overrides=[('age', 'ignore')])
14411454
bdb.execute('drop population p0')
14421455
population_id = core.bayesdb_get_population(bdb, 'p')
1443-
colno = core.bayesdb_variable_number(bdb, population_id, 'age')
1456+
colno = core.bayesdb_variable_number(bdb, population_id, None, 'age')
14441457
assert core.bayesdb_variable_stattype(bdb, population_id, colno) == \
14451458
'numerical'
14461459
bdb.execute('initialize 1 model for p_cc')

tests/test_core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -568,9 +568,9 @@ def impute_and_confidence(self, seed, M_c, X_L, X_D, Y, Q, n):
568568
)
569569
''')
570570
pid = core.bayesdb_get_population(bdb, 'p1')
571-
assert core.bayesdb_variable_number(bdb, pid, 'label') == 1
572-
assert core.bayesdb_variable_number(bdb, pid, 'age') == 2
573-
assert core.bayesdb_variable_number(bdb, pid, 'weight') == 3
571+
assert core.bayesdb_variable_number(bdb, pid, None, 'label') == 1
572+
assert core.bayesdb_variable_number(bdb, pid, None, 'age') == 2
573+
assert core.bayesdb_variable_number(bdb, pid, None, 'weight') == 3
574574
gid = core.bayesdb_get_generator(bdb, 'p1_cc')
575575
from bayeslite.metamodels.crosscat import crosscat_cc_colno
576576
assert crosscat_cc_colno(bdb, gid, 1) == 0

0 commit comments

Comments
 (0)