Skip to content

Commit 65ed4d6

Browse files
committed
Implement generalized skyline and arbitrary clock models
1 parent e5731a3 commit 65ed4d6

40 files changed

+521
-71
lines changed

.github/workflows/python-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
strategy:
1717
fail-fast: false
1818
matrix:
19-
python-version: [3.7, 3.8, 3.9, '3.10', '3.11']
19+
python-version: [3.8, 3.9, '3.10', '3.11']
2020
env:
2121
# https://github.com/pytest-dev/pytest/issues/2042
2222
PY_IGNORE_IMPORTMISMATCH: 1

environment-dev.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ channels:
44
- bioconda
55
- conda-forge
66
dependencies:
7-
- python>=3.7
7+
- python>=3.10
88
- numpy>=1.7
99
- dendropy
1010
- pytorch>=1.9
1111
- pip
1212
# Dev dependencies
13-
- black
13+
- black=23.3.0
1414
- docformatter
1515
- flake8
1616
- isort

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ channels:
33
- pytorch
44
- bioconda
55
dependencies:
6-
- python=3.7
6+
- python=3.10
77
- numpy>=1.7
88
- dendropy
99
- pytorch>=1.9

setup.cfg

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ classifiers =
1313
Intended Audience :: Science/Research
1414
License :: OSI Approved :: GNU General Public License v3 (GPLv3)
1515
Operating System :: OS Independent
16-
Programming Language :: Python :: 3.7
1716
Programming Language :: Python :: 3.8
1817
Programming Language :: Python :: 3.9
1918
Programming Language :: Python :: 3.10
19+
Programming Language :: Python :: 3.11
2020
Topic :: Scientific/Engineering :: Bio-Informatics
2121

2222
[options]
23-
python_requires = >=3.7
23+
python_requires = >=3.8
2424
packages = find:
2525
package_dir =
2626
=.

torchtree/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""This is the root package of the torchtree framework."""
2+
23
from ._version import __version__
34
from .core.parameter import CatParameter, Parameter, TransformedParameter, ViewParameter
45

torchtree/cli/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""The cli package contains modules for creating JSON configuration files
22
through a command-line interface."""
3+
34
from torchtree.cli.plugin_manager import PluginManager
45

56
PLUGIN_MANAGER = PluginManager()

torchtree/cli/advi.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -840,14 +840,19 @@ def create_logger(id_, parameters, arg):
840840
if arg.coalescent:
841841
models.append('coalescent')
842842
if arg.coalescent in COALESCENT_PIECEWISE:
843-
models.append('gmrf')
844843
models.append(
845844
{
846845
'id': arg.coalescent,
847846
'type': 'JointDistributionModel',
848-
'distributions': ['coalescent', 'gmrf'],
847+
'distributions': ['coalescent'],
849848
}
850849
)
850+
if arg.theta_prior == 'eml':
851+
models[-1]['distributions'].append('coalescent.theta.eml')
852+
models.append('coalescent.theta.eml')
853+
else:
854+
models[-1]['distributions'].append('gmrf')
855+
models.append('gmrf')
851856

852857
return {
853858
"id": id_,
@@ -877,14 +882,19 @@ def create_sampler(id_, var_id, parameters, arg):
877882
if arg.coalescent:
878883
models.append('coalescent')
879884
if arg.coalescent in COALESCENT_PIECEWISE:
880-
models.append('gmrf')
881885
models.append(
882886
{
883887
'id': arg.coalescent,
884888
'type': 'JointDistributionModel',
885-
'distributions': ['coalescent', 'gmrf'],
889+
'distributions': ['coalescent'],
886890
}
887891
)
892+
if arg.theta_prior == 'eml':
893+
models[-1]['distributions'].append('coalescent.theta.eml')
894+
models.append('coalescent.theta.eml')
895+
else:
896+
models[-1]['distributions'].append('gmrf')
897+
models.append('gmrf')
888898

889899
return {
890900
"id": id_,
@@ -937,7 +947,7 @@ def build_advi(arg):
937947
if arg.clock is not None and arg.heights == 'ratio':
938948
jacobians_list.append('tree')
939949

940-
if arg.coalescent in COALESCENT_PIECEWISE:
950+
if arg.coalescent in COALESCENT_PIECEWISE and arg.theta_prior != 'eml':
941951
jacobians_list.remove("coalescent.theta")
942952

943953
joint_jacobian = {
@@ -984,13 +994,20 @@ def build_advi(arg):
984994
parameters.extend(
985995
(
986996
f'{branch_model_id}.rates.prior.mean',
987-
f'{branch_model_id}.rates.prior.scale',
997+
f'{branch_model_id}.rates.prior.stdev',
998+
)
999+
)
1000+
elif arg.clock == 'ncln':
1001+
parameters.extend(
1002+
(
1003+
f'{branch_model_id}.location',
1004+
f'{branch_model_id}.scale',
9881005
)
9891006
)
9901007
else:
9911008
parameters.append(f"{branch_model_id}.rate")
9921009

993-
if arg.clock == 'horseshoe' or arg.clock == 'ucln':
1010+
if arg.clock == 'horseshoe' or arg.clock in ('ucln', 'ncln'):
9941011
parameters.append(f'{branch_model_id}.rates')
9951012
else:
9961013
parameters = ['tree.blens']
@@ -999,7 +1016,11 @@ def build_advi(arg):
9991016
if arg.coalescent_integrated is None:
10001017
parameters.append("coalescent.theta")
10011018

1002-
if arg.coalescent in COALESCENT_PIECEWISE and not arg.gmrf_integrated:
1019+
if (
1020+
arg.coalescent in COALESCENT_PIECEWISE
1021+
and not arg.gmrf_integrated
1022+
and arg.theta_prior is None
1023+
):
10031024
parameters.append('gmrf.precision')
10041025
elif arg.coalescent == 'exponential':
10051026
parameters.append('coalescent.growth')

torchtree/cli/argparse_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,19 @@ def str_or_float(arg, choices):
3131
)
3232

3333

34+
def str_or_int(arg):
35+
"""Used by argparse when the argument can be either an integer or a string."""
36+
try:
37+
return int(arg)
38+
except ValueError:
39+
if isinstance(arg, str):
40+
return arg
41+
else:
42+
raise argparse.ArgumentTypeError(
43+
'invalid choice (choose from an integer or a string)'
44+
)
45+
46+
3447
def list_of_float(arg, length):
3548
"""Used by argparse when the argument should be a list of floats."""
3649
values = arg.split(",")

0 commit comments

Comments
 (0)