Skip to content

Commit f2d976a

Browse files
authored
[backport] Avoid using mean intercept for rmsle. (dmlc#11588) (dmlc#11593)
1 parent 05daeb6 commit f2d976a

File tree

6 files changed

+95
-54
lines changed

6 files changed

+95
-54
lines changed

src/common/pseudo_huber.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
*/
44
#include "pseudo_huber.h"
55
namespace xgboost {
6-
DMLC_REGISTER_PARAMETER(PesudoHuberParam);
6+
DMLC_REGISTER_PARAMETER(PseudoHuberParam);
77
}

src/common/pseudo_huber.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
#include "xgboost/parameter.h"
77

88
namespace xgboost {
9-
struct PesudoHuberParam : public XGBoostParameter<PesudoHuberParam> {
9+
struct PseudoHuberParam : public XGBoostParameter<PseudoHuberParam> {
1010
float huber_slope{1.0};
1111

12-
DMLC_DECLARE_PARAMETER(PesudoHuberParam) {
12+
DMLC_DECLARE_PARAMETER(PseudoHuberParam) {
1313
DMLC_DECLARE_FIELD(huber_slope)
1414
.set_default(1.0f)
1515
.describe("The delta term in Pseudo-Huber loss.");

src/metric/elementwise_metric.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ struct EvalRowLogLoss {
182182
};
183183

184184
class PseudoErrorLoss : public MetricNoCache {
185-
PesudoHuberParam param_;
185+
PseudoHuberParam param_;
186186

187187
public:
188188
const char* Name() const override { return "mphe"; }

src/objective/quantile_obj.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2023-2024, XGBoost contributors
2+
* Copyright 2023-2025, XGBoost contributors
33
*/
44
#include <array> // std::array
55
#include <cstddef> // std::size_t
@@ -10,15 +10,13 @@
1010
#include "../common/quantile_loss_utils.h" // QuantileLossParam
1111
#include "../common/stats.h" // Quantile,WeightedQuantile
1212
#include "adaptive.h" // UpdateTreeLeaf
13-
#include "dmlc/parameter.h" // DMLC_DECLARE_PARAMETER
1413
#include "init_estimation.h" // CheckInitInputs
1514
#include "xgboost/base.h" // GradientPair,XGBOOST_DEVICE,bst_target_t
1615
#include "xgboost/data.h" // MetaInfo
1716
#include "xgboost/host_device_vector.h" // HostDeviceVector
1817
#include "xgboost/json.h" // Json,String,ToJson,FromJson
1918
#include "xgboost/linalg.h" // Tensor,MakeTensorView,MakeVec
2019
#include "xgboost/objective.h" // ObjFunction
21-
#include "xgboost/parameter.h" // XGBoostParameter
2220

2321
#if defined(XGBOOST_USE_CUDA)
2422

src/objective/regression_loss.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2017-2023 by XGBoost contributors
2+
* Copyright 2017-2025, XGBoost contributors
33
*/
44
#ifndef XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_
55
#define XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_
@@ -9,7 +9,6 @@
99
#include <cmath>
1010

1111
#include "../common/math.h"
12-
#include "xgboost/data.h" // MetaInfo
1312
#include "xgboost/logging.h"
1413
#include "xgboost/task.h" // ObjInfo
1514

src/objective/regression_obj.cu

Lines changed: 89 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2015-2024, XGBoost Contributors
2+
* Copyright 2015-2025, XGBoost Contributors
33
* \file regression_obj.cu
44
* \brief Definition of single-value regression and classification objectives.
55
* \author Tianqi Chen, Kailong Chen
@@ -9,7 +9,6 @@
99
#include <algorithm>
1010
#include <cmath>
1111
#include <cstdint> // std::int32_t
12-
#include <memory>
1312
#include <vector>
1413

1514
#include "../common/common.h"
@@ -53,56 +52,56 @@ void CheckRegInputs(MetaInfo const& info, HostDeviceVector<bst_float> const& pre
5352
CheckInitInputs(info);
5453
CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels.";
5554
}
55+
56+
template <typename Loss>
57+
void ValidateLabel(Context const* ctx, MetaInfo const& info) {
58+
auto label = info.labels.View(ctx->Device());
59+
auto valid = ctx->DispatchDevice(
60+
[&] {
61+
return std::all_of(linalg::cbegin(label), linalg::cend(label),
62+
[](float y) -> bool { return Loss::CheckLabel(y); });
63+
},
64+
[&] {
65+
#if defined(XGBOOST_USE_CUDA)
66+
auto cuctx = ctx->CUDACtx();
67+
auto it = dh::MakeTransformIterator<bool>(
68+
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) -> bool {
69+
auto [m, n] = linalg::UnravelIndex(i, label.Shape());
70+
return Loss::CheckLabel(label(m, n));
71+
});
72+
return dh::Reduce(cuctx->CTP(), it, it + label.Size(), true, thrust::logical_and<>{});
73+
#else
74+
common::AssertGPUSupport();
75+
return false;
76+
#endif // defined(XGBOOST_USE_CUDA)
77+
},
78+
[&] {
79+
#if defined(XGBOOST_USE_SYCL)
80+
return sycl::linalg::Validate(ctx->Device(), label,
81+
[](float y) -> bool { return Loss::CheckLabel(y); });
82+
#else
83+
common::AssertSYCLSupport();
84+
return false;
85+
#endif // defined(XGBOOST_USE_SYCL)
86+
});
87+
if (!valid) {
88+
LOG(FATAL) << Loss::LabelErrorMsg();
89+
}
90+
}
5691
} // anonymous namespace
5792

5893
#if defined(XGBOOST_USE_CUDA)
5994
DMLC_REGISTRY_FILE_TAG(regression_obj_gpu);
6095
#endif // defined(XGBOOST_USE_CUDA)
6196

62-
63-
6497
template<typename Loss>
6598
class RegLossObj : public FitInterceptGlmLike {
6699
protected:
67100
HostDeviceVector<float> additional_input_;
68101

69102
public:
70-
void ValidateLabel(MetaInfo const& info) {
71-
auto label = info.labels.View(ctx_->Device());
72-
auto valid = ctx_->DispatchDevice(
73-
[&] {
74-
return std::all_of(linalg::cbegin(label), linalg::cend(label),
75-
[](float y) -> bool { return Loss::CheckLabel(y); });
76-
},
77-
[&] {
78-
#if defined(XGBOOST_USE_CUDA)
79-
auto cuctx = ctx_->CUDACtx();
80-
auto it = dh::MakeTransformIterator<bool>(
81-
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) -> bool {
82-
auto [m, n] = linalg::UnravelIndex(i, label.Shape());
83-
return Loss::CheckLabel(label(m, n));
84-
});
85-
return dh::Reduce(cuctx->CTP(), it, it + label.Size(), true, thrust::logical_and<>{});
86-
#else
87-
common::AssertGPUSupport();
88-
return false;
89-
#endif // defined(XGBOOST_USE_CUDA)
90-
},
91-
[&] {
92-
#if defined(XGBOOST_USE_SYCL)
93-
return sycl::linalg::Validate(ctx_->Device(), label,
94-
[](float y) -> bool { return Loss::CheckLabel(y); });
95-
#else
96-
common::AssertSYCLSupport();
97-
return false;
98-
#endif // defined(XGBOOST_USE_SYCL)
99-
});
100-
if (!valid) {
101-
LOG(FATAL) << Loss::LabelErrorMsg();
102-
}
103-
}
104103
// 0 - scale_pos_weight, 1 - is_null_weight
105-
RegLossObj(): additional_input_(2) {}
104+
RegLossObj() : additional_input_(2) {}
106105

107106
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
108107
param_.UpdateAllowUnknown(args);
@@ -119,7 +118,7 @@ class RegLossObj : public FitInterceptGlmLike {
119118
std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) override {
120119
CheckRegInputs(info, preds);
121120
if (iter == 0) {
122-
ValidateLabel(info);
121+
ValidateLabel<Loss>(this->ctx_, info);
123122
}
124123

125124
size_t const ndata = preds.Size();
@@ -222,10 +221,6 @@ XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegression, LinearSquareLoss::Name())
222221
.describe("Regression with squared error.")
223222
.set_body([]() { return new RegLossObj<LinearSquareLoss>(); });
224223

225-
XGBOOST_REGISTER_OBJECTIVE(SquareLogError, SquaredLogError::Name())
226-
.describe("Regression with root mean squared logarithmic error.")
227-
.set_body([]() { return new RegLossObj<SquaredLogError>(); });
228-
229224
XGBOOST_REGISTER_OBJECTIVE(LogisticRegression, LogisticRegression::Name())
230225
.describe("Logistic regression for probability regression task.")
231226
.set_body([]() { return new RegLossObj<LogisticRegression>(); });
@@ -251,8 +246,57 @@ XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear")
251246
return new RegLossObj<LinearSquareLoss>(); });
252247
// End deprecated
253248

249+
class SquaredLogErrorRegression : public FitIntercept {
250+
public:
251+
static auto Name() { return SquaredLogError::Name(); }
252+
253+
void Configure(Args const&) override {}
254+
[[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; }
255+
[[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override {
256+
return std::max(static_cast<std::size_t>(1), info.labels.Shape(1));
257+
}
258+
void GetGradient(HostDeviceVector<bst_float> const& preds, const MetaInfo& info,
259+
std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) override {
260+
if (iter == 0) {
261+
ValidateLabel<SquaredLogError>(this->ctx_, info);
262+
}
263+
auto labels = info.labels.View(ctx_->Device());
264+
265+
out_gpair->SetDevice(ctx_->Device());
266+
out_gpair->Reshape(info.num_row_, this->Targets(info));
267+
auto gpair = out_gpair->View(ctx_->Device());
268+
269+
preds.SetDevice(ctx_->Device());
270+
auto predt = linalg::MakeTensorView(ctx_, &preds, info.num_row_, this->Targets(info));
271+
272+
info.weights_.SetDevice(ctx_->Device());
273+
common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan()
274+
: info.weights_.ConstDeviceSpan()};
275+
linalg::ElementWiseKernel(this->ctx_, labels,
276+
[=] XGBOOST_DEVICE(std::size_t i, std::size_t j) mutable {
277+
auto p = predt(i, j);
278+
auto y = labels(i, j);
279+
auto w = weight[i];
280+
auto grad = SquaredLogError::FirstOrderGradient(p, y);
281+
auto hess = SquaredLogError::SecondOrderGradient(p, y);
282+
gpair(i) = {grad * w, hess * w};
283+
});
284+
}
285+
[[nodiscard]] const char* DefaultEvalMetric() const override { return "rmsle"; }
286+
287+
void SaveConfig(Json* p_out) const override {
288+
auto& out = *p_out;
289+
out["name"] = String(Name());
290+
}
291+
void LoadConfig(Json const&) override {}
292+
};
293+
294+
XGBOOST_REGISTER_OBJECTIVE(SquaredLogErrorRegression, SquaredLogErrorRegression::Name())
295+
.describe("Root mean squared log error.")
296+
.set_body([]() { return new SquaredLogErrorRegression(); });
297+
254298
class PseudoHuberRegression : public FitIntercept {
255-
PesudoHuberParam param_;
299+
PseudoHuberParam param_;
256300

257301
public:
258302
void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); }

0 commit comments

Comments
 (0)