diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index b8cf0a6..cfc1afc 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -25,6 +25,7 @@ plot_scatter_submodels, plot_variable_importance, plot_variable_inclusion, + vi_to_kulprit, ) __all__ = [ @@ -41,6 +42,7 @@ "plot_scatter_submodels", "plot_variable_importance", "plot_variable_inclusion", + "vi_to_kulprit", ] __version__ = "0.10.0" diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 78ce920..cf804c5 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -1006,6 +1006,24 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 return vi_results +def vi_to_kulprit(vi_results: dict) -> list[list[str]]: + """ + Export variable importance results to Kulprit format. + + Parameters + ---------- + vi_results : dict + Dictionary computed with `compute_variable_importance` + + Returns + ------- + list[list[str]] + A list of lists containing variable names for each submodel. + """ + clean_labels = [label.strip("+ ") for label in vi_results["labels"]] + return [clean_labels[:idx] for idx in range(len(clean_labels))] + + def plot_variable_importance( vi_results: dict, submodels: Optional[Union[list[int], np.ndarray, tuple[int, ...]]] = None, diff --git a/tests/test_utils.py b/tests/test_utils.py index dbf3aca..ed85af7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -79,6 +79,10 @@ def test_vi(self, kwargs): pmb.plot_variable_importance(vi_results, **kwargs) pmb.plot_scatter_submodels(vi_results, **kwargs) + user_terms = pmb.vi_to_kulprit(vi_results) + assert len(user_terms) == 3 + assert all("+" not in term for terms in user_terms[1:] for term in terms) + def test_pdp_pandas_labels(self): pd = pytest.importorskip("pandas")