Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion argmin/src/core/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,8 @@ mod tests {
#[test]
#[cfg(feature = "serde1")]
fn test_checkpointing_solver_initialization() {
use std::fmt::Debug;

use crate::core::checkpointing::{CheckpointingFrequency, FileCheckpoint};
use crate::core::test_utils::TestProblem;
use crate::core::{ArgminFloat, CostFunction};
Expand All @@ -588,7 +590,7 @@ mod tests {
impl<O, P, F> Solver<O, IterState<P, (), (), (), F>> for OptimizationAlgorithm
where
O: CostFunction<Param = P, Output = F>,
P: Clone,
P: Clone + Debug,
F: ArgminFloat,
{
const NAME: &'static str = "OptimizationAlgorithm";
Expand Down
2 changes: 1 addition & 1 deletion argmin/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ pub use problem::{CostFunction, Gradient, Hessian, Jacobian, LinearProgram, Oper
pub use result::OptimizationResult;
pub use serialization::{DeserializeOwnedAlias, SerializeAlias};
pub use solver::Solver;
pub use state::{IterState, LinearProgramState, PopulationState, State};
pub use state::{IterState, LinearProgramState, PopulationState, State, StateData};
pub use termination::{TerminationReason, TerminationStatus};
103 changes: 95 additions & 8 deletions argmin/src/core/observers/slog_logger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//! See [`SlogLogger`] for details regarding usage.

use crate::core::observers::Observe;
use crate::core::state::StateData;
use crate::core::{Error, State, KV};
use slog;
use slog::{info, o, Drain, Key, Record, Serializer};
Expand All @@ -31,9 +32,38 @@ use std::sync::Mutex;
pub struct SlogLogger {
/// the logger
logger: slog::Logger,
/// Data to log. It is logged in order. Duplicates are not checked.
log_data: Vec<StateData>,
}

impl SlogLogger {
/// Specify the data to log. Data is logged in the order that it is specified
/// in the input `log_data` and duplicates are not removed.
///
/// The available data is any value obtained via the methods defined in the
/// [`State`] trait.
///
/// # Example
///
/// ```
/// use argmin::core::observers::SlogLogger;
/// use argmin::core::StateData;
///
/// // Default is to log function counts, best cost, cost, and iter. Modify
/// // it to also log the current parameters.
/// let mut log_data = Vec::new();
/// log_data.push(StateData::FunctionCounts);
/// log_data.push(StateData::BestCost);
/// log_data.push(StateData::Cost);
/// log_data.push(StateData::Iter);
/// log_data.push(StateData::Param);
/// let terminal_logger = SlogLogger::term().data(log_data);
/// ```
pub fn data(&mut self, log_data: Vec<StateData>) -> &mut Self {
self.log_data = log_data;
self
}

/// Log to the terminal.
///
/// Will block execution when buffer is full.
Expand Down Expand Up @@ -75,8 +105,15 @@ impl SlogLogger {
.overflow_strategy(overflow_strategy)
.build()
.fuse();
let log_data = vec![
StateData::FunctionCounts,
StateData::BestCost,
StateData::Cost,
StateData::Iter,
];
SlogLogger {
logger: slog::Logger::root(drain, o!()),
log_data,
}
}

Expand Down Expand Up @@ -138,8 +175,15 @@ impl SlogLogger {
.overflow_strategy(overflow_strategy)
.build()
.fuse();
let log_data = vec![
StateData::FunctionCounts,
StateData::BestCost,
StateData::Cost,
StateData::Iter,
];
Ok(SlogLogger {
logger: slog::Logger::root(drain, o!()),
log_data,
})
}
}
Expand All @@ -153,19 +197,62 @@ impl slog::KV for KV {
}
}

struct LogState<I>(I);
struct LogState<'a, I>(I, &'a [StateData]);

impl<I> slog::KV for LogState<&'_ I>
impl<'a, I> slog::KV for LogState<'a, &I>
where
I: State,
{
fn serialize(&self, _record: &Record, serializer: &mut dyn Serializer) -> slog::Result {
for (k, &v) in self.0.get_func_counts().iter() {
serializer.emit_u64(Key::from(k.clone()), v)?;
let state = self.0;
for data in self.1 {
let key = Key::from(data.to_string());
match data {
StateData::BestCost => {
serializer.emit_str(key, &state.get_best_cost().to_string())?;
}
StateData::BestParam => {
let param = state
.get_best_param()
.map_or("None".to_string(), |p| format!("{:?}", p));
serializer.emit_str(key, &param)?;
}
StateData::Cost => {
serializer.emit_str(key, &self.0.get_cost().to_string())?;
}
StateData::FunctionCounts => {
for (k, &v) in state.get_func_counts().iter() {
serializer.emit_u64(Key::from(k.clone()), v)?;
}
}
StateData::IsBest => serializer.emit_bool(key, state.is_best())?,
StateData::Iter => serializer.emit_u64(key, state.get_iter())?,
StateData::LastBestIter => serializer.emit_u64(key, state.get_last_best_iter())?,
StateData::MaxIters => serializer.emit_u64(key, state.get_max_iters())?,
StateData::Param => {
let param = state
.get_param()
.map_or("None".to_string(), |p| format!("{:?}", p));
serializer.emit_str(key, &param)?;
}
StateData::TargetCost => {
serializer.emit_str(key, &state.get_target_cost().to_string())?
}
StateData::TerminationReason => serializer.emit_str(
key,
state.get_termination_reason().map_or("None", |r| r.text()),
)?,
StateData::TerminationStatus => {
serializer.emit_str(key, &state.get_termination_status().to_string())?
}
StateData::Time => serializer.emit_str(
key,
&state
.get_time()
.map_or("None".to_string(), |t| format!("{:?}", t)),
)?,
}
}
serializer.emit_str(Key::from("best_cost"), &self.0.get_best_cost().to_string())?;
serializer.emit_str(Key::from("cost"), &self.0.get_cost().to_string())?;
serializer.emit_u64(Key::from("iter"), self.0.get_iter())?;
Ok(())
}
}
Expand All @@ -182,7 +269,7 @@ where

/// Logs information about the progress of the optimization after every iteration.
fn observe_iter(&mut self, state: &I, kv: &KV) -> Result<(), Error> {
info!(self.logger, ""; LogState(state), kv);
info!(self.logger, ""; LogState(state, &self.log_data), kv);
Ok(())
}
}
Expand Down
3 changes: 2 additions & 1 deletion argmin/src/core/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crate::core::{Error, Problem, State, TerminationReason, TerminationStatus, K
/// # Example
///
/// ```
/// use std::fmt::Debug;
/// use argmin::core::{
/// ArgminFloat, Solver, IterState, CostFunction, Error, KV, Problem, TerminationReason, TerminationStatus
/// };
Expand All @@ -33,7 +34,7 @@ use crate::core::{Error, Problem, State, TerminationReason, TerminationStatus, K
/// impl<O, P, G, J, H, F> Solver<O, IterState<P, G, J, H, F>> for OptimizationAlgorithm
/// where
/// O: CostFunction<Param = P, Output = F>,
/// P: Clone,
/// P: Clone + Debug,
/// F: ArgminFloat
/// {
/// const NAME: &'static str = "OptimizationAlgorithm";
Expand Down
4 changes: 2 additions & 2 deletions argmin/src/core/state/iterstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::core::{ArgminFloat, Problem, State, TerminationReason, TerminationSta
use instant;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::{collections::HashMap, fmt::Debug};

/// Maintains the state from iteration to iteration of a solver
///
Expand Down Expand Up @@ -864,7 +864,7 @@ where

impl<P, G, J, H, F> State for IterState<P, G, J, H, F>
where
P: Clone,
P: Clone + Debug,
F: ArgminFloat,
{
/// Type of parameter vector
Expand Down
4 changes: 2 additions & 2 deletions argmin/src/core/state/linearprogramstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::core::{ArgminFloat, Problem, State, TerminationReason, TerminationSta
use instant;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::{collections::HashMap, fmt::Debug};

/// Maintains the state from iteration to iteration of a solver
///
Expand Down Expand Up @@ -155,7 +155,7 @@ impl<P, F> LinearProgramState<P, F> {

impl<P, F> State for LinearProgramState<P, F>
where
P: Clone,
P: Clone + Debug,
F: ArgminFloat,
{
/// Type of parameter vector
Expand Down
56 changes: 54 additions & 2 deletions argmin/src/core/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,59 @@ pub use linearprogramstate::LinearProgramState;
pub use populationstate::PopulationState;

use crate::core::{ArgminFloat, Problem, TerminationReason, TerminationStatus};
use std::collections::HashMap;
use std::{collections::HashMap, fmt, fmt::Debug};

/// A set of values that can be queried using the [`State`] trait. Useful to configure
/// an observer.
#[derive(Copy, Clone, Debug)]
pub enum StateData {
/// The Param data
Param,
/// The best param data
BestParam,
/// The maximum number of iterations
MaxIters,
/// The iteration number
Iter,
/// The cost of the current iteration
Cost,
/// The best cost so far
BestCost,
/// The target cost
TargetCost,
/// How many times each function within the solver has been called
FunctionCounts,
/// The current time
Time,
/// Which iteration the last best cost was found
LastBestIter,
/// Boolean of if this iteration is the best
IsBest,
/// Basic yes/no status of if the solver has terminated
TerminationStatus,
/// If the solver has terminated, what was the reason
TerminationReason,
}

impl fmt::Display for StateData {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
StateData::BestCost => write!(f, "best_cost"),
StateData::BestParam => write!(f, "best_param"),
StateData::Cost => write!(f, "cost"),
StateData::FunctionCounts => write!(f, "function_counts"),
StateData::IsBest => write!(f, "is_best"),
StateData::Iter => write!(f, "iter"),
StateData::LastBestIter => write!(f, "last_best_iter"),
StateData::MaxIters => write!(f, "max_iters"),
StateData::Param => write!(f, "param"),
StateData::TargetCost => write!(f, "target_cost"),
StateData::TerminationReason => write!(f, "termination_reason"),
StateData::TerminationStatus => write!(f, "termination_status"),
StateData::Time => write!(f, "time"),
}
}
}

/// Minimal interface which struct used for managing state in solvers have to implement.
///
Expand Down Expand Up @@ -44,7 +96,7 @@ use std::collections::HashMap;
/// for this (so far f32 and f64).
pub trait State {
/// Type of parameter vector
type Param;
type Param: Debug;
/// Floating point precision (f32 or f64)
type Float: ArgminFloat;

Expand Down
4 changes: 2 additions & 2 deletions argmin/src/core/state/populationstate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::core::{ArgminFloat, Problem, State, TerminationReason, TerminationSta
use instant;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::{collections::HashMap, fmt::Debug};

/// Maintains the state from iteration to iteration of a population-based solver
///
Expand Down Expand Up @@ -434,7 +434,7 @@ where

impl<P, F> State for PopulationState<P, F>
where
P: Clone,
P: Clone + Debug,
F: ArgminFloat,
{
/// Type of an individual
Expand Down
3 changes: 3 additions & 0 deletions argmin/src/solver/conjugategradient/cg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use std::fmt::Debug;

use crate::core::{
ArgminFloat, Error, IterState, Operator, Problem, SerializeAlias, Solver, State, KV,
};
Expand Down Expand Up @@ -93,6 +95,7 @@ impl<P, O, F> Solver<O, IterState<P, (), (), (), F>> for ConjugateGradient<P, F>
where
O: Operator<Param = P, Output = P>,
P: Clone
+ Debug
+ SerializeAlias
+ ArgminDot<P, F>
+ ArgminSub<P, P>
Expand Down
4 changes: 3 additions & 1 deletion argmin/src/solver/conjugategradient/nonlinear_cg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use std::fmt::Debug;

use crate::core::{
ArgminFloat, CostFunction, DeserializeOwnedAlias, Error, Executor, Gradient, IterState,
LineSearch, NLCGBetaUpdate, OptimizationResult, Problem, SerializeAlias, Solver, State, KV,
Expand Down Expand Up @@ -122,7 +124,7 @@ impl<O, P, G, L, B, F> Solver<O, IterState<P, G, (), (), F>>
for NonlinearConjugateGradient<P, L, B, F>
where
O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
P: Clone + SerializeAlias + DeserializeOwnedAlias + ArgminAdd<P, P> + ArgminMul<F, P>,
P: Clone + Debug + SerializeAlias + DeserializeOwnedAlias + ArgminAdd<P, P> + ArgminMul<F, P>,
G: Clone
+ SerializeAlias
+ DeserializeOwnedAlias
Expand Down
4 changes: 3 additions & 1 deletion argmin/src/solver/gaussnewton/gaussnewton_linesearch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use std::fmt::Debug;

use crate::core::{
ArgminFloat, CostFunction, DeserializeOwnedAlias, Error, Executor, Gradient, IterState,
Jacobian, LineSearch, Operator, OptimizationResult, Problem, SerializeAlias, Solver,
Expand Down Expand Up @@ -84,7 +86,7 @@ impl<L, F: ArgminFloat> GaussNewtonLS<L, F> {
impl<O, L, F, P, G, J, U> Solver<O, IterState<P, G, J, (), F>> for GaussNewtonLS<L, F>
where
O: Operator<Param = P, Output = U> + Jacobian<Param = P, Jacobian = J>,
P: Clone + SerializeAlias + DeserializeOwnedAlias + ArgminMul<F, P>,
P: Clone + Debug + SerializeAlias + DeserializeOwnedAlias + ArgminMul<F, P>,
G: Clone + SerializeAlias + DeserializeOwnedAlias,
U: ArgminL2Norm<F>,
J: Clone
Expand Down
4 changes: 3 additions & 1 deletion argmin/src/solver/gaussnewton/gaussnewton_method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use std::fmt::Debug;

use crate::core::{
ArgminFloat, Error, IterState, Jacobian, Operator, Problem, Solver, State, TerminationReason,
TerminationStatus, KV,
Expand Down Expand Up @@ -112,7 +114,7 @@ impl<F: ArgminFloat> Default for GaussNewton<F> {
impl<O, F, P, J, U> Solver<O, IterState<P, (), J, (), F>> for GaussNewton<F>
where
O: Operator<Param = P, Output = U> + Jacobian<Param = P, Jacobian = J>,
P: Clone + ArgminSub<P, P> + ArgminMul<F, P>,
P: Clone + Debug + ArgminSub<P, P> + ArgminMul<F, P>,
U: ArgminL2Norm<F>,
J: Clone
+ ArgminTranspose<J>
Expand Down
Loading