forked from rust-lang/rust
-
Notifications
You must be signed in to change notification settings - Fork 8
Open
EnzymeAD/Enzyme
#1921Description
I am trying to use the Rust library faer
with Enzyme and am running into problems. I am not able to pass in references to faer::Col<T>
objects for some reason and I cannot figure out how to get it to work.
Prior to these examples I exported the enzyme debug env variables:
export ENZYME_PRINT_TA=1
export ENZYME_PRINT_AA=1
export ENZYME_PRINT=1
export ENZYME_PRINT_MOD=1
export ENZYME_PRINT_MOD_AFTER=1
and then ran
cargo +enzyme build
Example: $f(\vec{x}) = \vec{c} \cdot \vec{x}$
Initial Idea
#![feature(autodiff)]
use faer::{col, Col};
#[autodiff(df, Reverse, Duplicated, Duplicated)]
fn f(x: &Col<f64>, y: &mut f64) {
let c: Col<f64> = col![1.0, 2.0, 3.0];
*y = c.transpose() * x;
}
fn main() {
let x: Col<f64> = col![1.0, 1.0, 1.0];
let mut y: f64 = 0.0;
let mut grad: Col<f64> = col![0.0, 0.0, 0.0];
let mut seed: f64 = 1.0;
df(&x, &mut grad,
&mut y, &mut seed);
// Should be 1.0, 1.0, 1.0
println!(" f evaluated at:");
x.iter().enumerate()
.for_each(|(i, e)| {
println!("x[{i}] = {e}");
});
// Should be 1.0, 2.0, 3.0
println!("\n gradient of f at x:");
grad.iter().enumerate()
.for_each(|(i, e)| {
println!("grad[{i}] = {e}");
});
// Should be 6.0
println!("\n f at x:");
println!("f(x) = {y}");
}
Output: ex1.txt
Tried This Next
#![feature(autodiff)]
use faer::{col, Col, ColRef};
#[autodiff(df, Reverse, Duplicated, Duplicated)]
fn f(x: &[f64], y: &mut f64) {
let c: Col<f64> = col![1.0, 2.0, 3.0];
let x: ColRef<f64> = col::from_slice(x); // Shadow previous x
*y = c.transpose() * x;
}
fn main() {
let x: Col<f64> = col![1.0, 1.0, 1.0];
let mut y: f64 = 0.0;
let mut grad: Col<f64> = col![0.0, 0.0, 0.0];
let mut seed: f64 = 1.0;
df(x.as_slice(), grad.as_slice_mut(),
&mut y, &mut seed);
// Should be 1.0, 1.0, 1.0
println!(" f evaluated at:");
x.iter().enumerate()
.for_each(|(i, e)| {
println!("x[{i}] = {e}");
});
// Should be 1.0, 2.0, 3.0
println!("\n gradient of f at x:");
grad.iter().enumerate()
.for_each(|(i, e)| {
println!("grad[{i}] = {e}");
});
// Should be 6.0
println!("\n f at x:");
println!("f(x) = {y}");
}
Output: ex2.txt
Could not end up getting this to work.
Here is my cargo.toml
file:
[package]
name = "learning-enzyme-faer"
version = "0.1.0"
edition = "2021"
[profile.release]
lto = "fat"
[profile.dev]
lto = "fat"
[toolchain]
channel = "enzyme"
[dependencies]
faer = "0.19"
Metadata
Metadata
Assignees
Labels
No labels