Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
"nroot",
"nstage",
"nstep",
"numpy",
"odyad",
"Oleg",
"oneapi",
Expand Down
75 changes: 75 additions & 0 deletions russell_lab/src/matrix/mat_to_numpy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use super::Matrix;
use std::fmt::Write;

/// Generates numpy code with a nested list representing a matrix
///
/// When saved to a file, we can execute the numpy script as:
///
/// ```text
/// math < /tmp/russell_lab/the_matrix.m
/// ```
///
/// # Examples
///
/// ```
/// use russell_lab::{mat_to_numpy, Matrix};
///
/// fn main() {
/// #[rustfmt::skip]
/// let a = Matrix::from(&[
/// [ 1.0, 2.0],
/// [-3.0, 4.0],
/// [ 5.0, 6.0e-10],
/// ]);
///
/// let res = mat_to_numpy("a_matrix", &a);
///
/// let correct = "a_matrix = np.array([[1,2],[-3,4],[5,0.0000000006]],dtype=float)";
///
/// assert_eq!(res, correct);
/// }
/// ```
pub fn mat_to_numpy(name: &str, a: &Matrix) -> String {
let (nrow, ncol) = a.dims();
let mut buf = String::new();
write!(&mut buf, "{} = np.array([", name).unwrap();
for i in 0..nrow {
if i > 0 {
write!(&mut buf, "],").unwrap();
}
for j in 0..ncol {
if j == 0 {
write!(&mut buf, "[").unwrap();
} else {
write!(&mut buf, ",").unwrap();
}
let val = a.get(i, j);
write!(&mut buf, "{}", val).unwrap();
}
}
write!(&mut buf, "]],dtype=float)").unwrap();
buf
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

#[cfg(test)]
mod tests {
use super::{mat_to_numpy, Matrix};

#[test]
fn mat_to_numpy_works() {
#[rustfmt::skip]
let a = Matrix::from(&[
[ 1.0, 20.0],
[ 3.01, 4.0],
[-55.0, 6.0],
[ 7.0, 8.0e-10],
]);
let res = mat_to_numpy("a_matrix", &a);
assert_eq!(
res,
"a_matrix = np.array([[1,20],[3.01,4],[-55,6],[7,0.0000000008]],dtype=float)"
)
}
}
2 changes: 2 additions & 0 deletions russell_lab/src/matrix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ mod mat_svd;
mod mat_sym_rank_op;
mod mat_t_mat_mul;
mod mat_to_mathematica;
mod mat_to_numpy;
mod mat_to_static_array;
mod mat_update;
mod mat_write_vismatrix;
Expand Down Expand Up @@ -84,6 +85,7 @@ pub use mat_svd::*;
pub use mat_sym_rank_op::*;
pub use mat_t_mat_mul::*;
pub use mat_to_mathematica::*;
pub use mat_to_numpy::*;
pub use mat_to_static_array::*;
pub use mat_update::*;
pub use mat_write_vismatrix::*;
Expand Down