|
1 | 1 | # The __init__.py file is loaded when the package is loaded. |
2 | 2 | # It is used to indicate that the directory in which it resides is a Python package |
3 | | -import warnings |
4 | | -from importlib import metadata |
5 | 3 |
|
6 | | -import numpy as np |
| 4 | +from importlib import metadata |
7 | 5 |
|
8 | 6 | __version__ = metadata.version("growthcurves") |
9 | 7 |
|
10 | 8 | from . import inference, models, non_parametric, parametric, plot, preprocessing |
| 9 | +from .fit import fit_model |
11 | 10 | from .inference import compare_methods |
12 | 11 | from .models import ( |
13 | 12 | MODEL_REGISTRY, |
|
33 | 32 | "blank_subtraction", |
34 | 33 | "path_correct", |
35 | 34 | "compare_methods", |
| 35 | + "fit_model", |
36 | 36 | ] |
37 | | - |
38 | | - |
39 | | -def fit_model( |
40 | | - t: np.ndarray, |
41 | | - N: np.ndarray, |
42 | | - model_name: str, |
43 | | - lag_threshold: float = 0.15, |
44 | | - exp_threshold: float = 0.15, |
45 | | - phase_boundary_method=None, |
46 | | - **kwargs, |
47 | | -) -> tuple[dict, dict]: |
48 | | - if model_name in models.MODEL_REGISTRY["non_parametric"]: |
49 | | - fit_res = non_parametric.fit_non_parametric(t, N, method=model_name, **kwargs) |
50 | | - else: |
51 | | - fit_res = parametric.fit_parametric(t, N, method=model_name, **kwargs) |
52 | | - # return None if fit fails, along with bad fit stats |
53 | | - if fit_res is None: |
54 | | - warnings.warn( |
55 | | - f"Model fitting failed for model {model_name}. Returning None.", |
56 | | - stacklevel=2, |
57 | | - ) |
58 | | - return None, inference.bad_fit_stats() |
59 | | - |
60 | | - stats_res = inference.extract_stats( |
61 | | - fit_res, |
62 | | - t=t, |
63 | | - N=N, |
64 | | - lag_threshold=lag_threshold, |
65 | | - exp_threshold=exp_threshold, |
66 | | - phase_boundary_method=phase_boundary_method, |
67 | | - **kwargs, |
68 | | - ) |
69 | | - |
70 | | - stats_res["model_name"] = model_name |
71 | | - return fit_res, stats_res |
72 | | - |
73 | | - |
74 | | -# ! not so good for dynamic inspection tools... |
75 | | -fit_model.__doc__ = f"""Fit a growth model to the provided t and N. |
76 | | -
|
77 | | - Parameters |
78 | | - ---------- |
79 | | - t : np.ndarray |
80 | | - Time points corresponding to N (in hours). |
81 | | - N : np.ndarray |
82 | | - Growth data points corresponding to t. |
83 | | - model_name : str |
84 | | - One of the models in {', '.join(get_all_models())}. |
85 | | - lag_threshold : float, optional |
86 | | - Fraction of μ_max to define end of lag phase (threshold method, default: 0.15). |
87 | | - exp_threshold : float, optional |
88 | | - Fraction of μ_max to define end of exponential phase (threshold method, |
89 | | - default: 0.15). |
90 | | - phase_boundary_method : str, optional |
91 | | - Method to determine phase boundaries ("tangent", "threshold", |
92 | | - or None for default for model class). |
93 | | - **kwargs |
94 | | - Additiona keyword arguments to be passed to fitting and inference functions. |
95 | | -
|
96 | | - Returns |
97 | | - ------- |
98 | | - tuple[dict, dict] |
99 | | - Return tuple of two dictionaries: (fit_res, stats_res) |
100 | | - - fit_res: Dictionary containing fitted model parameters. |
101 | | - - stats_res: Dictionary containing goodness-of-fit statistics and growth metrics |
102 | | -
|
103 | | - """ |
0 commit comments