Skip to content

Commit 4e0a01a

Browse files
Add generator_id parameter to bayesdb_variable_numbers.
Require callers to specify whether they want to include the latent variables of a specific generator or not.
1 parent 9469465 commit 4e0a01a

File tree

4 files changed

+27
-17
lines changed

4 files changed

+27
-17
lines changed

src/bqlfn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,8 @@ def bql_row_similarity(bdb, population_id, generator_id, rowid, target_rowid,
356356
if target_rowid is None:
357357
raise BQLError(bdb, 'No such target row for SIMILARITY')
358358
if len(colnos) == 0:
359-
colnos = core.bayesdb_variable_numbers(bdb, population_id)
359+
colnos = core.bayesdb_variable_numbers(bdb, population_id,
360+
generator_id)
360361
def generator_similarity(generator_id):
361362
metamodel = core.bayesdb_generator_metamodel(bdb, generator_id)
362363
return metamodel.row_similarity(bdb, generator_id, None, rowid,

src/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1194,7 +1194,8 @@ def compile_column_lists(bdb, population_id, generator_id, column_lists,
11941194
else:
11951195
out.write(', ')
11961196
if isinstance(collist, ast.ColListAll):
1197-
colnos = core.bayesdb_variable_numbers(bdb, population_id)
1197+
colnos = core.bayesdb_variable_numbers(bdb, population_id,
1198+
generator_id)
11981199
out.write(', '.join(str(colno) for colno in colnos))
11991200
elif isinstance(collist, ast.ColListLit):
12001201
unknown = set()

src/core.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -244,20 +244,27 @@ def bayesdb_variable_number(bdb, population_id, generator_id, name):
244244
''', (population_id, generator_id, name))
245245
return cursor_value(cursor)
246246

247-
def bayesdb_variable_names(bdb, population_id):
247+
def bayesdb_variable_names(bdb, population_id, generator_id):
248248
"""Return a list of the names of columns modelled in `population_id`."""
249-
colnos = bayesdb_variable_numbers(bdb, population_id)
249+
colnos = bayesdb_variable_numbers(bdb, population_id, generator_id)
250250
return [bayesdb_variable_name(bdb, population_id, colno)
251251
for colno in colnos]
252252

253-
def bayesdb_variable_numbers(bdb, population_id):
253+
def bayesdb_variable_numbers(bdb, population_id, generator_id):
254254
"""Return a list of the numbers of columns modelled in `population_id`."""
255-
sql = '''
256-
SELECT colno FROM bayesdb_variable
257-
WHERE population_id = ? AND generator_id IS NULL
258-
ORDER BY colno ASC
259-
'''
260-
return [row[0] for row in bdb.sql_execute(sql, (population_id,))]
255+
if generator_id is None:
256+
cursor = bdb.sql_execute('''
257+
SELECT colno FROM bayesdb_variable
258+
WHERE population_id = ? AND generator_id IS NULL
259+
ORDER BY colno ASC
260+
''', (population_id,))
261+
else:
262+
cursor = bdb.sql_execute('''
263+
SELECT colno FROM bayesdb_variable
264+
WHERE population_id = ? AND generator_id = ?
265+
ORDER BY colno ASC
266+
''', (population_id, generator_id))
267+
return [colno for (colno,) in cursor]
261268

262269
def bayesdb_variable_name(bdb, population_id, colno):
263270
"""Return the name a population variable."""

src/metamodels/cgpm_metamodel.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,8 @@ def analyze_models(self, bdb, generator_id, modelnos=None, iterations=1,
237237
def retrieve_analyze_variables(ast):
238238
# Transition all variables by default.
239239
if len(ast) == 0:
240-
variables = core.bayesdb_variable_names(bdb, population_id)
240+
variables = core.bayesdb_variable_names(bdb, population_id,
241+
generator_id)
241242
# Exactly 1 clause supported.
242243
elif len(ast) == 1:
243244
clause = ast[0]
@@ -248,7 +249,8 @@ def retrieve_analyze_variables(ast):
248249
elif isinstance(clause, cgpm_analyze.parse.Skip):
249250
variables = filter(
250251
lambda v: v not in clause.vars,
251-
core.bayesdb_variable_names(bdb, population_id))
252+
core.bayesdb_variable_names(bdb, population_id,
253+
generator_id))
252254
# Unknown/impossible clause.
253255
else:
254256
raise ValueError('Unknown clause in ANALYZE: %s.' % ast)
@@ -672,7 +674,7 @@ def _retrieve_stattype_dist_params(var):
672674

673675
# Get the column number.
674676
colno = core.bayesdb_variable_number(bdb, population_id, None, var)
675-
assert not colno < 0
677+
assert 0 <= colno
676678

677679
# Add it to the list and mark it modelled by default.
678680
stattype = core.bayesdb_variable_stattype(
@@ -785,12 +787,11 @@ def _retrieve_stattype_dist_params(var):
785787
# Use the default distribution for any variables that remain to be
786788
# modelled, excluding any that are latent or that have statistical
787789
# types we don't know about.
788-
for var in core.bayesdb_variable_names(bdb, population_id):
790+
for var in core.bayesdb_variable_names(bdb, population_id, None):
789791
if var in modelled:
790792
continue
791793
colno = core.bayesdb_variable_number(bdb, population_id, None, var)
792-
if colno < 0:
793-
continue
794+
assert 0 <= colno
794795
stattype, dist, params = _retrieve_stattype_dist_params(var)
795796
if stattype not in _DEFAULT_DIST:
796797
assert False # XXX Why would you be here, anyway?

0 commit comments

Comments
 (0)