Skip to content

Commit dbbb476

Browse files
authored
Merge pull request #52 from biosustain/fix_entry_points_documentation
📝🐛 functions in __init__ are not documented in api documentation
2 parents e0cd284 + d611532 commit dbbb476

File tree

2 files changed

+82
-70
lines changed

2 files changed

+82
-70
lines changed

src/growthcurves/__init__.py

Lines changed: 3 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
# The __init__.py file is loaded when the package is loaded.
22
# It is used to indicate that the directory in which it resides is a Python package
3-
import warnings
4-
from importlib import metadata
53

6-
import numpy as np
4+
from importlib import metadata
75

86
__version__ = metadata.version("growthcurves")
97

108
from . import inference, models, non_parametric, parametric, plot, preprocessing
9+
from .fit import fit_model
1110
from .inference import compare_methods
1211
from .models import (
1312
MODEL_REGISTRY,
@@ -33,71 +32,5 @@
3332
"blank_subtraction",
3433
"path_correct",
3534
"compare_methods",
35+
"fit_model",
3636
]
37-
38-
39-
def fit_model(
40-
t: np.ndarray,
41-
N: np.ndarray,
42-
model_name: str,
43-
lag_threshold: float = 0.15,
44-
exp_threshold: float = 0.15,
45-
phase_boundary_method=None,
46-
**kwargs,
47-
) -> tuple[dict, dict]:
48-
if model_name in models.MODEL_REGISTRY["non_parametric"]:
49-
fit_res = non_parametric.fit_non_parametric(t, N, method=model_name, **kwargs)
50-
else:
51-
fit_res = parametric.fit_parametric(t, N, method=model_name, **kwargs)
52-
# return None if fit fails, along with bad fit stats
53-
if fit_res is None:
54-
warnings.warn(
55-
f"Model fitting failed for model {model_name}. Returning None.",
56-
stacklevel=2,
57-
)
58-
return None, inference.bad_fit_stats()
59-
60-
stats_res = inference.extract_stats(
61-
fit_res,
62-
t=t,
63-
N=N,
64-
lag_threshold=lag_threshold,
65-
exp_threshold=exp_threshold,
66-
phase_boundary_method=phase_boundary_method,
67-
**kwargs,
68-
)
69-
70-
stats_res["model_name"] = model_name
71-
return fit_res, stats_res
72-
73-
74-
# ! not so good for dynamic inspection tools...
75-
fit_model.__doc__ = f"""Fit a growth model to the provided t and N.
76-
77-
Parameters
78-
----------
79-
t : np.ndarray
80-
Time points corresponding to N (in hours).
81-
N : np.ndarray
82-
Growth data points corresponding to t.
83-
model_name : str
84-
One of the models in {', '.join(get_all_models())}.
85-
lag_threshold : float, optional
86-
Fraction of μ_max to define end of lag phase (threshold method, default: 0.15).
87-
exp_threshold : float, optional
88-
Fraction of μ_max to define end of exponential phase (threshold method,
89-
default: 0.15).
90-
phase_boundary_method : str, optional
91-
Method to determine phase boundaries ("tangent", "threshold",
92-
or None for default for model class).
93-
**kwargs
94-
Additiona keyword arguments to be passed to fitting and inference functions.
95-
96-
Returns
97-
-------
98-
tuple[dict, dict]
99-
Return tuple of two dictionaries: (fit_res, stats_res)
100-
- fit_res: Dictionary containing fitted model parameters.
101-
- stats_res: Dictionary containing goodness-of-fit statistics and growth metrics
102-
103-
"""

src/growthcurves/fit.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Entry point of growthcurves package"""
2+
3+
import warnings
4+
5+
import numpy as np
6+
7+
from . import inference, models, non_parametric, parametric
8+
from .models import get_all_models
9+
10+
__all__ = [
11+
"fit_model",
12+
]
13+
14+
15+
def fit_model(
16+
t: np.ndarray,
17+
N: np.ndarray,
18+
model_name: str,
19+
lag_threshold: float = 0.15,
20+
exp_threshold: float = 0.15,
21+
phase_boundary_method=None,
22+
**kwargs,
23+
) -> tuple[dict, dict]:
24+
if model_name in models.MODEL_REGISTRY["non_parametric"]:
25+
fit_res = non_parametric.fit_non_parametric(t, N, method=model_name, **kwargs)
26+
else:
27+
fit_res = parametric.fit_parametric(t, N, method=model_name, **kwargs)
28+
# return None if fit fails, along with bad fit stats
29+
if fit_res is None:
30+
warnings.warn(
31+
f"Model fitting failed for model {model_name}. Returning None.",
32+
stacklevel=2,
33+
)
34+
return None, inference.bad_fit_stats()
35+
36+
stats_res = inference.extract_stats(
37+
fit_res,
38+
t=t,
39+
N=N,
40+
lag_threshold=lag_threshold,
41+
exp_threshold=exp_threshold,
42+
phase_boundary_method=phase_boundary_method,
43+
**kwargs,
44+
)
45+
46+
stats_res["model_name"] = model_name
47+
return fit_res, stats_res
48+
49+
50+
# ! not so good for dynamic inspection tools...
51+
fit_model.__doc__ = f"""Fit a growth model to the provided t and N.
52+
53+
Parameters
54+
----------
55+
t : np.ndarray
56+
Time points corresponding to N (in hours).
57+
N : np.ndarray
58+
Growth data points corresponding to t.
59+
model_name : str
60+
One of the models in {', '.join(get_all_models())}.
61+
lag_threshold : float, optional
62+
Fraction of μ_max to define end of lag phase (threshold method, default: 0.15).
63+
exp_threshold : float, optional
64+
Fraction of μ_max to define end of exponential phase (threshold method,
65+
default: 0.15).
66+
phase_boundary_method : str, optional
67+
Method to determine phase boundaries ("tangent", "threshold",
68+
or None for default for model class).
69+
**kwargs
70+
Additiona keyword arguments to be passed to fitting and inference functions.
71+
72+
Returns
73+
-------
74+
tuple[dict, dict]
75+
Return tuple of two dictionaries: (fit_res, stats_res)
76+
- fit_res: Dictionary containing fitted model parameters.
77+
- stats_res: Dictionary containing goodness-of-fit statistics and growth metrics
78+
79+
"""

0 commit comments

Comments
 (0)