Skip to content

Commit ae4ac47

Browse files
committed
🎨 functions in __init__ are not documented in api documentation
- put fit_model in it's own module - import it at the top-level for compatibility
1 parent e0cd284 commit ae4ac47

File tree

2 files changed

+80
-70
lines changed

2 files changed

+80
-70
lines changed

src/growthcurves/__init__.py

Lines changed: 3 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
# The __init__.py file is loaded when the package is loaded.
22
# It is used to indicate that the directory in which it resides is a Python package
3-
import warnings
4-
from importlib import metadata
53

6-
import numpy as np
4+
from importlib import metadata
75

86
__version__ = metadata.version("growthcurves")
97

108
from . import inference, models, non_parametric, parametric, plot, preprocessing
9+
from .fit import fit_model
1110
from .inference import compare_methods
1211
from .models import (
1312
MODEL_REGISTRY,
@@ -33,71 +32,5 @@
3332
"blank_subtraction",
3433
"path_correct",
3534
"compare_methods",
35+
"fit_model",
3636
]
37-
38-
39-
def fit_model(
40-
t: np.ndarray,
41-
N: np.ndarray,
42-
model_name: str,
43-
lag_threshold: float = 0.15,
44-
exp_threshold: float = 0.15,
45-
phase_boundary_method=None,
46-
**kwargs,
47-
) -> tuple[dict, dict]:
48-
if model_name in models.MODEL_REGISTRY["non_parametric"]:
49-
fit_res = non_parametric.fit_non_parametric(t, N, method=model_name, **kwargs)
50-
else:
51-
fit_res = parametric.fit_parametric(t, N, method=model_name, **kwargs)
52-
# return None if fit fails, along with bad fit stats
53-
if fit_res is None:
54-
warnings.warn(
55-
f"Model fitting failed for model {model_name}. Returning None.",
56-
stacklevel=2,
57-
)
58-
return None, inference.bad_fit_stats()
59-
60-
stats_res = inference.extract_stats(
61-
fit_res,
62-
t=t,
63-
N=N,
64-
lag_threshold=lag_threshold,
65-
exp_threshold=exp_threshold,
66-
phase_boundary_method=phase_boundary_method,
67-
**kwargs,
68-
)
69-
70-
stats_res["model_name"] = model_name
71-
return fit_res, stats_res
72-
73-
74-
# ! not so good for dynamic inspection tools...
75-
fit_model.__doc__ = f"""Fit a growth model to the provided t and N.
76-
77-
Parameters
78-
----------
79-
t : np.ndarray
80-
Time points corresponding to N (in hours).
81-
N : np.ndarray
82-
Growth data points corresponding to t.
83-
model_name : str
84-
One of the models in {', '.join(get_all_models())}.
85-
lag_threshold : float, optional
86-
Fraction of μ_max to define end of lag phase (threshold method, default: 0.15).
87-
exp_threshold : float, optional
88-
Fraction of μ_max to define end of exponential phase (threshold method,
89-
default: 0.15).
90-
phase_boundary_method : str, optional
91-
Method to determine phase boundaries ("tangent", "threshold",
92-
or None for default for model class).
93-
**kwargs
94-
Additiona keyword arguments to be passed to fitting and inference functions.
95-
96-
Returns
97-
-------
98-
tuple[dict, dict]
99-
Return tuple of two dictionaries: (fit_res, stats_res)
100-
- fit_res: Dictionary containing fitted model parameters.
101-
- stats_res: Dictionary containing goodness-of-fit statistics and growth metrics
102-
103-
"""

src/growthcurves/fit.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""Entry point of growthcurves package"""
2+
3+
import warnings
4+
5+
import numpy as np
6+
7+
from . import inference, models, non_parametric, parametric
8+
from .models import get_all_models
9+
10+
__all__ = ["fit_model",]
11+
12+
13+
def fit_model(
14+
t: np.ndarray,
15+
N: np.ndarray,
16+
model_name: str,
17+
lag_threshold: float = 0.15,
18+
exp_threshold: float = 0.15,
19+
phase_boundary_method=None,
20+
**kwargs,
21+
) -> tuple[dict, dict]:
22+
if model_name in models.MODEL_REGISTRY["non_parametric"]:
23+
fit_res = non_parametric.fit_non_parametric(t, N, method=model_name, **kwargs)
24+
else:
25+
fit_res = parametric.fit_parametric(t, N, method=model_name, **kwargs)
26+
# return None if fit fails, along with bad fit stats
27+
if fit_res is None:
28+
warnings.warn(
29+
f"Model fitting failed for model {model_name}. Returning None.",
30+
stacklevel=2,
31+
)
32+
return None, inference.bad_fit_stats()
33+
34+
stats_res = inference.extract_stats(
35+
fit_res,
36+
t=t,
37+
N=N,
38+
lag_threshold=lag_threshold,
39+
exp_threshold=exp_threshold,
40+
phase_boundary_method=phase_boundary_method,
41+
**kwargs,
42+
)
43+
44+
stats_res["model_name"] = model_name
45+
return fit_res, stats_res
46+
47+
48+
# ! not so good for dynamic inspection tools...
49+
fit_model.__doc__ = f"""Fit a growth model to the provided t and N.
50+
51+
Parameters
52+
----------
53+
t : np.ndarray
54+
Time points corresponding to N (in hours).
55+
N : np.ndarray
56+
Growth data points corresponding to t.
57+
model_name : str
58+
One of the models in {', '.join(get_all_models())}.
59+
lag_threshold : float, optional
60+
Fraction of μ_max to define end of lag phase (threshold method, default: 0.15).
61+
exp_threshold : float, optional
62+
Fraction of μ_max to define end of exponential phase (threshold method,
63+
default: 0.15).
64+
phase_boundary_method : str, optional
65+
Method to determine phase boundaries ("tangent", "threshold",
66+
or None for default for model class).
67+
**kwargs
68+
Additiona keyword arguments to be passed to fitting and inference functions.
69+
70+
Returns
71+
-------
72+
tuple[dict, dict]
73+
Return tuple of two dictionaries: (fit_res, stats_res)
74+
- fit_res: Dictionary containing fitted model parameters.
75+
- stats_res: Dictionary containing goodness-of-fit statistics and growth metrics
76+
77+
"""

0 commit comments

Comments
 (0)