Skip to content

Commit 1726dac

Browse files
Merge pull request #1 from ottenbreit-data-science/finishing
Modified numpy version requirement. Better cleanup after fit.
2 parents ac2d1c7 + ab22926 commit 1726dac

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

cpp/APLRRegressor.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ void APLRRegressor::determine_interactions_to_consider()
422422
if(!(terms[sorted_latest_split_point_errors_indices[i]]==interaction))
423423
{
424424
interaction.given_terms.push_back(terms[sorted_latest_split_point_errors_indices[i]]);
425+
interaction.given_terms[interaction.given_terms.size()-1].cleanup_when_this_term_was_added_as_a_given_predictor();
425426
bool already_exists{false};
426427
for (size_t k = 0; k < terms_eligible_current.size(); ++k)
427428
{
@@ -836,6 +837,10 @@ void APLRRegressor::cleanup_after_fit()
836837
distributed_terms.clear();
837838
interactions_to_consider.clear();
838839
error_index_for_interactions_to_consider.resize(0);
840+
for (size_t i = 0; i < terms.size(); ++i)
841+
{
842+
terms[i].cleanup_after_fit();
843+
}
839844
}
840845

841846
VectorXd APLRRegressor::predict(const MatrixXd &X)

cpp/term.h

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ class Term
5353
void estimate_split_point_on_discretized_data();
5454
void calculate_coefficient_and_error_on_discretized_data(bool direction_right, double split_point);
5555
void estimate_coefficient_and_error_on_all_data();
56-
void clean_up_after_estimate_split_point();
56+
void cleanup_after_estimate_split_point();
57+
void cleanup_after_fit();
58+
void cleanup_when_this_term_was_added_as_a_given_predictor();
5759

5860
public:
5961
//fields
@@ -171,7 +173,7 @@ void Term::estimate_split_point(const MatrixXd &X,const VectorXd &y,const Vector
171173
discretize_data_by_bin();
172174
estimate_split_point_on_discretized_data();
173175
estimate_coefficient_and_error_on_all_data();
174-
clean_up_after_estimate_split_point();
176+
cleanup_after_estimate_split_point();
175177
}
176178

177179
//Calculate indices that get zeroed out during calculate() because of given terms. Also calculates indices of those observations that do not.
@@ -582,7 +584,7 @@ void Term::estimate_coefficient_and_error_on_all_data()
582584
}
583585
}
584586

585-
void Term::clean_up_after_estimate_split_point()
587+
void Term::cleanup_after_estimate_split_point()
586588
{
587589
given_terms_indices.not_zeroed.resize(0);
588590
given_terms_indices.zeroed.resize(0);
@@ -593,6 +595,22 @@ void Term::clean_up_after_estimate_split_point()
593595
errors_initial.resize(0);
594596
}
595597

598+
void Term::cleanup_after_fit()
599+
{
600+
bins_start_index.clear();
601+
bins_end_index.clear();
602+
bins_split_points_left.clear();
603+
bins_split_points_right.clear();
604+
values_discretized.resize(0);
605+
sample_weight_discretized.resize(0);
606+
}
607+
608+
void Term::cleanup_when_this_term_was_added_as_a_given_predictor()
609+
{
610+
cleanup_after_fit();
611+
coefficient_steps.resize(0);
612+
}
613+
596614
VectorXd Term::calculate_prediction_contribution(const MatrixXd &X)
597615
{
598616
VectorXd values{calculate(X)};

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515

1616
setuptools.setup(
1717
name='aplr',
18-
version='1.0.1',
18+
version='1.0.2',
1919
description='Automatic Piecewise Linear Regression',
2020
ext_modules=[sfc_module],
2121
author="Mathias von Ottenbreit",
2222
author_email="[email protected]",
2323
long_description="Build predictive and interpretable parametric machine learning models in Python based on the Automatic Piecewise Linear Regression methodology developed by Mathias von Ottenbreit.",
2424
long_description_content_type="text/markdown",
2525
packages=['aplr'],
26-
install_requires=["numpy"],
26+
install_requires=["numpy>=1.20"],
2727
python_requires='>=3.8',
2828
classifiers=["License :: OSI Approved :: MIT License"],
2929
license="MIT",

0 commit comments

Comments
 (0)