Skip to content

Commit 03704c4

Browse files
authored
Merge simpler stopping criterion interface
Since the value type is known from the vector dynamic type at generation time of the stopping criterion, we don't need to specify it manually, and can remove the value-typed threshold parameters. This effectively means replacing typed stopping criteria like ``` .with_criteria(gko::stop::ResidualNorm<float>::build().with_baseline(gko::stop::mode::absolute).with_reduction_factor(1e-5)) ``` by the simpler ``` .with_criteria(gko::stop::abs_residual_norm(1e-5)) ``` Related PR: #1947
2 parents 3427224 + 604335b commit 03704c4

File tree

19 files changed

+402
-28
lines changed

19 files changed

+402
-28
lines changed

core/config/multigrid_config.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

5+
#include <ginkgo/core/multigrid/pgm.hpp>
6+
57
#include "core/config/parse_macro.hpp"
6-
#include "ginkgo/core/multigrid/pgm.hpp"
78

89

910
namespace gko {

core/solver/pipe_cg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
#include <ginkgo/core/base/math.hpp>
1313
#include <ginkgo/core/base/name_demangling.hpp>
1414
#include <ginkgo/core/base/precision_dispatch.hpp>
15+
#include <ginkgo/core/base/range.hpp>
1516
#include <ginkgo/core/base/utils.hpp>
1617

1718
#include "core/config/solver_config.hpp"
1819
#include "core/distributed/helpers.hpp"
1920
#include "core/solver/pipe_cg_kernels.hpp"
2021
#include "core/solver/solver_boilerplate.hpp"
21-
#include "ginkgo/core/base/range.hpp"
2222

2323

2424
namespace gko {

core/stop/iteration.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

55
#include "ginkgo/core/stop/iteration.hpp"
66

7+
#include <ginkgo/core/base/abstract_factory.hpp>
8+
79

810
namespace gko {
911
namespace stop {
@@ -22,5 +24,11 @@ bool Iteration::check_impl(uint8 stoppingId, bool setFinalized,
2224
}
2325

2426

27+
deferred_factory_parameter<Iteration::Factory> max_iters(size_type count)
28+
{
29+
return Iteration::build().with_max_iters(count);
30+
}
31+
32+
2533
} // namespace stop
2634
} // namespace gko

core/stop/residual_norm.cpp

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

55
#include "ginkgo/core/stop/residual_norm.hpp"
66

7+
#include <ginkgo/core/base/exception_helpers.hpp>
78
#include <ginkgo/core/base/precision_dispatch.hpp>
9+
#include <ginkgo/core/stop/criterion.hpp>
810

911
#include "core/base/dispatch_helper.hpp"
1012
#include "core/components/fill_array_kernels.hpp"
@@ -234,6 +236,164 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_RESIDUAL_NORM);
234236
class ImplicitResidualNorm<_type>
235237
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_IMPLICIT_RESIDUAL_NORM);
236238

239+
class ResidualNormFactory;
240+
241+
struct residual_norm_factory_parameters
242+
: public enable_parameters_type<residual_norm_factory_parameters,
243+
ResidualNormFactory> {
244+
double GKO_FACTORY_PARAMETER_SCALAR(threshold, 0.0);
245+
246+
mode GKO_FACTORY_PARAMETER_SCALAR(baseline, mode::rhs_norm);
247+
248+
bool GKO_FACTORY_PARAMETER_SCALAR(implicit, false);
249+
};
250+
251+
252+
class ResidualNormFactory
253+
: public EnablePolymorphicObject<ResidualNormFactory, CriterionFactory>,
254+
public EnablePolymorphicAssignment<ResidualNormFactory> {
255+
friend class EnablePolymorphicObject<ResidualNormFactory, CriterionFactory>;
256+
friend class enable_parameters_type<residual_norm_factory_parameters,
257+
ResidualNormFactory>;
258+
friend EnableDefaultCriterionFactory<ResidualNormFactory, Criterion,
259+
residual_norm_factory_parameters>;
260+
261+
explicit ResidualNormFactory(
262+
std::shared_ptr<const Executor> exec,
263+
const residual_norm_factory_parameters& parameters = {})
264+
: EnablePolymorphicObject<ResidualNormFactory, CriterionFactory>(
265+
std::move(exec)),
266+
parameters_{parameters}
267+
{}
268+
269+
std::unique_ptr<Criterion> generate_impl(CriterionArgs args) const override
270+
{
271+
std::unique_ptr<Criterion> result;
272+
auto exec = this->get_executor();
273+
run<matrix::Dense<double>, matrix::Dense<std::complex<double>>,
274+
matrix::Dense<float>, matrix::Dense<std::complex<float>>
275+
#if GINKGO_ENABLE_HALF
276+
,
277+
matrix::Dense<half>, matrix::Dense<std::complex<half>>
278+
#endif
279+
#if GINKGO_ENABLE_BFLOAT16
280+
,
281+
matrix::Dense<bfloat16>, matrix::Dense<std::complex<bfloat16>>
282+
#endif
283+
#if GINKGO_BUILD_MPI
284+
,
285+
experimental::distributed::Vector<double>,
286+
experimental::distributed::Vector<std::complex<double>>,
287+
experimental::distributed::Vector<float>,
288+
experimental::distributed::Vector<std::complex<float>>
289+
#if GINKGO_ENABLE_HALF
290+
,
291+
experimental::distributed::Vector<half>,
292+
experimental::distributed::Vector<std::complex<half>>
293+
#endif
294+
#if GINKGO_ENABLE_BFLOAT16
295+
,
296+
experimental::distributed::Vector<bfloat16>,
297+
experimental::distributed::Vector<std::complex<bfloat16>>
298+
#endif
299+
#endif
300+
>(args.b, [&](auto dense_b) {
301+
using value_type =
302+
typename std::decay_t<decltype(*dense_b)>::value_type;
303+
constexpr bool is_distributed =
304+
std::is_same_v<std::decay_t<decltype(*dense_b)>,
305+
experimental::distributed::Vector<value_type>>;
306+
using vector_type = std::conditional_t<
307+
is_distributed, experimental::distributed::Vector<value_type>,
308+
matrix::Dense<value_type>>;
309+
auto dense_x = as<vector_type>(args.x);
310+
auto dense_r = as<vector_type>(args.initial_residual);
311+
auto cast_threshold = static_cast<remove_complex<value_type>>(
312+
this->parameters_.threshold);
313+
auto cast_args =
314+
CriterionArgs{args.system_matrix, dense_b, dense_x, dense_r};
315+
if (static_cast<double>(cast_threshold) <= 0.0) {
316+
GKO_INVALID_STATE(
317+
"stopping criterion threshold is zero or negative when "
318+
"cast to ValueType");
319+
}
320+
if (this->parameters_.implicit) {
321+
result = ImplicitResidualNorm<value_type>::build()
322+
.with_baseline(this->parameters_.baseline)
323+
.with_reduction_factor(cast_threshold)
324+
.on(exec)
325+
->generate(cast_args);
326+
} else {
327+
result = ResidualNorm<value_type>::build()
328+
.with_baseline(this->parameters_.baseline)
329+
.with_reduction_factor(cast_threshold)
330+
.on(exec)
331+
->generate(cast_args);
332+
}
333+
});
334+
return result;
335+
}
336+
337+
residual_norm_factory_parameters parameters_;
338+
};
339+
340+
341+
deferred_factory_parameter<CriterionFactory> absolute_residual_norm(
342+
double tolerance)
343+
{
344+
return residual_norm_factory_parameters{}
345+
.with_threshold(tolerance)
346+
.with_baseline(mode::absolute);
347+
}
348+
349+
350+
deferred_factory_parameter<CriterionFactory> relative_residual_norm(
351+
double tolerance)
352+
{
353+
return residual_norm_factory_parameters{}
354+
.with_threshold(tolerance)
355+
.with_baseline(mode::rhs_norm);
356+
}
357+
358+
359+
deferred_factory_parameter<CriterionFactory> initial_residual_norm(
360+
double tolerance)
361+
{
362+
return residual_norm_factory_parameters{}
363+
.with_threshold(tolerance)
364+
.with_baseline(mode::initial_resnorm);
365+
}
366+
367+
368+
deferred_factory_parameter<CriterionFactory> absolute_implicit_residual_norm(
369+
double tolerance)
370+
{
371+
return residual_norm_factory_parameters{}
372+
.with_threshold(tolerance)
373+
.with_baseline(mode::absolute)
374+
.with_implicit(true);
375+
}
376+
377+
378+
deferred_factory_parameter<CriterionFactory> relative_implicit_residual_norm(
379+
double tolerance)
380+
{
381+
return residual_norm_factory_parameters{}
382+
.with_threshold(tolerance)
383+
.with_baseline(mode::rhs_norm)
384+
.with_implicit(true);
385+
}
386+
387+
388+
deferred_factory_parameter<CriterionFactory> initial_implicit_residual_norm(
389+
double tolerance)
390+
{
391+
return residual_norm_factory_parameters{}
392+
.with_threshold(tolerance)
393+
.with_baseline(mode::initial_resnorm)
394+
.with_implicit(true);
395+
}
396+
237397

238398
} // namespace stop
239399
} // namespace gko

core/stop/time.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

55
#include "ginkgo/core/stop/time.hpp"
66

7+
#include <chrono>
8+
9+
#include <ginkgo/core/base/abstract_factory.hpp>
10+
711

812
namespace gko {
913
namespace stop {
@@ -22,5 +26,12 @@ bool Time::check_impl(uint8 stoppingId, bool setFinalized,
2226
}
2327

2428

29+
deferred_factory_parameter<Time::Factory> time_limit(
30+
std::chrono::nanoseconds time)
31+
{
32+
return Time::build().with_time_limit(time);
33+
}
34+
35+
2536
} // namespace stop
2637
} // namespace gko

core/test/matrix/csr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
#include <gtest/gtest.h>
66

77
#include <ginkgo/core/base/device_matrix_data.hpp>
8+
#include <ginkgo/core/base/exception.hpp>
89
#include <ginkgo/core/matrix/csr.hpp>
910

1011
#include "core/test/utils.hpp"
11-
#include "ginkgo/core/base/exception.hpp"
1212

1313

1414
namespace {

include/ginkgo/core/base/std_extensions.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#include <memory>
1212
#include <type_traits>
1313

14-
#include "ginkgo/core/base/types.hpp"
14+
#include <ginkgo/core/base/types.hpp>
1515

1616

1717
// This header provides implementations of useful utilities introduced into the

include/ginkgo/core/stop/iteration.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#define GKO_PUBLIC_CORE_STOP_ITERATION_HPP_
77

88

9+
#include <ginkgo/core/base/abstract_factory.hpp>
910
#include <ginkgo/core/stop/criterion.hpp>
1011

1112

@@ -58,6 +59,29 @@ class Iteration : public EnablePolymorphicObject<Iteration, Criterion> {
5859
};
5960

6061

62+
/**
63+
* Creates the precursor to an Iteration stopping criterion factory, to be used
64+
* in conjunction with `.with_criteria(...)` function calls when building a
65+
* solver factory. This stopping criterion will stop the iteration after `count`
66+
* iterations of the solver have finished.
67+
*
68+
* Full usage example: Stop after 100 iterations or when the relative residual
69+
* norm is below $10^{-10}$, whichever happens first.
70+
* ```cpp
71+
* auto factory = gko::solver::Cg<double>::build()
72+
* .with_criteria(
73+
* gko::stop::max_iters(100),
74+
* gko::stop::relative_residual_norm(1e-10))
75+
* .on(exec);
76+
* ```
77+
*
78+
* @param count the number of iterations after which to stop
79+
* @return a deferred_factory_parameter that can be passed to the
80+
* `with_criteria` function when building a solver.
81+
*/
82+
deferred_factory_parameter<Iteration::Factory> max_iters(size_type count);
83+
84+
6185
} // namespace stop
6286
} // namespace gko
6387

include/ginkgo/core/stop/residual_norm.hpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,61 @@ class ImplicitResidualNorm : public ResidualNormBase<ValueType> {
224224
};
225225

226226

227+
/**
228+
* Creates the precursor to a ResidualNorm stopping criterion factory, to be
229+
* used in conjunction with `.with_criteria(...)` function calls when building a
230+
* solver factory. This stopping criterion will stop the iteration after the
231+
* residual norm has decreased below the specified value or by the specified
232+
* amount.
233+
*
234+
* Full usage example: Stop after 100 iterations or when the absolute residual
235+
* norm is below $10^{-10}$, whichever happens first.
236+
* ```cpp
237+
* auto factory = gko::solver::Cg<double>::build()
238+
* .with_criteria(
239+
* gko::stop::max_iters(100),
240+
* gko::stop::absolute_residual_norm(1e-10))
241+
* .on(exec);
242+
* ```
243+
*
244+
* @param tolerance the value the residual norm needs to be below.
245+
* With residual $r$, initial guess $x_0$, right-hand side $b$, matrix $A$,
246+
* `absolute` means the exact value of the norm $||r||$,
247+
* `relative` means the norm relative to the right-hand side $||r||/||b||$,
248+
* `initial` means the norm relative to the initial residual
249+
* $||r||/||b - A x_0||$.
250+
* An implicit stopping criterion is only available with some solvers, and
251+
* refers to either the energy norm $||r||_A$ in short-recurrence solvers
252+
* like Cg or the euclidian norm $||r||$ in solvers like GMRES.
253+
* Implicit residual norms are cheaper to compute, but may be less precise
254+
* due to accumulating rounding errors.
255+
* @return a deferred_factory_parameter that can be passed to the
256+
* `with_criteria` function when building a solver.
257+
*/
258+
deferred_factory_parameter<CriterionFactory> absolute_residual_norm(
259+
double tolerance);
260+
261+
/** @copydoc absolute_residual_norm */
262+
deferred_factory_parameter<CriterionFactory> relative_residual_norm(
263+
double tolerance);
264+
265+
/** @copydoc absolute_residual_norm */
266+
deferred_factory_parameter<CriterionFactory> initial_residual_norm(
267+
double tolerance);
268+
269+
/** @copydoc absolute_residual_norm */
270+
deferred_factory_parameter<CriterionFactory> absolute_implicit_residual_norm(
271+
double tolerance);
272+
273+
/** @copydoc absolute_residual_norm */
274+
deferred_factory_parameter<CriterionFactory> relative_implicit_residual_norm(
275+
double tolerance);
276+
277+
/** @copydoc absolute_residual_norm */
278+
deferred_factory_parameter<CriterionFactory> initial_implicit_residual_norm(
279+
double tolerance);
280+
281+
227282
// The following classes are deprecated, but they internally reference
228283
// themselves. To reduce unnecessary warnings, we disable deprecation warnings
229284
// for the definition of these classes.

0 commit comments

Comments
 (0)