Skip to content

Commit 6a3f26d

Browse files
calculating sample_weight_discretized more correctly
1 parent be8196a commit 6a3f26d

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

cpp/term.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -432,12 +432,20 @@ void Term::discretize_data_by_bin()
432432
values_discretized[i]=sorted_vectors.values_sorted.block(bins_start_index[i],0,observations_in_bins[i],1).mean();
433433
}
434434

435-
if(sorted_vectors.sample_weight_sorted.size()>0)
435+
sample_weight_discretized.resize(bins_start_index.size());
436+
bool sample_weights_were_provided_by_user{sorted_vectors.sample_weight_sorted.size()>0};
437+
if(sample_weights_were_provided_by_user)
436438
{
437-
sample_weight_discretized.resize(bins_start_index.size());
438439
for (size_t i = 0; i < bins_start_index.size(); ++i)
439440
{
440-
sample_weight_discretized[i]=sorted_vectors.sample_weight_sorted.block(bins_start_index[i],0,observations_in_bins[i],1).mean();
441+
sample_weight_discretized[i]=sorted_vectors.sample_weight_sorted.block(bins_start_index[i],0,observations_in_bins[i],1).sum();
442+
}
443+
}
444+
else
445+
{
446+
for (size_t i = 0; i < bins_start_index.size(); ++i)
447+
{
448+
sample_weight_discretized[i]=static_cast<double>(observations_in_bins[i]);
441449
}
442450
}
443451
}

0 commit comments

Comments
 (0)