Skip to content

Commit 4d3d5e9

Browse files
rstzcopybara-github
authored andcommitted
[YDF] Fix training logs crash
PiperOrigin-RevId: 870862134
1 parent b14e89d commit 4d3d5e9

File tree

3 files changed

+43
-10
lines changed

3 files changed

+43
-10
lines changed

yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -986,9 +986,14 @@ metric::proto::EvaluationResults TrainingLogToEvaluationResults(
986986
evaluation.set_loss_value(eval_set == TrainingLogEvaluationSet::kValidation
987987
? log_entry.validation_loss()
988988
: log_entry.training_loss());
989-
990-
for (int metrix_idx = 0;
991-
metrix_idx < training_logs.secondary_metric_names_size(); metrix_idx++) {
989+
int secondary_metric_size =
990+
eval_set == TrainingLogEvaluationSet::kValidation
991+
? log_entry.validation_secondary_metrics_size()
992+
: log_entry.training_secondary_metrics_size();
993+
secondary_metric_size = std::min(secondary_metric_size,
994+
training_logs.secondary_metric_names_size());
995+
996+
for (int metrix_idx = 0; metrix_idx < secondary_metric_size; metrix_idx++) {
992997
const auto& metric_name = training_logs.secondary_metric_names(metrix_idx);
993998
const auto metric_value =
994999
eval_set == TrainingLogEvaluationSet::kValidation

yggdrasil_decision_forests/port/python/ydf/model/gradient_boosted_trees_model/gradient_boosted_trees_model_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,30 @@ def test_training_logs(self):
9292
self.assertAlmostEqual(training_evaluation.loss, 0.5057407)
9393
self.assertAlmostEqual(training_evaluation.accuracy, 0.89436144)
9494

95+
def test_training_logs_with_newly_trained_model(self):
96+
dataset = {
97+
"x": np.array([0, 0, 1, 1] * 20),
98+
"y": np.array([0, 0, 1, 1] * 20),
99+
}
100+
model = specialized_learners.GradientBoostedTreesLearner(
101+
label="y",
102+
num_trees=5,
103+
validation_ratio=0.5,
104+
).train(dataset)
105+
106+
training_logs = model.training_logs()
107+
self.assertLen(training_logs, 5)
108+
109+
for log in training_logs:
110+
# Check validation evaluation
111+
self.assertIsNotNone(log.evaluation)
112+
self.assertTrue(hasattr(log.evaluation, "loss"))
113+
self.assertIsInstance(log.evaluation.loss, float)
114+
115+
# Check training evaluation
116+
self.assertIsNotNone(log.training_evaluation)
117+
self.assertTrue(hasattr(log.training_evaluation, "loss"))
118+
95119
def test_empty_training_logs(self):
96120
# This model has no training logs.
97121
training_logs = self.adult_binary_class_gbdt.training_logs()

yggdrasil_decision_forests/port/python/ydf/model/gradient_boosted_trees_model/gradient_boosted_trees_wrapper.cc

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include <pybind11/numpy.h>
1919

20+
#include <cmath>
2021
#include <cstring>
2122
#include <memory>
2223
#include <utility>
@@ -94,18 +95,21 @@ std::vector<GBTCCTrainingLogEntry> GradientBoostedTreesCCModel::training_logs()
9495
const auto& label_col_spec = gbt_model_->label_col_spec();
9596
logs.reserve(training_logs.entries_size());
9697
for (const auto& entry : training_logs.entries()) {
97-
const auto& validation_evaluation =
98-
model::gradient_boosted_trees::internal::TrainingLogToEvaluationResults(
99-
entry, training_logs, gbt_model_->task(), label_col_spec,
100-
gbt_model_->loss_config(), gbt_model_->GetLossName(),
101-
model::gradient_boosted_trees::internal::TrainingLogEvaluationSet::
102-
kValidation);
103-
const auto& training_evaluation =
98+
const auto training_evaluation =
10499
model::gradient_boosted_trees::internal::TrainingLogToEvaluationResults(
105100
entry, training_logs, gbt_model_->task(), label_col_spec,
106101
gbt_model_->loss_config(), gbt_model_->GetLossName(),
107102
model::gradient_boosted_trees::internal::TrainingLogEvaluationSet::
108103
kTraining);
104+
metric::proto::EvaluationResults validation_evaluation;
105+
if (!std::isnan(gbt_model_->validation_loss())) {
106+
validation_evaluation = model::gradient_boosted_trees::internal::
107+
TrainingLogToEvaluationResults(
108+
entry, training_logs, gbt_model_->task(), label_col_spec,
109+
gbt_model_->loss_config(), gbt_model_->GetLossName(),
110+
model::gradient_boosted_trees::internal::
111+
TrainingLogEvaluationSet::kValidation);
112+
}
109113
logs.push_back({.iteration = entry.number_of_trees(),
110114
.validation_evaluation = validation_evaluation,
111115
.training_evaluation = training_evaluation});

0 commit comments

Comments
 (0)