|
6 | 6 | #include <xgboost/objective.h>
|
7 | 7 |
|
8 | 8 | #include "../helpers.h"
|
| 9 | +#include "../objective_helpers.h" |
9 | 10 |
|
10 | 11 | TEST(Objective, UnknownFunction) {
|
11 | 12 | xgboost::ObjFunction* obj = nullptr;
|
@@ -43,4 +44,61 @@ TEST(Objective, PredTransform) {
|
43 | 44 | ASSERT_TRUE(predts.HostCanWrite());
|
44 | 45 | }
|
45 | 46 | }
|
| 47 | + |
| 48 | +class TestDefaultObjConfig : public ::testing::TestWithParam<std::string> { |
| 49 | + Context ctx_; |
| 50 | + |
| 51 | + public: |
| 52 | + void Run(std::string objective) { |
| 53 | + auto Xy = MakeFmatForObjTest(objective); |
| 54 | + std::unique_ptr<Learner> learner{Learner::Create({Xy})}; |
| 55 | + std::unique_ptr<ObjFunction> objfn{ObjFunction::Create(objective, &ctx_)}; |
| 56 | + |
| 57 | + learner->SetParam("objective", objective); |
| 58 | + if (objective.find("multi") != std::string::npos) { |
| 59 | + learner->SetParam("num_class", "3"); |
| 60 | + objfn->Configure(Args{{"num_class", "3"}}); |
| 61 | + } else if (objective.find("quantile") != std::string::npos) { |
| 62 | + learner->SetParam("quantile_alpha", "0.5"); |
| 63 | + objfn->Configure(Args{{"quantile_alpha", "0.5"}}); |
| 64 | + } else { |
| 65 | + objfn->Configure(Args{}); |
| 66 | + } |
| 67 | + learner->Configure(); |
| 68 | + learner->UpdateOneIter(0, Xy); |
| 69 | + learner->EvalOneIter(0, {Xy}, {"train"}); |
| 70 | + Json config{Object{}}; |
| 71 | + learner->SaveConfig(&config); |
| 72 | + auto jobj = get<Object const>(config["learner"]["objective"]); |
| 73 | + |
| 74 | + ASSERT_TRUE(jobj.find("name") != jobj.cend()); |
| 75 | + // FIXME(jiamingy): We should have the following check, but some legacy parameter like |
| 76 | + // "pos_weight", "delta_step" in objectives are not in metrics. |
| 77 | + |
| 78 | + // if (jobj.size() > 1) { |
| 79 | + // ASSERT_FALSE(IsA<Null>(objfn->DefaultMetricConfig())); |
| 80 | + // } |
| 81 | + auto mconfig = objfn->DefaultMetricConfig(); |
| 82 | + if (!IsA<Null>(mconfig)) { |
| 83 | + // make sure metric can handle it |
| 84 | + std::unique_ptr<Metric> metricfn{Metric::Create(get<String const>(mconfig["name"]), &ctx_)}; |
| 85 | + metricfn->LoadConfig(mconfig); |
| 86 | + Json loaded(Object{}); |
| 87 | + metricfn->SaveConfig(&loaded); |
| 88 | + metricfn->Configure(Args{}); |
| 89 | + ASSERT_EQ(mconfig, loaded); |
| 90 | + } |
| 91 | + } |
| 92 | +}; |
| 93 | + |
| 94 | +TEST_P(TestDefaultObjConfig, Objective) { |
| 95 | + std::string objective = GetParam(); |
| 96 | + this->Run(objective); |
| 97 | +} |
| 98 | + |
| 99 | +INSTANTIATE_TEST_SUITE_P(Objective, TestDefaultObjConfig, |
| 100 | + ::testing::ValuesIn(MakeObjNamesForTest()), |
| 101 | + [](const ::testing::TestParamInfo<TestDefaultObjConfig::ParamType>& info) { |
| 102 | + return ObjTestNameGenerator(info); |
| 103 | + }); |
46 | 104 | } // namespace xgboost
|
0 commit comments