|
1 | | -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors |
| 1 | +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors |
2 | 2 | // |
3 | 3 | // SPDX-License-Identifier: BSD-3-Clause |
4 | 4 |
|
5 | 5 | #include "ginkgo/core/stop/residual_norm.hpp" |
6 | 6 |
|
| 7 | +#include <ginkgo/core/base/exception_helpers.hpp> |
7 | 8 | #include <ginkgo/core/base/precision_dispatch.hpp> |
| 9 | +#include <ginkgo/core/stop/criterion.hpp> |
8 | 10 |
|
9 | 11 | #include "core/base/dispatch_helper.hpp" |
10 | 12 | #include "core/components/fill_array_kernels.hpp" |
@@ -234,6 +236,164 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_RESIDUAL_NORM); |
234 | 236 | class ImplicitResidualNorm<_type> |
235 | 237 | GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_IMPLICIT_RESIDUAL_NORM); |
236 | 238 |
|
| 239 | +class ResidualNormFactory; |
| 240 | + |
| 241 | +struct residual_norm_factory_parameters |
| 242 | + : public enable_parameters_type<residual_norm_factory_parameters, |
| 243 | + ResidualNormFactory> { |
| 244 | + double GKO_FACTORY_PARAMETER_SCALAR(threshold, 0.0); |
| 245 | + |
| 246 | + mode GKO_FACTORY_PARAMETER_SCALAR(baseline, mode::rhs_norm); |
| 247 | + |
| 248 | + bool GKO_FACTORY_PARAMETER_SCALAR(implicit, false); |
| 249 | +}; |
| 250 | + |
| 251 | + |
| 252 | +class ResidualNormFactory |
| 253 | + : public EnablePolymorphicObject<ResidualNormFactory, CriterionFactory>, |
| 254 | + public EnablePolymorphicAssignment<ResidualNormFactory> { |
| 255 | + friend class EnablePolymorphicObject<ResidualNormFactory, CriterionFactory>; |
| 256 | + friend class enable_parameters_type<residual_norm_factory_parameters, |
| 257 | + ResidualNormFactory>; |
| 258 | + friend EnableDefaultCriterionFactory<ResidualNormFactory, Criterion, |
| 259 | + residual_norm_factory_parameters>; |
| 260 | + |
| 261 | + explicit ResidualNormFactory( |
| 262 | + std::shared_ptr<const Executor> exec, |
| 263 | + const residual_norm_factory_parameters& parameters = {}) |
| 264 | + : EnablePolymorphicObject<ResidualNormFactory, CriterionFactory>( |
| 265 | + std::move(exec)), |
| 266 | + parameters_{parameters} |
| 267 | + {} |
| 268 | + |
| 269 | + std::unique_ptr<Criterion> generate_impl(CriterionArgs args) const override |
| 270 | + { |
| 271 | + std::unique_ptr<Criterion> result; |
| 272 | + auto exec = this->get_executor(); |
| 273 | + run<matrix::Dense<double>, matrix::Dense<std::complex<double>>, |
| 274 | + matrix::Dense<float>, matrix::Dense<std::complex<float>> |
| 275 | +#if GINKGO_ENABLE_HALF |
| 276 | + , |
| 277 | + matrix::Dense<half>, matrix::Dense<std::complex<half>> |
| 278 | +#endif |
| 279 | +#if GINKGO_ENABLE_BFLOAT16 |
| 280 | + , |
| 281 | + matrix::Dense<bfloat16>, matrix::Dense<std::complex<bfloat16>> |
| 282 | +#endif |
| 283 | +#if GINKGO_BUILD_MPI |
| 284 | + , |
| 285 | + experimental::distributed::Vector<double>, |
| 286 | + experimental::distributed::Vector<std::complex<double>>, |
| 287 | + experimental::distributed::Vector<float>, |
| 288 | + experimental::distributed::Vector<std::complex<float>> |
| 289 | +#if GINKGO_ENABLE_HALF |
| 290 | + , |
| 291 | + experimental::distributed::Vector<half>, |
| 292 | + experimental::distributed::Vector<std::complex<half>> |
| 293 | +#endif |
| 294 | +#if GINKGO_ENABLE_BFLOAT16 |
| 295 | + , |
| 296 | + experimental::distributed::Vector<bfloat16>, |
| 297 | + experimental::distributed::Vector<std::complex<bfloat16>> |
| 298 | +#endif |
| 299 | +#endif |
| 300 | + >(args.b, [&](auto dense_b) { |
| 301 | + using value_type = |
| 302 | + typename std::decay_t<decltype(*dense_b)>::value_type; |
| 303 | + constexpr bool is_distributed = |
| 304 | + std::is_same_v<std::decay_t<decltype(*dense_b)>, |
| 305 | + experimental::distributed::Vector<value_type>>; |
| 306 | + using vector_type = std::conditional_t< |
| 307 | + is_distributed, experimental::distributed::Vector<value_type>, |
| 308 | + matrix::Dense<value_type>>; |
| 309 | + auto dense_x = as<vector_type>(args.x); |
| 310 | + auto dense_r = as<vector_type>(args.initial_residual); |
| 311 | + auto cast_threshold = static_cast<remove_complex<value_type>>( |
| 312 | + this->parameters_.threshold); |
| 313 | + auto cast_args = |
| 314 | + CriterionArgs{args.system_matrix, dense_b, dense_x, dense_r}; |
| 315 | + if (static_cast<double>(cast_threshold) <= 0.0) { |
| 316 | + GKO_INVALID_STATE( |
| 317 | + "stopping criterion threshold is zero or negative when " |
| 318 | + "cast to ValueType"); |
| 319 | + } |
| 320 | + if (this->parameters_.implicit) { |
| 321 | + result = ImplicitResidualNorm<value_type>::build() |
| 322 | + .with_baseline(this->parameters_.baseline) |
| 323 | + .with_reduction_factor(cast_threshold) |
| 324 | + .on(exec) |
| 325 | + ->generate(cast_args); |
| 326 | + } else { |
| 327 | + result = ResidualNorm<value_type>::build() |
| 328 | + .with_baseline(this->parameters_.baseline) |
| 329 | + .with_reduction_factor(cast_threshold) |
| 330 | + .on(exec) |
| 331 | + ->generate(cast_args); |
| 332 | + } |
| 333 | + }); |
| 334 | + return result; |
| 335 | + } |
| 336 | + |
| 337 | + residual_norm_factory_parameters parameters_; |
| 338 | +}; |
| 339 | + |
| 340 | + |
| 341 | +deferred_factory_parameter<CriterionFactory> absolute_residual_norm( |
| 342 | + double tolerance) |
| 343 | +{ |
| 344 | + return residual_norm_factory_parameters{} |
| 345 | + .with_threshold(tolerance) |
| 346 | + .with_baseline(mode::absolute); |
| 347 | +} |
| 348 | + |
| 349 | + |
| 350 | +deferred_factory_parameter<CriterionFactory> relative_residual_norm( |
| 351 | + double tolerance) |
| 352 | +{ |
| 353 | + return residual_norm_factory_parameters{} |
| 354 | + .with_threshold(tolerance) |
| 355 | + .with_baseline(mode::rhs_norm); |
| 356 | +} |
| 357 | + |
| 358 | + |
| 359 | +deferred_factory_parameter<CriterionFactory> initial_residual_norm( |
| 360 | + double tolerance) |
| 361 | +{ |
| 362 | + return residual_norm_factory_parameters{} |
| 363 | + .with_threshold(tolerance) |
| 364 | + .with_baseline(mode::initial_resnorm); |
| 365 | +} |
| 366 | + |
| 367 | + |
| 368 | +deferred_factory_parameter<CriterionFactory> absolute_implicit_residual_norm( |
| 369 | + double tolerance) |
| 370 | +{ |
| 371 | + return residual_norm_factory_parameters{} |
| 372 | + .with_threshold(tolerance) |
| 373 | + .with_baseline(mode::absolute) |
| 374 | + .with_implicit(true); |
| 375 | +} |
| 376 | + |
| 377 | + |
| 378 | +deferred_factory_parameter<CriterionFactory> relative_implicit_residual_norm( |
| 379 | + double tolerance) |
| 380 | +{ |
| 381 | + return residual_norm_factory_parameters{} |
| 382 | + .with_threshold(tolerance) |
| 383 | + .with_baseline(mode::rhs_norm) |
| 384 | + .with_implicit(true); |
| 385 | +} |
| 386 | + |
| 387 | + |
| 388 | +deferred_factory_parameter<CriterionFactory> initial_implicit_residual_norm( |
| 389 | + double tolerance) |
| 390 | +{ |
| 391 | + return residual_norm_factory_parameters{} |
| 392 | + .with_threshold(tolerance) |
| 393 | + .with_baseline(mode::initial_resnorm) |
| 394 | + .with_implicit(true); |
| 395 | +} |
| 396 | + |
237 | 397 |
|
238 | 398 | } // namespace stop |
239 | 399 | } // namespace gko |
0 commit comments