Skip to content

Commit cf19f90

Browse files
authored
Rollup merge of rust-lang#142640 - Sa4dUs:ad-intrinsic, r=ZuseZ4
Implement autodiff using intrinsics This PR aims to move autodiff logic to `autodiff` intrinsic. Allowing us to delete a great part of our frontend code and overall, simplify the compilation pipeline of autodiff functions.
2 parents f753cd4 + 86c250f commit cf19f90

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

core/src/intrinsics/mod.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3157,6 +3157,44 @@ pub const unsafe fn copysignf64(x: f64, y: f64) -> f64;
31573157
#[rustc_intrinsic]
31583158
pub const unsafe fn copysignf128(x: f128, y: f128) -> f128;
31593159

3160+
/// Generates the LLVM body for the automatic differentiation of `f` using Enzyme,
3161+
/// with `df` as the derivative function and `args` as its arguments.
3162+
///
3163+
/// Used internally as the body of `df` when expanding the `#[autodiff_forward]`
3164+
/// and `#[autodiff_reverse]` attribute macros.
3165+
///
3166+
/// Type Parameters:
3167+
/// - `F`: The original function to differentiate. Must be a function item.
3168+
/// - `G`: The derivative function. Must be a function item.
3169+
/// - `T`: A tuple of arguments passed to `df`.
3170+
/// - `R`: The return type of the derivative function.
3171+
///
3172+
/// This shows where the `autodiff` intrinsic is used during macro expansion:
3173+
///
3174+
/// ```rust,ignore (macro example)
3175+
/// #[autodiff_forward(df1, Dual, Const, Dual)]
3176+
/// pub fn f1(x: &[f64], y: f64) -> f64 {
3177+
/// unimplemented!()
3178+
/// }
3179+
/// ```
3180+
///
3181+
/// expands to:
3182+
///
3183+
/// ```rust,ignore (macro example)
3184+
/// #[rustc_autodiff]
3185+
/// #[inline(never)]
3186+
/// pub fn f1(x: &[f64], y: f64) -> f64 {
3187+
/// ::core::panicking::panic("not implemented")
3188+
/// }
3189+
/// #[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
3190+
/// pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) {
3191+
/// ::core::intrinsics::autodiff(f1::<>, df1::<>, (x, bx_0, y))
3192+
/// }
3193+
/// ```
3194+
#[rustc_nounwind]
3195+
#[rustc_intrinsic]
3196+
pub const fn autodiff<F, G, T: crate::marker::Tuple, R>(f: F, df: G, args: T) -> R;
3197+
31603198
/// Inform Miri that a given pointer definitely has a certain alignment.
31613199
#[cfg(miri)]
31623200
#[rustc_allow_const_fn_unstable(const_eval_select)]

core/src/macros/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,6 +1495,7 @@ pub(crate) mod builtin {
14951495
/// (or explicitly returns `-> ()`). Otherwise, it must be set to one of the allowed activities.
14961496
#[unstable(feature = "autodiff", issue = "124509")]
14971497
#[allow_internal_unstable(rustc_attrs)]
1498+
#[allow_internal_unstable(core_intrinsics)]
14981499
#[rustc_builtin_macro]
14991500
pub macro autodiff_forward($item:item) {
15001501
/* compiler built-in */
@@ -1513,6 +1514,7 @@ pub(crate) mod builtin {
15131514
/// (or explicitly returns `-> ()`). Otherwise, it must be set to one of the allowed activities.
15141515
#[unstable(feature = "autodiff", issue = "124509")]
15151516
#[allow_internal_unstable(rustc_attrs)]
1517+
#[allow_internal_unstable(core_intrinsics)]
15161518
#[rustc_builtin_macro]
15171519
pub macro autodiff_reverse($item:item) {
15181520
/* compiler built-in */

0 commit comments

Comments
 (0)