Skip to content

Commit a7cbe04

Browse files
authored
Initialize relabeled_failures in SurvivalSplittingRule ctor (#1513)
1 parent 8b8b640 commit a7cbe04

20 files changed

+25
-14
lines changed

core/src/splitting/SurvivalSplittingRule.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323

2424
namespace grf {
2525

26-
SurvivalSplittingRule::SurvivalSplittingRule(double alpha):
27-
alpha(alpha) {
26+
SurvivalSplittingRule::SurvivalSplittingRule(size_t num_data_rows, double alpha):
27+
relabeled_failures(num_data_rows, 0), alpha(alpha) {
2828
}
2929

3030
bool SurvivalSplittingRule::find_best_split(const Data& data,
@@ -99,9 +99,6 @@ void SurvivalSplittingRule::find_best_split_internal(const Data& data,
9999
std::vector<double> at_risk(num_failures + 1);
100100
at_risk[0] = static_cast<double>(size_node);
101101

102-
// allocating an N-sized (full data set size) array is faster than a hash table
103-
std::vector<size_t> relabeled_failures(data.get_num_rows());
104-
105102
std::vector<double> numerator_weights(num_failures + 1);
106103
std::vector<double> denominator_weights(num_failures + 1);
107104

@@ -135,7 +132,7 @@ void SurvivalSplittingRule::find_best_split_internal(const Data& data,
135132

136133
for (auto& var : possible_split_vars) {
137134
find_best_split_value(data, var, size_node, min_child_size, num_failures_node, num_failures,
138-
best_value, best_var, best_logrank, best_send_missing_left, samples, relabeled_failures,
135+
best_value, best_var, best_logrank, best_send_missing_left, samples,
139136
count_failure, at_risk, numerator_weights, denominator_weights);
140137
}
141138
}
@@ -151,7 +148,6 @@ void SurvivalSplittingRule::find_best_split_value(const Data& data,
151148
double& best_logrank,
152149
bool& best_send_missing_left,
153150
const std::vector<size_t>& samples,
154-
const std::vector<size_t>& relabeled_failures,
155151
const std::vector<double>& count_failure,
156152
const std::vector<double>& at_risk,
157153
const std::vector<double>& numerator_weights,

core/src/splitting/SurvivalSplittingRule.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace grf {
3030

3131
class SurvivalSplittingRule final: public SplittingRule {
3232
public:
33-
SurvivalSplittingRule(double alpha);
33+
SurvivalSplittingRule(size_t num_data_rows, double alpha);
3434

3535
bool find_best_split(const Data& data,
3636
size_t node,
@@ -66,7 +66,6 @@ class SurvivalSplittingRule final: public SplittingRule {
6666
double& best_logrank,
6767
bool& best_send_missing_left,
6868
const std::vector<size_t>& samples,
69-
const std::vector<size_t>& relabeled_failures,
7069
const std::vector<double>& count_failure,
7170
const std::vector<double>& at_risk,
7271
const std::vector<double>& numerator_weights,
@@ -81,6 +80,7 @@ class SurvivalSplittingRule final: public SplittingRule {
8180
const std::vector<double>& numerator_weights,
8281
const std::vector<double>& denominator_weights);
8382

83+
std::vector<size_t> relabeled_failures;
8484
double alpha;
8585

8686
DISALLOW_COPY_AND_ASSIGN(SurvivalSplittingRule);

core/src/splitting/factory/CausalSurvivalSplittingRuleFactory.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
namespace grf {
2424

2525
std::unique_ptr<SplittingRule> CausalSurvivalSplittingRuleFactory::create(size_t max_num_unique_values,
26-
const TreeOptions& options) const {
26+
const Data& data,
27+
const TreeOptions& options) const {
2728
return std::unique_ptr<SplittingRule>(new CausalSurvivalSplittingRule(
2829
max_num_unique_values,
2930
options.get_min_node_size(),

core/src/splitting/factory/CausalSurvivalSplittingRuleFactory.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class CausalSurvivalSplittingRuleFactory final: public SplittingRuleFactory {
2929
public:
3030
CausalSurvivalSplittingRuleFactory() = default;
3131
std::unique_ptr<SplittingRule> create(size_t max_num_unique_values,
32+
const Data& data,
3233
const TreeOptions& options) const;
3334
private:
3435
DISALLOW_COPY_AND_ASSIGN(CausalSurvivalSplittingRuleFactory);

core/src/splitting/factory/InstrumentalSplittingRuleFactory.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
namespace grf {
2424

2525
std::unique_ptr<SplittingRule> InstrumentalSplittingRuleFactory::create(size_t max_num_unique_values,
26+
const Data& data,
2627
const TreeOptions& options) const {
2728
return std::unique_ptr<SplittingRule>(new InstrumentalSplittingRule(
2829
max_num_unique_values,

core/src/splitting/factory/InstrumentalSplittingRuleFactory.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class InstrumentalSplittingRuleFactory final: public SplittingRuleFactory {
3737
public:
3838
InstrumentalSplittingRuleFactory() = default;
3939
std::unique_ptr<SplittingRule> create(size_t max_num_unique_values,
40+
const Data& data,
4041
const TreeOptions& options) const;
4142
private:
4243
DISALLOW_COPY_AND_ASSIGN(InstrumentalSplittingRuleFactory);

core/src/splitting/factory/MultiCausalSplittingRuleFactory.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ MultiCausalSplittingRuleFactory::MultiCausalSplittingRuleFactory(size_t response
2828
num_treatments(num_treatments) {}
2929

3030
std::unique_ptr<SplittingRule> MultiCausalSplittingRuleFactory::create(size_t max_num_unique_values,
31+
const Data& data,
3132
const TreeOptions& options) const {
3233
return std::unique_ptr<SplittingRule>(new MultiCausalSplittingRule(
3334
max_num_unique_values,

core/src/splitting/factory/MultiCausalSplittingRuleFactory.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class MultiCausalSplittingRuleFactory final: public SplittingRuleFactory {
3030
size_t num_treatments);
3131

3232
std::unique_ptr<SplittingRule> create(size_t max_num_unique_values,
33+
const Data& data,
3334
const TreeOptions& options) const;
3435
private:
3536
size_t response_length;

core/src/splitting/factory/MultiRegressionSplittingRuleFactory.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ MultiRegressionSplittingRuleFactory::MultiRegressionSplittingRuleFactory(size_t
2626
num_outcomes(num_outcomes) {}
2727

2828
std::unique_ptr<SplittingRule> MultiRegressionSplittingRuleFactory::create(size_t max_num_unique_values,
29+
const Data& data,
2930
const TreeOptions& options) const {
3031
return std::unique_ptr<SplittingRule>(new MultiRegressionSplittingRule(
3132
max_num_unique_values,

core/src/splitting/factory/MultiRegressionSplittingRuleFactory.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class MultiRegressionSplittingRuleFactory final: public SplittingRuleFactory {
3636
MultiRegressionSplittingRuleFactory(size_t num_outcomes);
3737

3838
std::unique_ptr<SplittingRule> create(size_t max_num_unique_values,
39+
const Data& data,
3940
const TreeOptions& options) const;
4041
private:
4142
size_t num_outcomes;

0 commit comments

Comments
 (0)