Skip to content

Commit 8c18a66

Browse files
committed
Add the Enzyme frontend code
1 parent 730d5d4 commit 8c18a66

File tree

16 files changed

+1169
-0
lines changed

16 files changed

+1169
-0
lines changed
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
use crate::expand::typetree::TypeTree;
2+
use std::str::FromStr;
3+
use std::fmt::{self, Display, Formatter};
4+
use crate::ptr::P;
5+
use crate::{Ty, TyKind};
6+
7+
use crate::expand::HashStable_Generic;
8+
use crate::expand::Encodable;
9+
use crate::expand::Decodable;
10+
11+
#[allow(dead_code)]
12+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
13+
pub enum DiffMode {
14+
Inactive,
15+
Source,
16+
Forward,
17+
Reverse,
18+
ForwardFirst,
19+
ReverseFirst,
20+
}
21+
22+
pub fn is_rev(mode: DiffMode) -> bool {
23+
match mode {
24+
DiffMode::Reverse | DiffMode::ReverseFirst => true,
25+
_ => false,
26+
}
27+
}
28+
pub fn is_fwd(mode: DiffMode) -> bool {
29+
match mode {
30+
DiffMode::Forward | DiffMode::ForwardFirst => true,
31+
_ => false,
32+
}
33+
}
34+
35+
impl Display for DiffMode {
36+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
37+
match self {
38+
DiffMode::Inactive => write!(f, "Inactive"),
39+
DiffMode::Source => write!(f, "Source"),
40+
DiffMode::Forward => write!(f, "Forward"),
41+
DiffMode::Reverse => write!(f, "Reverse"),
42+
DiffMode::ForwardFirst => write!(f, "ForwardFirst"),
43+
DiffMode::ReverseFirst => write!(f, "ReverseFirst"),
44+
}
45+
}
46+
}
47+
48+
pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
49+
if activity == DiffActivity::None {
50+
// Only valid if primal returns (), but we can't check that here.
51+
return true;
52+
}
53+
match mode {
54+
DiffMode::Inactive => false,
55+
DiffMode::Source => false,
56+
DiffMode::Forward | DiffMode::ForwardFirst => {
57+
activity == DiffActivity::Dual ||
58+
activity == DiffActivity::DualOnly ||
59+
activity == DiffActivity::Const
60+
}
61+
DiffMode::Reverse | DiffMode::ReverseFirst => {
62+
activity == DiffActivity::Const ||
63+
activity == DiffActivity::Active ||
64+
activity == DiffActivity::ActiveOnly
65+
}
66+
}
67+
}
68+
fn is_ptr_or_ref(ty: &Ty) -> bool {
69+
match ty.kind {
70+
TyKind::Ptr(_) | TyKind::Ref(_, _) => true,
71+
_ => false,
72+
}
73+
}
74+
// TODO We should make this more robust to also
75+
// accept aliases of f32 and f64
76+
//fn is_float(ty: &Ty) -> bool {
77+
// false
78+
//}
79+
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
80+
if is_ptr_or_ref(ty) {
81+
return activity == DiffActivity::Dual ||
82+
activity == DiffActivity::DualOnly ||
83+
activity == DiffActivity::Duplicated ||
84+
activity == DiffActivity::DuplicatedOnly ||
85+
activity == DiffActivity::Const;
86+
}
87+
true
88+
//if is_scalar_ty(&ty) {
89+
// return activity == DiffActivity::Active || activity == DiffActivity::ActiveOnly ||
90+
// activity == DiffActivity::Const;
91+
//}
92+
}
93+
pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
94+
return match mode {
95+
DiffMode::Inactive => false,
96+
DiffMode::Source => false,
97+
DiffMode::Forward | DiffMode::ForwardFirst => {
98+
// These are the only valid cases
99+
activity == DiffActivity::Dual ||
100+
activity == DiffActivity::DualOnly ||
101+
activity == DiffActivity::Const
102+
}
103+
DiffMode::Reverse | DiffMode::ReverseFirst => {
104+
// These are the only valid cases
105+
activity == DiffActivity::Active ||
106+
activity == DiffActivity::ActiveOnly ||
107+
activity == DiffActivity::Const ||
108+
activity == DiffActivity::Duplicated ||
109+
activity == DiffActivity::DuplicatedOnly
110+
}
111+
};
112+
}
113+
pub fn invalid_input_activities(mode: DiffMode, activity_vec: &[DiffActivity]) -> Option<usize> {
114+
for i in 0..activity_vec.len() {
115+
if !valid_input_activity(mode, activity_vec[i]) {
116+
return Some(i);
117+
}
118+
}
119+
None
120+
}
121+
122+
#[allow(dead_code)]
123+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
124+
pub enum DiffActivity {
125+
None,
126+
Const,
127+
Active,
128+
ActiveOnly,
129+
Dual,
130+
DualOnly,
131+
Duplicated,
132+
DuplicatedOnly,
133+
FakeActivitySize
134+
}
135+
136+
impl Display for DiffActivity {
137+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
138+
match self {
139+
DiffActivity::None => write!(f, "None"),
140+
DiffActivity::Const => write!(f, "Const"),
141+
DiffActivity::Active => write!(f, "Active"),
142+
DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
143+
DiffActivity::Dual => write!(f, "Dual"),
144+
DiffActivity::DualOnly => write!(f, "DualOnly"),
145+
DiffActivity::Duplicated => write!(f, "Duplicated"),
146+
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
147+
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
148+
}
149+
}
150+
}
151+
152+
impl FromStr for DiffMode {
153+
type Err = ();
154+
155+
fn from_str(s: &str) -> Result<DiffMode, ()> {
156+
match s {
157+
"Inactive" => Ok(DiffMode::Inactive),
158+
"Source" => Ok(DiffMode::Source),
159+
"Forward" => Ok(DiffMode::Forward),
160+
"Reverse" => Ok(DiffMode::Reverse),
161+
"ForwardFirst" => Ok(DiffMode::ForwardFirst),
162+
"ReverseFirst" => Ok(DiffMode::ReverseFirst),
163+
_ => Err(()),
164+
}
165+
}
166+
}
167+
impl FromStr for DiffActivity {
168+
type Err = ();
169+
170+
fn from_str(s: &str) -> Result<DiffActivity, ()> {
171+
match s {
172+
"None" => Ok(DiffActivity::None),
173+
"Active" => Ok(DiffActivity::Active),
174+
"ActiveOnly" => Ok(DiffActivity::ActiveOnly),
175+
"Const" => Ok(DiffActivity::Const),
176+
"Dual" => Ok(DiffActivity::Dual),
177+
"DualOnly" => Ok(DiffActivity::DualOnly),
178+
"Duplicated" => Ok(DiffActivity::Duplicated),
179+
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
180+
_ => Err(()),
181+
}
182+
}
183+
}
184+
185+
#[allow(dead_code)]
186+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
187+
pub struct AutoDiffAttrs {
188+
pub mode: DiffMode,
189+
pub ret_activity: DiffActivity,
190+
pub input_activity: Vec<DiffActivity>,
191+
}
192+
193+
impl AutoDiffAttrs {
194+
pub fn has_ret_activity(&self) -> bool {
195+
match self.ret_activity {
196+
DiffActivity::None => false,
197+
_ => true,
198+
}
199+
}
200+
pub fn has_active_only_ret(&self) -> bool {
201+
match self.ret_activity {
202+
DiffActivity::ActiveOnly => true,
203+
_ => false,
204+
}
205+
}
206+
}
207+
208+
impl AutoDiffAttrs {
209+
pub fn inactive() -> Self {
210+
AutoDiffAttrs {
211+
mode: DiffMode::Inactive,
212+
ret_activity: DiffActivity::None,
213+
input_activity: Vec::new(),
214+
}
215+
}
216+
pub fn source() -> Self {
217+
AutoDiffAttrs {
218+
mode: DiffMode::Source,
219+
ret_activity: DiffActivity::None,
220+
input_activity: Vec::new(),
221+
}
222+
}
223+
224+
pub fn is_active(&self) -> bool {
225+
match self.mode {
226+
DiffMode::Inactive => false,
227+
_ => {
228+
true
229+
}
230+
}
231+
}
232+
233+
pub fn is_source(&self) -> bool {
234+
match self.mode {
235+
DiffMode::Source => true,
236+
_ => false,
237+
}
238+
}
239+
pub fn apply_autodiff(&self) -> bool {
240+
match self.mode {
241+
DiffMode::Inactive => false,
242+
DiffMode::Source => false,
243+
_ => {
244+
true
245+
}
246+
}
247+
}
248+
249+
pub fn into_item(
250+
self,
251+
source: String,
252+
target: String,
253+
inputs: Vec<TypeTree>,
254+
output: TypeTree,
255+
) -> AutoDiffItem {
256+
AutoDiffItem { source, target, inputs, output, attrs: self }
257+
}
258+
}
259+
260+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
261+
pub struct AutoDiffItem {
262+
pub source: String,
263+
pub target: String,
264+
pub attrs: AutoDiffAttrs,
265+
pub inputs: Vec<TypeTree>,
266+
pub output: TypeTree,
267+
}
268+
269+
impl fmt::Display for AutoDiffItem {
270+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
271+
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
272+
write!(f, " with attributes: {:?}", self.attrs)?;
273+
write!(f, " with inputs: {:?}", self.inputs)?;
274+
write!(f, " with output: {:?}", self.output)
275+
}
276+
}

compiler/rustc_ast/src/expand/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ use rustc_span::symbol::Ident;
77
use crate::MetaItem;
88

99
pub mod allocator;
10+
pub mod autodiff_attrs;
11+
pub mod typetree;
1012

1113
#[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)]
1214
pub struct StrippedCfgItem<ModId = DefId> {
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
use crate::expand::HashStable_Generic;
2+
use crate::expand::Encodable;
3+
use crate::expand::Decodable;
4+
5+
use std::fmt;
6+
7+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
8+
pub enum Kind {
9+
Anything,
10+
Integer,
11+
Pointer,
12+
Half,
13+
Float,
14+
Double,
15+
Unknown,
16+
}
17+
18+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
19+
pub struct TypeTree(pub Vec<Type>);
20+
21+
impl TypeTree {
22+
pub fn new() -> Self {
23+
Self(Vec::new())
24+
}
25+
pub fn all_ints() -> Self {
26+
Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }])
27+
}
28+
pub fn int(size: usize) -> Self {
29+
let mut ints = Vec::with_capacity(size);
30+
for i in 0..size {
31+
ints.push(Type { offset: i as isize, size: 1, kind: Kind::Integer, child: TypeTree::new() });
32+
}
33+
Self(ints)
34+
}
35+
}
36+
37+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
38+
pub struct FncTree {
39+
pub args: Vec<TypeTree>,
40+
pub ret: TypeTree,
41+
}
42+
43+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
44+
pub struct Type {
45+
pub offset: isize,
46+
pub size: usize,
47+
pub kind: Kind,
48+
pub child: TypeTree,
49+
}
50+
51+
impl Type {
52+
pub fn add_offset(self, add: isize) -> Self {
53+
let offset = match self.offset {
54+
-1 => add,
55+
x => add + x,
56+
};
57+
58+
Self { size: self.size, kind: self.kind, child: self.child, offset }
59+
}
60+
}
61+
62+
impl fmt::Display for Type {
63+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64+
<Self as fmt::Debug>::fmt(self, f)
65+
}
66+
}

compiler/rustc_builtin_macros/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ name = "rustc_builtin_macros"
33
version = "0.0.0"
44
edition = "2021"
55

6+
7+
[lints.rust]
8+
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(llvm_enzyme)'] }
9+
610
[lib]
711
doctest = false
812

compiler/rustc_builtin_macros/messages.ftl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
builtin_macros_alloc_error_must_be_fn = alloc_error_handler must be a function
22
builtin_macros_alloc_must_statics = allocators must be statics
33
4+
builtin_macros_autodiff_unknown_activity = did not recognize activity {$act}
5+
builtin_macros_autodiff = autodiff must be applied to function
6+
builtin_macros_autodiff_not_build = this rustc version does not support autodiff
7+
builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode
8+
builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found}
9+
builtin_macros_autodiff_mode = unknown Mode: `{$mode}`. Use `Forward` or `Reverse`
10+
builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
11+
412
builtin_macros_asm_clobber_abi = clobber_abi
513
builtin_macros_asm_clobber_no_reg = asm with `clobber_abi` must specify explicit registers for outputs
614
builtin_macros_asm_clobber_outputs = generic outputs

0 commit comments

Comments
 (0)