Skip to content

Commit b78bb93

Browse files
committed
Add expectation_method and urn_model options for omega expectations
1 parent 0c05004 commit b78bb93

File tree

7 files changed

+336
-44
lines changed

7 files changed

+336
-44
lines changed

csubst/csubst

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -508,18 +508,28 @@ def _main():
508508
analyze.add_argument('--cbs', metavar='yes|no', default='no', type=strtobool,
509509
help='default=%(default)s: Combinatorial-branch-site output. Set "yes" to generate the output tsv.')
510510
# Omega_C calculation
511-
analyze.add_argument('--omegaC_method', metavar='[submodel|modelfree]', default='submodel', type=str,
512-
choices=['modelfree','submodel'],
513-
help='default=%(default)s: Method to calculate omega_C. '
514-
'"submodel" utilizes a codon substitution model in the ancestral state reconstruction. '
515-
'In addition to the base substitution models, codon frequencies and '
516-
'among-site rate heterogeneity are taken into account. '
511+
analyze.add_argument('--expectation_method', metavar='codon_model|urn', default=None, type=str,
512+
choices=['codon_model', 'urn'],
513+
help='default=codon_model: Method to calculate omega_C expected values. '
514+
'"codon_model" utilizes a codon substitution model in ancestral-state reconstruction. '
515+
'In addition to base substitution models, codon frequencies and among-site rate '
516+
'heterogeneity are taken into account. '
517517
'Described in Fukushima and Pollock (2023, https://doi.org/10.1038/s41559-022-01932-7). '
518-
'"modelfree" (experimental) for expected values from among-site randomization '
519-
'(urn sampling) of substitutions. ')
518+
'"urn" uses among-site randomization (weighted urn sampling).')
519+
analyze.add_argument('--urn_model', metavar='wallenius|fisher|factorized_approx', default=None, type=str,
520+
choices=['wallenius', 'fisher', 'factorized_approx'],
521+
help='default=wallenius: Urn expectation model used when --expectation_method urn. '
522+
'"wallenius" uses Wallenius-type inclusion probabilities; '
523+
'"fisher" uses Fisher-type inclusion probabilities; '
524+
'"factorized_approx" uses the legacy factorized approximation.')
525+
analyze.add_argument('--omegaC_method', metavar='[submodel|modelfree]', default=None, type=str,
526+
choices=['modelfree','submodel'],
527+
help='Deprecated alias for --expectation_method. '
528+
'"submodel" maps to --expectation_method codon_model. '
529+
'"modelfree" maps to --expectation_method urn.')
520530
analyze.add_argument('--calc_omega_pvalue', metavar='yes|no', default='no', type=strtobool,
521531
help='default=%(default)s: Experimental feature. Estimate branch-combination-wise one-sided empirical '
522-
'P values of omega_C by substitution randomization (modelfree only). '
532+
'P values of omega_C by substitution randomization (--expectation_method urn only). '
523533
'When --calibrate_longtail is active, p-value columns are suffixed with "_nocalib".')
524534
analyze.add_argument('--omega_pvalue_null_model', metavar='hypergeom|poisson|poisson_full|nbinom', default='hypergeom', type=str,
525535
choices=['hypergeom', 'poisson', 'poisson_full', 'nbinom'],
@@ -553,7 +563,7 @@ def _main():
553563
analyze.add_argument('--asrv', metavar='no|pool|sn|each|file|file_each', default='each', type=str,
554564
choices=['no', 'pool', 'sn', 'each', 'file', 'file_each'],
555565
help='default=%(default)s: Experimental. Correct among-site rate variation in omega calculation. '
556-
'This option is used in --omegaC_method modelfree but not with --omegaC_method submodel. '
566+
'This option is used with --expectation_method urn but not with --expectation_method codon_model. '
557567
'"no", No ASRV, meaning a uniform rate among sites. '
558568
'"pool", All categories of substitutions are pooled to calculate a single set of ASRV. '
559569
'"sn", Synonymous and nonsynonymous substitutions are processed individually '

csubst/main_analyze.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,18 @@
2020
from csubst import tree
2121

2222

23+
def _resolve_expectation_method(g):
24+
token = g.get('expectation_method', None)
25+
if (token is None) or (str(token).strip() == ''):
26+
token = g.get('omegaC_method', 'submodel')
27+
normalized = str(token).strip().lower()
28+
if normalized in ['codon_model', 'submodel', '']:
29+
return 'codon_model'
30+
if normalized in ['urn', 'modelfree']:
31+
return 'urn'
32+
raise ValueError('Unsupported expectation method: {}'.format(token))
33+
34+
2335
def _remap_site_column_to_alignment(df, g, column_name='site'):
2436
if column_name not in df.columns:
2537
return df
@@ -541,8 +553,8 @@ def _prepare_epistasis_configuration(g, ON_tensor, OS_tensor):
541553
if not bool(g.get('epistasis_requested', False)):
542554
g['epistasis_enabled'] = False
543555
return g
544-
if str(g.get('omegaC_method', '')).strip().lower() != 'modelfree':
545-
raise ValueError('--epistasis_beta should be used with --omegaC_method "modelfree".')
556+
if _resolve_expectation_method(g) != 'urn':
557+
raise ValueError('--epistasis_beta should be used with --expectation_method "urn".')
546558
num_site = int(ON_tensor.shape[1])
547559
degree_internal = _load_epistasis_degree_from_file(g=g, num_site=num_site)
548560
if degree_internal is None:
@@ -659,7 +671,7 @@ def main_analyze(g):
659671
elapsed_time = int(time.time() - start)
660672
print(("Elapsed time: {:,.1f} sec\n".format(elapsed_time)), flush=True)
661673

662-
if (g['omegaC_method']!='submodel'):
674+
if _resolve_expectation_method(g) != 'codon_model':
663675
g['state_cdn'] = None
664676
g['state_pep'] = None
665677
g['state_nsy'] = None

csubst/omega.py

Lines changed: 165 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,6 +1204,32 @@ def _resolve_omega_pvalue_rounding_mode(g):
12041204
return mode
12051205

12061206

1207+
def _resolve_expectation_method(g):
1208+
token = 'codon_model'
1209+
if g is not None:
1210+
token = g.get('expectation_method', None)
1211+
if (token is None) or (str(token).strip() == ''):
1212+
token = g.get('omegaC_method', 'submodel')
1213+
normalized = str(token).strip().lower()
1214+
if normalized in ['codon_model', 'submodel', '']:
1215+
return 'codon_model'
1216+
if normalized in ['urn', 'modelfree']:
1217+
return 'urn'
1218+
raise ValueError('Unsupported expectation_method: {}'.format(token))
1219+
1220+
1221+
def _resolve_urn_model(g):
1222+
token = 'wallenius'
1223+
if g is not None:
1224+
token = g.get('urn_model', 'wallenius')
1225+
normalized = str(token).strip().lower()
1226+
if normalized in ['wallenius', 'fisher']:
1227+
return normalized
1228+
if normalized in ['factorized_approx', 'factorized', 'legacy_factorized', 'approx']:
1229+
return 'factorized_approx'
1230+
raise ValueError('urn_model should be one of wallenius, fisher, factorized_approx.')
1231+
1232+
12071233
def _prepare_permutation_branch_sizes(sub_branches, niter, g):
12081234
sub_branches = np.asarray(sub_branches)
12091235
if sub_branches.ndim != 1:
@@ -1283,7 +1309,71 @@ def _calc_wallenius_inclusion_probabilities(site_weights, draw_size, float_type=
12831309
return out.astype(float_type, copy=False)
12841310

12851311

1286-
def _calc_wallenius_expected_overlap(cb_ids, sub_sites, sub_branches, g, float_type):
1312+
def _calc_fisher_inclusion_probabilities(site_weights, draw_size, float_type=np.float64):
1313+
site_weights = np.asarray(site_weights, dtype=np.float64).reshape(-1)
1314+
if site_weights.ndim != 1:
1315+
raise ValueError('site_weights should be a 1D array.')
1316+
if (site_weights < 0).any():
1317+
raise ValueError('site_weights should be non-negative.')
1318+
draw_size = int(draw_size)
1319+
if draw_size < 0:
1320+
raise ValueError('draw_size should be >= 0.')
1321+
out = np.zeros(shape=site_weights.shape, dtype=np.float64)
1322+
positive_mask = (site_weights > 0)
1323+
positive_weights = site_weights[positive_mask]
1324+
num_positive = int(positive_weights.shape[0])
1325+
if (draw_size == 0) or (num_positive == 0):
1326+
return out.astype(float_type, copy=False)
1327+
if draw_size >= num_positive:
1328+
out[positive_mask] = 1.0
1329+
return out.astype(float_type, copy=False)
1330+
1331+
target = float(draw_size)
1332+
lo = 0.0
1333+
hi = 1.0
1334+
for _ in range(128):
1335+
scaled = hi * positive_weights
1336+
current = (1.0 - (1.0 / (1.0 + scaled))).sum(dtype=np.float64)
1337+
if current >= target:
1338+
break
1339+
hi *= 2.0
1340+
for _ in range(96):
1341+
mid = (lo + hi) / 2.0
1342+
scaled = mid * positive_weights
1343+
current = (1.0 - (1.0 / (1.0 + scaled))).sum(dtype=np.float64)
1344+
if current < target:
1345+
lo = mid
1346+
else:
1347+
hi = mid
1348+
lam = (lo + hi) / 2.0
1349+
scaled = lam * positive_weights
1350+
out_positive = 1.0 - (1.0 / (1.0 + scaled))
1351+
out[positive_mask] = np.clip(out_positive, a_min=0.0, a_max=1.0)
1352+
return out.astype(float_type, copy=False)
1353+
1354+
1355+
def _normalize_sub_sites_for_urn_overlap(sub_sites, num_branch):
1356+
sub_sites = np.asarray(sub_sites, dtype=np.float64)
1357+
if sub_sites.ndim == 1:
1358+
sub_sites = np.broadcast_to(sub_sites.reshape(1, -1), (int(num_branch), sub_sites.shape[0]))
1359+
elif sub_sites.ndim != 2:
1360+
raise ValueError('sub_sites should be a 1D or 2D array.')
1361+
if sub_sites.shape[0] != int(num_branch):
1362+
txt = 'sub_sites branch axis ({}) and sub_branches length ({}) should match.'
1363+
raise ValueError(txt.format(sub_sites.shape[0], int(num_branch)))
1364+
if (sub_sites < 0).any():
1365+
raise ValueError('sub_sites should be non-negative.')
1366+
return sub_sites
1367+
1368+
1369+
def _calc_weighted_urn_expected_overlap(
1370+
cb_ids,
1371+
sub_sites,
1372+
sub_branches,
1373+
g,
1374+
float_type,
1375+
inclusion_probability_func,
1376+
):
12871377
cb_ids = np.asarray(cb_ids, dtype=np.int64)
12881378
if cb_ids.ndim != 2:
12891379
raise ValueError('cb_ids should be a 2D array.')
@@ -1298,17 +1388,7 @@ def _calc_wallenius_expected_overlap(cb_ids, sub_sites, sub_branches, g, float_t
12981388
if not np.isfinite(sub_branches).all():
12991389
raise ValueError('sub_branches should be finite.')
13001390
np.clip(sub_branches, a_min=0.0, a_max=None, out=sub_branches)
1301-
1302-
sub_sites = np.asarray(sub_sites, dtype=np.float64)
1303-
if sub_sites.ndim == 1:
1304-
sub_sites = np.broadcast_to(sub_sites.reshape(1, -1), (sub_branches.shape[0], sub_sites.shape[0]))
1305-
elif sub_sites.ndim != 2:
1306-
raise ValueError('sub_sites should be a 1D or 2D array.')
1307-
if sub_sites.shape[0] != sub_branches.shape[0]:
1308-
txt = 'sub_sites branch axis ({}) and sub_branches length ({}) should match.'
1309-
raise ValueError(txt.format(sub_sites.shape[0], sub_branches.shape[0]))
1310-
if (sub_sites < 0).any():
1311-
raise ValueError('sub_sites should be non-negative.')
1391+
sub_sites = _normalize_sub_sites_for_urn_overlap(sub_sites=sub_sites, num_branch=sub_branches.shape[0])
13121392
if sub_sites.shape[0] <= cb_ids.max():
13131393
raise ValueError('cb_ids contain out-of-range branch IDs.')
13141394

@@ -1324,15 +1404,15 @@ def _calc_wallenius_expected_overlap(cb_ids, sub_sites, sub_branches, g, float_t
13241404
continue
13251405
size_lo = int(np.clip(base[branch_id], a_min=0, a_max=num_positive))
13261406
size_hi = int(np.clip(base[branch_id] + 1, a_min=0, a_max=num_positive))
1327-
prob_lo = _calc_wallenius_inclusion_probabilities(
1407+
prob_lo = inclusion_probability_func(
13281408
site_weights=branch_weights,
13291409
draw_size=size_lo,
13301410
float_type=np.float64,
13311411
)
13321412
if (frac[branch_id] <= 0) or (size_lo == size_hi):
13331413
inclusion_prob[branch_id, :] = prob_lo
13341414
else:
1335-
prob_hi = _calc_wallenius_inclusion_probabilities(
1415+
prob_hi = inclusion_probability_func(
13361416
site_weights=branch_weights,
13371417
draw_size=size_hi,
13381418
float_type=np.float64,
@@ -1354,7 +1434,7 @@ def _calc_wallenius_expected_overlap(cb_ids, sub_sites, sub_branches, g, float_t
13541434
if num_positive == 0:
13551435
continue
13561436
draw_size = int(np.clip(rounded_sizes[branch_id], a_min=0, a_max=num_positive))
1357-
inclusion_prob[branch_id, :] = _calc_wallenius_inclusion_probabilities(
1437+
inclusion_prob[branch_id, :] = inclusion_probability_func(
13581438
site_weights=branch_weights,
13591439
draw_size=draw_size,
13601440
float_type=np.float64,
@@ -1366,6 +1446,63 @@ def _calc_wallenius_expected_overlap(cb_ids, sub_sites, sub_branches, g, float_t
13661446
)
13671447

13681448

1449+
def _calc_wallenius_expected_overlap(cb_ids, sub_sites, sub_branches, g, float_type):
1450+
return _calc_weighted_urn_expected_overlap(
1451+
cb_ids=cb_ids,
1452+
sub_sites=sub_sites,
1453+
sub_branches=sub_branches,
1454+
g=g,
1455+
float_type=float_type,
1456+
inclusion_probability_func=_calc_wallenius_inclusion_probabilities,
1457+
)
1458+
1459+
1460+
def _calc_fisher_expected_overlap(cb_ids, sub_sites, sub_branches, g, float_type):
1461+
return _calc_weighted_urn_expected_overlap(
1462+
cb_ids=cb_ids,
1463+
sub_sites=sub_sites,
1464+
sub_branches=sub_branches,
1465+
g=g,
1466+
float_type=float_type,
1467+
inclusion_probability_func=_calc_fisher_inclusion_probabilities,
1468+
)
1469+
1470+
1471+
def _calc_urn_expected_overlap(cb_ids, sub_sites, sub_branches, g, float_type):
1472+
urn_model = _resolve_urn_model(g=g)
1473+
if urn_model == 'wallenius':
1474+
return _calc_wallenius_expected_overlap(
1475+
cb_ids=cb_ids,
1476+
sub_sites=sub_sites,
1477+
sub_branches=sub_branches,
1478+
g=g,
1479+
float_type=float_type,
1480+
)
1481+
if urn_model == 'fisher':
1482+
return _calc_fisher_expected_overlap(
1483+
cb_ids=cb_ids,
1484+
sub_sites=sub_sites,
1485+
sub_branches=sub_branches,
1486+
g=g,
1487+
float_type=float_type,
1488+
)
1489+
if urn_model == 'factorized_approx':
1490+
sub_branches = np.asarray(sub_branches, dtype=np.float64).reshape(-1)
1491+
if sub_branches.ndim != 1:
1492+
raise ValueError('sub_branches should be a 1D array.')
1493+
if not np.isfinite(sub_branches).all():
1494+
raise ValueError('sub_branches should be finite.')
1495+
np.clip(sub_branches, a_min=0.0, a_max=None, out=sub_branches)
1496+
sub_sites = _normalize_sub_sites_for_urn_overlap(sub_sites=sub_sites, num_branch=sub_branches.shape[0])
1497+
return _calc_tmp_E_sum(
1498+
cb_ids=cb_ids,
1499+
sub_sites=sub_sites,
1500+
sub_branches=sub_branches,
1501+
float_type=float_type,
1502+
)
1503+
raise ValueError('Unsupported urn_model: {}'.format(urn_model))
1504+
1505+
13691506
def _fill_packed_masks_for_sizes(packed_masks_branch, site_p, size_values):
13701507
size_values = np.asarray(size_values, dtype=np.int64).reshape(-1)
13711508
if packed_masks_branch.shape[0] != size_values.shape[0]:
@@ -1742,7 +1879,7 @@ def calc_E_mean(mode, cb_ids, sub_sg, sub_bg, obs_col, list_igad, g, static_sub_
17421879
else:
17431880
sub_sites = static_sub_sites
17441881
sub_branches = substitution.get_sub_branches(sub_bg, mode, sg, a, d)
1745-
E_b += _calc_wallenius_expected_overlap(
1882+
E_b += _calc_urn_expected_overlap(
17461883
cb_ids=cb_ids,
17471884
sub_sites=sub_sites,
17481885
sub_branches=sub_branches,
@@ -1777,7 +1914,7 @@ def joblib_calc_E_mean(
17771914
else:
17781915
sub_sites = static_sub_sites
17791916
sub_branches = substitution.get_sub_branches(sub_bg, mode, sg, a, d)
1780-
dfEb += _calc_wallenius_expected_overlap(
1917+
dfEb += _calc_urn_expected_overlap(
17811918
cb_ids=cb_ids,
17821919
sub_sites=sub_sites,
17831920
sub_branches=sub_branches,
@@ -2015,7 +2152,8 @@ def subroot_E2nan(cb, tree):
20152152
def get_E(cb, g, ON_tensor, OS_tensor):
20162153
requested_output_stats = _resolve_requested_output_stats(g)
20172154
base_stats = output_stat.get_required_base_stats(requested_output_stats)
2018-
if (g['omegaC_method']=='modelfree'):
2155+
expectation_method = _resolve_expectation_method(g=g)
2156+
if expectation_method == 'urn':
20192157
ON_gad, ON_ga, ON_gd = substitution.get_group_state_totals(ON_tensor)
20202158
OS_gad, OS_ga, OS_gd = substitution.get_group_state_totals(OS_tensor)
20212159
g['N_ind_nomissing_gad'] = np.where(ON_gad!=0)
@@ -2027,7 +2165,7 @@ def get_E(cb, g, ON_tensor, OS_tensor):
20272165
for st in base_stats:
20282166
cb['ECN'+st] = calc_E_stat(cb, ON_tensor, mode=st, stat='mean', SN='N', g=g)
20292167
cb['ECS'+st] = calc_E_stat(cb, OS_tensor, mode=st, stat='mean', SN='S', g=g)
2030-
if (g['omegaC_method']=='submodel'):
2168+
elif expectation_method == 'codon_model':
20312169
id_cols = cb.columns[cb.columns.str.startswith('branch_id_')]
20322170
state_nsyE = get_exp_state(g=g, mode='nsy')
20332171
if (g['current_arity']==2):
@@ -2059,6 +2197,8 @@ def get_E(cb, g, ON_tensor, OS_tensor):
20592197
)
20602198
cb = table.merge_tables(cb, cbES)
20612199
del state_cdnE,cbES
2200+
else:
2201+
raise ValueError('Unsupported expectation_method: {}'.format(expectation_method))
20622202
cb = substitution.add_dif_stats(cb, g['float_tol'], prefix='EC', output_stats=requested_output_stats)
20632203
cb = subroot_E2nan(cb, tree=g['tree'])
20642204
return cb
@@ -2598,7 +2738,7 @@ def _calc_poisson_count_matrix(
25982738
else:
25992739
sub_sites = static_sub_sites
26002740
sub_branches = substitution.get_sub_branches(sub_bg, mode, sg, a, d)
2601-
mean_count = _calc_wallenius_expected_overlap(
2741+
mean_count = _calc_urn_expected_overlap(
26022742
cb_ids=cb_ids,
26032743
sub_sites=sub_sites,
26042744
sub_branches=sub_branches,
@@ -2717,7 +2857,7 @@ def _calc_nbinom_count_matrix(
27172857
else:
27182858
sub_sites = static_sub_sites
27192859
sub_branches = substitution.get_sub_branches(sub_bg, mode, sg, a, d)
2720-
mean_count = _calc_wallenius_expected_overlap(
2860+
mean_count = _calc_urn_expected_overlap(
27212861
cb_ids=cb_ids,
27222862
sub_sites=sub_sites,
27232863
sub_branches=sub_branches,
@@ -2817,7 +2957,7 @@ def _calc_poisson_full_count_matrix(
28172957
sub_site_mass[nonzero_branch, :] /
28182958
sub_branches[nonzero_branch, None]
28192959
)
2820-
mean_count = _calc_wallenius_expected_overlap(
2960+
mean_count = _calc_urn_expected_overlap(
28212961
cb_ids=cb_ids,
28222962
sub_sites=sub_sites,
28232963
sub_branches=sub_branches,
@@ -3141,8 +3281,9 @@ def _resolve_omega_pvalue_dsc_calibration_transformation(cb, sub, g):
31413281
def add_omega_empirical_pvalues(cb, ON_tensor, OS_tensor, g):
31423282
if not bool(g.get('calc_omega_pvalue', False)):
31433283
return cb
3144-
if str(g.get('omegaC_method', '')).strip().lower() != 'modelfree':
3145-
sys.stderr.write('Skipping --calc_omega_pvalue because --omegaC_method is not "modelfree".\n')
3284+
if _resolve_expectation_method(g=g) != 'urn':
3285+
txt = 'Skipping --calc_omega_pvalue because --expectation_method is not "urn".\n'
3286+
sys.stderr.write(txt)
31463287
return cb
31473288
null_model = _resolve_omega_pvalue_null_model(g=g)
31483289
txt = 'omega_C empirical p-value null model: {}'

0 commit comments

Comments
 (0)