Skip to content

Commit 6e76f39

Browse files
committed
add min_iters to simple criterion
1 parent 8124ab4 commit 6e76f39

File tree

4 files changed

+82
-5
lines changed

4 files changed

+82
-5
lines changed

core/config/config_helper.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <ginkgo/core/base/exception_helpers.hpp>
1010
#include <ginkgo/core/base/lin_op.hpp>
1111
#include <ginkgo/core/config/registry.hpp>
12+
#include <ginkgo/core/stop/iteration.hpp>
1213

1314
#include "core/config/registry_accessor.hpp"
1415
#include "core/config/stop_config.hpp"
@@ -118,21 +119,32 @@ parse_minimal_criteria(const pnode& config, const registry& context,
118119
// We use additional scope such that we check it before the following map
119120
// throw exception
120121
{
121-
config_check_decorator config_check(config);
122+
config_check_decorator config_check(config, {{"min_iters"}});
122123
for (const auto& it : criterion_map) {
123124
config_check.get(it.first);
124125
}
125126
}
126127

128+
127129
std::vector<deferred_factory_parameter<const stop::CriterionFactory>> res;
128130
for (const auto& it : config.get_map()) {
129-
if (it.first == "value_type") {
131+
if (it.first == "value_type" || it.first == "min_iters") {
130132
continue;
131133
}
132134
res.emplace_back(
133135
criterion_map.at(it.first)(config, context, updated_td));
134136
}
135-
return res;
137+
if (auto& obj = config.get("min_iters")) {
138+
auto counts = get_value<size_type>(obj);
139+
if (res.size() == 1) {
140+
return {stop::min_iters(counts, res.at(0))};
141+
} else {
142+
return {stop::min_iters(
143+
counts, stop::Combined::build().with_criteria(res))};
144+
}
145+
} else {
146+
return res;
147+
}
136148
}
137149

138150

core/stop/iteration.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ deferred_factory_parameter<Iteration::Factory> max_iters(size_type count)
3131

3232

3333
deferred_factory_parameter<CriterionFactory> min_iters(
34-
size_type count, deferred_factory_parameter<CriterionFactory> criterion)
34+
size_type count,
35+
deferred_factory_parameter<const CriterionFactory> criterion)
3536
{
3637
return MinIterationWrapper::build()
3738
.with_min_iters(count)

core/test/config/config.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <ginkgo/core/stop/time.hpp>
1717

1818
#include "core/config/config_helper.hpp"
19+
#include "core/stop/iteration.hpp"
1920
#include "core/test/utils.hpp"
2021

2122

@@ -286,6 +287,68 @@ TEST_F(Config, GenerateCriteriaFromMinimalConfigWithValueType)
286287
}
287288

288289

290+
TEST_F(Config, GenerateOneCriteriaFromMinimalConfigWithMinIters)
291+
{
292+
auto reg = registry();
293+
reg.emplace("precond", this->mtx);
294+
pnode minimal_stop{{{"iteration", pnode{5}}, {"min_iters", pnode{10}}}};
295+
296+
pnode p{{{"criteria", minimal_stop}}};
297+
auto obj = std::dynamic_pointer_cast<gko::solver::Cg<float>::Factory>(
298+
parse<LinOpFactoryType::Cg>(p, reg, type_descriptor{"float32", "void"})
299+
.on(this->exec));
300+
301+
ASSERT_NE(obj, nullptr);
302+
auto criteria = gko::as<gko::stop::MinIterationWrapper::Factory>(
303+
obj->get_parameters().criteria.at(0));
304+
auto inner = gko::as<gko::stop::Iteration::Factory>(
305+
criteria->get_parameters().inner_criterion);
306+
ASSERT_EQ(criteria->get_parameters().min_iters, 10);
307+
ASSERT_EQ(inner->get_parameters().max_iters, 5);
308+
}
309+
310+
311+
TEST_F(Config, GenerateCriteriaFromMinimalConfigWithMinIters)
312+
{
313+
auto reg = registry();
314+
reg.emplace("precond", this->mtx);
315+
pnode minimal_stop{{{"iteration", pnode{5}},
316+
{"time", pnode{100}},
317+
{"min_iters", pnode{10}}}};
318+
319+
pnode p{{{"criteria", minimal_stop}}};
320+
auto obj = std::dynamic_pointer_cast<gko::solver::Cg<float>::Factory>(
321+
parse<LinOpFactoryType::Cg>(p, reg, type_descriptor{"float32", "void"})
322+
.on(this->exec));
323+
324+
ASSERT_NE(obj, nullptr);
325+
auto criteria = gko::as<gko::stop::MinIterationWrapper::Factory>(
326+
obj->get_parameters().criteria.at(0));
327+
auto inner = gko::as<gko::stop::Combined::Factory>(
328+
criteria->get_parameters().inner_criterion);
329+
ASSERT_EQ(criteria->get_parameters().min_iters, 10);
330+
ASSERT_EQ(inner->get_parameters().criteria.size(), 2);
331+
if (auto inner1 = gko::as<gko::stop::Iteration::Factory>(
332+
inner->get_parameters().criteria.at(0))) {
333+
auto inner2 = gko::as<gko::stop::Time::Factory>(
334+
inner->get_parameters().criteria.at(1));
335+
ASSERT_EQ(inner1->get_parameters().max_iters, 5);
336+
ASSERT_EQ(inner2->get_parameters().time_limit,
337+
std::chrono::nanoseconds(100));
338+
} else if (auto inner1 = gko::as<gko::stop::Time::Factory>(
339+
inner->get_parameters().criteria.at(0))) {
340+
auto inner2 = gko::as<gko::stop::Iteration::Factory>(
341+
inner->get_parameters().criteria.at(1));
342+
ASSERT_EQ(inner2->get_parameters().max_iters, 5);
343+
ASSERT_EQ(inner1->get_parameters().time_limit,
344+
std::chrono::nanoseconds(100));
345+
} else {
346+
ASSERT_TRUE(false)
347+
<< "the first criterion is not ResidualNorm or Time.";
348+
}
349+
}
350+
351+
289352
TEST_F(Config, MinimalConfigThrowWhenKeyIsInvalid)
290353
{
291354
pnode minimal_stop{{{"time", pnode{100}}, {"invalid", pnode{"no"}}}};

include/ginkgo/core/stop/iteration.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ deferred_factory_parameter<Iteration::Factory> max_iters(size_type count);
108108
* `with_criteria` function when building a solver.
109109
*/
110110
deferred_factory_parameter<CriterionFactory> min_iters(
111-
size_type count, deferred_factory_parameter<CriterionFactory> criterion);
111+
size_type count,
112+
deferred_factory_parameter<const CriterionFactory> criterion);
112113

113114

114115
/**

0 commit comments

Comments
 (0)