@@ -1204,6 +1204,32 @@ def _resolve_omega_pvalue_rounding_mode(g):
12041204 return mode
12051205
12061206
1207+ def _resolve_expectation_method (g ):
1208+ token = 'codon_model'
1209+ if g is not None :
1210+ token = g .get ('expectation_method' , None )
1211+ if (token is None ) or (str (token ).strip () == '' ):
1212+ token = g .get ('omegaC_method' , 'submodel' )
1213+ normalized = str (token ).strip ().lower ()
1214+ if normalized in ['codon_model' , 'submodel' , '' ]:
1215+ return 'codon_model'
1216+ if normalized in ['urn' , 'modelfree' ]:
1217+ return 'urn'
1218+ raise ValueError ('Unsupported expectation_method: {}' .format (token ))
1219+
1220+
1221+ def _resolve_urn_model (g ):
1222+ token = 'wallenius'
1223+ if g is not None :
1224+ token = g .get ('urn_model' , 'wallenius' )
1225+ normalized = str (token ).strip ().lower ()
1226+ if normalized in ['wallenius' , 'fisher' ]:
1227+ return normalized
1228+ if normalized in ['factorized_approx' , 'factorized' , 'legacy_factorized' , 'approx' ]:
1229+ return 'factorized_approx'
1230+ raise ValueError ('urn_model should be one of wallenius, fisher, factorized_approx.' )
1231+
1232+
12071233def _prepare_permutation_branch_sizes (sub_branches , niter , g ):
12081234 sub_branches = np .asarray (sub_branches )
12091235 if sub_branches .ndim != 1 :
@@ -1283,7 +1309,71 @@ def _calc_wallenius_inclusion_probabilities(site_weights, draw_size, float_type=
12831309 return out .astype (float_type , copy = False )
12841310
12851311
1286- def _calc_wallenius_expected_overlap (cb_ids , sub_sites , sub_branches , g , float_type ):
1312+ def _calc_fisher_inclusion_probabilities (site_weights , draw_size , float_type = np .float64 ):
1313+ site_weights = np .asarray (site_weights , dtype = np .float64 ).reshape (- 1 )
1314+ if site_weights .ndim != 1 :
1315+ raise ValueError ('site_weights should be a 1D array.' )
1316+ if (site_weights < 0 ).any ():
1317+ raise ValueError ('site_weights should be non-negative.' )
1318+ draw_size = int (draw_size )
1319+ if draw_size < 0 :
1320+ raise ValueError ('draw_size should be >= 0.' )
1321+ out = np .zeros (shape = site_weights .shape , dtype = np .float64 )
1322+ positive_mask = (site_weights > 0 )
1323+ positive_weights = site_weights [positive_mask ]
1324+ num_positive = int (positive_weights .shape [0 ])
1325+ if (draw_size == 0 ) or (num_positive == 0 ):
1326+ return out .astype (float_type , copy = False )
1327+ if draw_size >= num_positive :
1328+ out [positive_mask ] = 1.0
1329+ return out .astype (float_type , copy = False )
1330+
1331+ target = float (draw_size )
1332+ lo = 0.0
1333+ hi = 1.0
1334+ for _ in range (128 ):
1335+ scaled = hi * positive_weights
1336+ current = (1.0 - (1.0 / (1.0 + scaled ))).sum (dtype = np .float64 )
1337+ if current >= target :
1338+ break
1339+ hi *= 2.0
1340+ for _ in range (96 ):
1341+ mid = (lo + hi ) / 2.0
1342+ scaled = mid * positive_weights
1343+ current = (1.0 - (1.0 / (1.0 + scaled ))).sum (dtype = np .float64 )
1344+ if current < target :
1345+ lo = mid
1346+ else :
1347+ hi = mid
1348+ lam = (lo + hi ) / 2.0
1349+ scaled = lam * positive_weights
1350+ out_positive = 1.0 - (1.0 / (1.0 + scaled ))
1351+ out [positive_mask ] = np .clip (out_positive , a_min = 0.0 , a_max = 1.0 )
1352+ return out .astype (float_type , copy = False )
1353+
1354+
1355+ def _normalize_sub_sites_for_urn_overlap (sub_sites , num_branch ):
1356+ sub_sites = np .asarray (sub_sites , dtype = np .float64 )
1357+ if sub_sites .ndim == 1 :
1358+ sub_sites = np .broadcast_to (sub_sites .reshape (1 , - 1 ), (int (num_branch ), sub_sites .shape [0 ]))
1359+ elif sub_sites .ndim != 2 :
1360+ raise ValueError ('sub_sites should be a 1D or 2D array.' )
1361+ if sub_sites .shape [0 ] != int (num_branch ):
1362+ txt = 'sub_sites branch axis ({}) and sub_branches length ({}) should match.'
1363+ raise ValueError (txt .format (sub_sites .shape [0 ], int (num_branch )))
1364+ if (sub_sites < 0 ).any ():
1365+ raise ValueError ('sub_sites should be non-negative.' )
1366+ return sub_sites
1367+
1368+
1369+ def _calc_weighted_urn_expected_overlap (
1370+ cb_ids ,
1371+ sub_sites ,
1372+ sub_branches ,
1373+ g ,
1374+ float_type ,
1375+ inclusion_probability_func ,
1376+ ):
12871377 cb_ids = np .asarray (cb_ids , dtype = np .int64 )
12881378 if cb_ids .ndim != 2 :
12891379 raise ValueError ('cb_ids should be a 2D array.' )
@@ -1298,17 +1388,7 @@ def _calc_wallenius_expected_overlap(cb_ids, sub_sites, sub_branches, g, float_t
12981388 if not np .isfinite (sub_branches ).all ():
12991389 raise ValueError ('sub_branches should be finite.' )
13001390 np .clip (sub_branches , a_min = 0.0 , a_max = None , out = sub_branches )
1301-
1302- sub_sites = np .asarray (sub_sites , dtype = np .float64 )
1303- if sub_sites .ndim == 1 :
1304- sub_sites = np .broadcast_to (sub_sites .reshape (1 , - 1 ), (sub_branches .shape [0 ], sub_sites .shape [0 ]))
1305- elif sub_sites .ndim != 2 :
1306- raise ValueError ('sub_sites should be a 1D or 2D array.' )
1307- if sub_sites .shape [0 ] != sub_branches .shape [0 ]:
1308- txt = 'sub_sites branch axis ({}) and sub_branches length ({}) should match.'
1309- raise ValueError (txt .format (sub_sites .shape [0 ], sub_branches .shape [0 ]))
1310- if (sub_sites < 0 ).any ():
1311- raise ValueError ('sub_sites should be non-negative.' )
1391+ sub_sites = _normalize_sub_sites_for_urn_overlap (sub_sites = sub_sites , num_branch = sub_branches .shape [0 ])
13121392 if sub_sites .shape [0 ] <= cb_ids .max ():
13131393 raise ValueError ('cb_ids contain out-of-range branch IDs.' )
13141394
@@ -1324,15 +1404,15 @@ def _calc_wallenius_expected_overlap(cb_ids, sub_sites, sub_branches, g, float_t
13241404 continue
13251405 size_lo = int (np .clip (base [branch_id ], a_min = 0 , a_max = num_positive ))
13261406 size_hi = int (np .clip (base [branch_id ] + 1 , a_min = 0 , a_max = num_positive ))
1327- prob_lo = _calc_wallenius_inclusion_probabilities (
1407+ prob_lo = inclusion_probability_func (
13281408 site_weights = branch_weights ,
13291409 draw_size = size_lo ,
13301410 float_type = np .float64 ,
13311411 )
13321412 if (frac [branch_id ] <= 0 ) or (size_lo == size_hi ):
13331413 inclusion_prob [branch_id , :] = prob_lo
13341414 else :
1335- prob_hi = _calc_wallenius_inclusion_probabilities (
1415+ prob_hi = inclusion_probability_func (
13361416 site_weights = branch_weights ,
13371417 draw_size = size_hi ,
13381418 float_type = np .float64 ,
@@ -1354,7 +1434,7 @@ def _calc_wallenius_expected_overlap(cb_ids, sub_sites, sub_branches, g, float_t
13541434 if num_positive == 0 :
13551435 continue
13561436 draw_size = int (np .clip (rounded_sizes [branch_id ], a_min = 0 , a_max = num_positive ))
1357- inclusion_prob [branch_id , :] = _calc_wallenius_inclusion_probabilities (
1437+ inclusion_prob [branch_id , :] = inclusion_probability_func (
13581438 site_weights = branch_weights ,
13591439 draw_size = draw_size ,
13601440 float_type = np .float64 ,
@@ -1366,6 +1446,63 @@ def _calc_wallenius_expected_overlap(cb_ids, sub_sites, sub_branches, g, float_t
13661446 )
13671447
13681448
1449+ def _calc_wallenius_expected_overlap (cb_ids , sub_sites , sub_branches , g , float_type ):
1450+ return _calc_weighted_urn_expected_overlap (
1451+ cb_ids = cb_ids ,
1452+ sub_sites = sub_sites ,
1453+ sub_branches = sub_branches ,
1454+ g = g ,
1455+ float_type = float_type ,
1456+ inclusion_probability_func = _calc_wallenius_inclusion_probabilities ,
1457+ )
1458+
1459+
1460+ def _calc_fisher_expected_overlap (cb_ids , sub_sites , sub_branches , g , float_type ):
1461+ return _calc_weighted_urn_expected_overlap (
1462+ cb_ids = cb_ids ,
1463+ sub_sites = sub_sites ,
1464+ sub_branches = sub_branches ,
1465+ g = g ,
1466+ float_type = float_type ,
1467+ inclusion_probability_func = _calc_fisher_inclusion_probabilities ,
1468+ )
1469+
1470+
1471+ def _calc_urn_expected_overlap (cb_ids , sub_sites , sub_branches , g , float_type ):
1472+ urn_model = _resolve_urn_model (g = g )
1473+ if urn_model == 'wallenius' :
1474+ return _calc_wallenius_expected_overlap (
1475+ cb_ids = cb_ids ,
1476+ sub_sites = sub_sites ,
1477+ sub_branches = sub_branches ,
1478+ g = g ,
1479+ float_type = float_type ,
1480+ )
1481+ if urn_model == 'fisher' :
1482+ return _calc_fisher_expected_overlap (
1483+ cb_ids = cb_ids ,
1484+ sub_sites = sub_sites ,
1485+ sub_branches = sub_branches ,
1486+ g = g ,
1487+ float_type = float_type ,
1488+ )
1489+ if urn_model == 'factorized_approx' :
1490+ sub_branches = np .asarray (sub_branches , dtype = np .float64 ).reshape (- 1 )
1491+ if sub_branches .ndim != 1 :
1492+ raise ValueError ('sub_branches should be a 1D array.' )
1493+ if not np .isfinite (sub_branches ).all ():
1494+ raise ValueError ('sub_branches should be finite.' )
1495+ np .clip (sub_branches , a_min = 0.0 , a_max = None , out = sub_branches )
1496+ sub_sites = _normalize_sub_sites_for_urn_overlap (sub_sites = sub_sites , num_branch = sub_branches .shape [0 ])
1497+ return _calc_tmp_E_sum (
1498+ cb_ids = cb_ids ,
1499+ sub_sites = sub_sites ,
1500+ sub_branches = sub_branches ,
1501+ float_type = float_type ,
1502+ )
1503+ raise ValueError ('Unsupported urn_model: {}' .format (urn_model ))
1504+
1505+
13691506def _fill_packed_masks_for_sizes (packed_masks_branch , site_p , size_values ):
13701507 size_values = np .asarray (size_values , dtype = np .int64 ).reshape (- 1 )
13711508 if packed_masks_branch .shape [0 ] != size_values .shape [0 ]:
@@ -1742,7 +1879,7 @@ def calc_E_mean(mode, cb_ids, sub_sg, sub_bg, obs_col, list_igad, g, static_sub_
17421879 else :
17431880 sub_sites = static_sub_sites
17441881 sub_branches = substitution .get_sub_branches (sub_bg , mode , sg , a , d )
1745- E_b += _calc_wallenius_expected_overlap (
1882+ E_b += _calc_urn_expected_overlap (
17461883 cb_ids = cb_ids ,
17471884 sub_sites = sub_sites ,
17481885 sub_branches = sub_branches ,
@@ -1777,7 +1914,7 @@ def joblib_calc_E_mean(
17771914 else :
17781915 sub_sites = static_sub_sites
17791916 sub_branches = substitution .get_sub_branches (sub_bg , mode , sg , a , d )
1780- dfEb += _calc_wallenius_expected_overlap (
1917+ dfEb += _calc_urn_expected_overlap (
17811918 cb_ids = cb_ids ,
17821919 sub_sites = sub_sites ,
17831920 sub_branches = sub_branches ,
@@ -2015,7 +2152,8 @@ def subroot_E2nan(cb, tree):
20152152def get_E (cb , g , ON_tensor , OS_tensor ):
20162153 requested_output_stats = _resolve_requested_output_stats (g )
20172154 base_stats = output_stat .get_required_base_stats (requested_output_stats )
2018- if (g ['omegaC_method' ]== 'modelfree' ):
2155+ expectation_method = _resolve_expectation_method (g = g )
2156+ if expectation_method == 'urn' :
20192157 ON_gad , ON_ga , ON_gd = substitution .get_group_state_totals (ON_tensor )
20202158 OS_gad , OS_ga , OS_gd = substitution .get_group_state_totals (OS_tensor )
20212159 g ['N_ind_nomissing_gad' ] = np .where (ON_gad != 0 )
@@ -2027,7 +2165,7 @@ def get_E(cb, g, ON_tensor, OS_tensor):
20272165 for st in base_stats :
20282166 cb ['ECN' + st ] = calc_E_stat (cb , ON_tensor , mode = st , stat = 'mean' , SN = 'N' , g = g )
20292167 cb ['ECS' + st ] = calc_E_stat (cb , OS_tensor , mode = st , stat = 'mean' , SN = 'S' , g = g )
2030- if ( g [ 'omegaC_method' ] == 'submodel' ) :
2168+ elif expectation_method == 'codon_model' :
20312169 id_cols = cb .columns [cb .columns .str .startswith ('branch_id_' )]
20322170 state_nsyE = get_exp_state (g = g , mode = 'nsy' )
20332171 if (g ['current_arity' ]== 2 ):
@@ -2059,6 +2197,8 @@ def get_E(cb, g, ON_tensor, OS_tensor):
20592197 )
20602198 cb = table .merge_tables (cb , cbES )
20612199 del state_cdnE ,cbES
2200+ else :
2201+ raise ValueError ('Unsupported expectation_method: {}' .format (expectation_method ))
20622202 cb = substitution .add_dif_stats (cb , g ['float_tol' ], prefix = 'EC' , output_stats = requested_output_stats )
20632203 cb = subroot_E2nan (cb , tree = g ['tree' ])
20642204 return cb
@@ -2598,7 +2738,7 @@ def _calc_poisson_count_matrix(
25982738 else :
25992739 sub_sites = static_sub_sites
26002740 sub_branches = substitution .get_sub_branches (sub_bg , mode , sg , a , d )
2601- mean_count = _calc_wallenius_expected_overlap (
2741+ mean_count = _calc_urn_expected_overlap (
26022742 cb_ids = cb_ids ,
26032743 sub_sites = sub_sites ,
26042744 sub_branches = sub_branches ,
@@ -2717,7 +2857,7 @@ def _calc_nbinom_count_matrix(
27172857 else :
27182858 sub_sites = static_sub_sites
27192859 sub_branches = substitution .get_sub_branches (sub_bg , mode , sg , a , d )
2720- mean_count = _calc_wallenius_expected_overlap (
2860+ mean_count = _calc_urn_expected_overlap (
27212861 cb_ids = cb_ids ,
27222862 sub_sites = sub_sites ,
27232863 sub_branches = sub_branches ,
@@ -2817,7 +2957,7 @@ def _calc_poisson_full_count_matrix(
28172957 sub_site_mass [nonzero_branch , :] /
28182958 sub_branches [nonzero_branch , None ]
28192959 )
2820- mean_count = _calc_wallenius_expected_overlap (
2960+ mean_count = _calc_urn_expected_overlap (
28212961 cb_ids = cb_ids ,
28222962 sub_sites = sub_sites ,
28232963 sub_branches = sub_branches ,
@@ -3141,8 +3281,9 @@ def _resolve_omega_pvalue_dsc_calibration_transformation(cb, sub, g):
31413281def add_omega_empirical_pvalues (cb , ON_tensor , OS_tensor , g ):
31423282 if not bool (g .get ('calc_omega_pvalue' , False )):
31433283 return cb
3144- if str (g .get ('omegaC_method' , '' )).strip ().lower () != 'modelfree' :
3145- sys .stderr .write ('Skipping --calc_omega_pvalue because --omegaC_method is not "modelfree".\n ' )
3284+ if _resolve_expectation_method (g = g ) != 'urn' :
3285+ txt = 'Skipping --calc_omega_pvalue because --expectation_method is not "urn".\n '
3286+ sys .stderr .write (txt )
31463287 return cb
31473288 null_model = _resolve_omega_pvalue_null_model (g = g )
31483289 txt = 'omega_C empirical p-value null model: {}'
0 commit comments