Skip to content

Commit 50a8258

Browse files
committed
move AD types from rustc_middle to rustc_ast
1 parent dee82b3 commit 50a8258

File tree

24 files changed

+429
-13
lines changed

24 files changed

+429
-13
lines changed

compiler/rustc_ast/src/ast.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,11 +1592,13 @@ impl MacCall {
15921592
}
15931593
}
15941594

1595+
/// Manuel
15951596
/// Arguments passed to an attribute macro.
15961597
#[derive(Clone, Encodable, Decodable, Debug)]
15971598
pub enum AttrArgs {
15981599
/// No arguments: `#[attr]`.
15991600
Empty,
1601+
/// Manuel autodiff
16001602
/// Delimited arguments: `#[attr()/[]/{}]`.
16011603
Delimited(DelimArgs),
16021604
/// Arguments of a key-value attribute: `#[attr = "value"]`.
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
use super::typetree::TypeTree;
2+
use std::str::FromStr;
3+
use rustc_data_structures::stable_hasher::{HashStable, StableHasher};//, StableOrd};
4+
use crate::HashStableContext;
5+
6+
#[allow(dead_code)]
7+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)]
8+
pub enum DiffMode {
9+
Inactive,
10+
Source,
11+
Forward,
12+
Reverse,
13+
}
14+
15+
#[allow(dead_code)]
16+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)]
17+
pub enum DiffActivity {
18+
None,
19+
Active,
20+
Const,
21+
Duplicated,
22+
DuplicatedNoNeed,
23+
}
24+
fn clause_diffactivity_discriminant(value: &DiffActivity) -> usize {
25+
match value {
26+
DiffActivity::None => 0,
27+
DiffActivity::Active => 1,
28+
DiffActivity::Const => 2,
29+
DiffActivity::Duplicated => 3,
30+
DiffActivity::DuplicatedNoNeed => 4,
31+
}
32+
}
33+
fn clause_diffmode_discriminant(value: &DiffMode) -> usize {
34+
match value {
35+
DiffMode::Inactive => 0,
36+
DiffMode::Source => 1,
37+
DiffMode::Forward => 2,
38+
DiffMode::Reverse => 3,
39+
}
40+
}
41+
42+
43+
impl<CTX: HashStableContext> HashStable<CTX> for DiffMode {
44+
fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
45+
clause_diffmode_discriminant(self).hash_stable(hcx, hasher);
46+
}
47+
}
48+
49+
impl<CTX: HashStableContext> HashStable<CTX> for DiffActivity {
50+
fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
51+
clause_diffactivity_discriminant(self).hash_stable(hcx, hasher);
52+
}
53+
}
54+
55+
56+
impl FromStr for DiffActivity {
57+
type Err = ();
58+
59+
fn from_str(s: &str) -> Result<DiffActivity, ()> {
60+
match s {
61+
"None" => Ok(DiffActivity::None),
62+
"Active" => Ok(DiffActivity::Active),
63+
"Const" => Ok(DiffActivity::Const),
64+
"Duplicated" => Ok(DiffActivity::Duplicated),
65+
"DuplicatedNoNeed" => Ok(DiffActivity::DuplicatedNoNeed),
66+
_ => Err(()),
67+
}
68+
}
69+
}
70+
71+
#[allow(dead_code)]
72+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)]
73+
pub struct AutoDiffAttrs {
74+
pub mode: DiffMode,
75+
pub ret_activity: DiffActivity,
76+
pub input_activity: Vec<DiffActivity>,
77+
}
78+
79+
impl<CTX: HashStableContext> HashStable<CTX> for AutoDiffAttrs {
80+
fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
81+
self.mode.hash_stable(hcx, hasher);
82+
self.ret_activity.hash_stable(hcx, hasher);
83+
self.input_activity.hash_stable(hcx, hasher);
84+
}
85+
}
86+
87+
impl AutoDiffAttrs {
88+
pub fn inactive() -> Self {
89+
AutoDiffAttrs {
90+
mode: DiffMode::Inactive,
91+
ret_activity: DiffActivity::None,
92+
input_activity: Vec::new(),
93+
}
94+
}
95+
96+
pub fn is_active(&self) -> bool {
97+
match self.mode {
98+
DiffMode::Inactive => false,
99+
_ => true,
100+
}
101+
}
102+
103+
pub fn is_source(&self) -> bool {
104+
match self.mode {
105+
DiffMode::Source => true,
106+
_ => false,
107+
}
108+
}
109+
pub fn apply_autodiff(&self) -> bool {
110+
match self.mode {
111+
DiffMode::Inactive => false,
112+
DiffMode::Source => false,
113+
_ => true,
114+
}
115+
}
116+
117+
pub fn into_item(
118+
self,
119+
source: String,
120+
target: String,
121+
inputs: Vec<TypeTree>,
122+
output: TypeTree,
123+
) -> AutoDiffItem {
124+
AutoDiffItem { source, target, inputs, output, attrs: self }
125+
}
126+
}
127+
128+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)]
129+
pub struct AutoDiffItem {
130+
pub source: String,
131+
pub target: String,
132+
pub attrs: AutoDiffAttrs,
133+
pub inputs: Vec<TypeTree>,
134+
pub output: TypeTree,
135+
}
136+
137+
impl<CTX: HashStableContext> HashStable<CTX> for AutoDiffItem {
138+
fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
139+
self.source.hash_stable(hcx, hasher);
140+
self.target.hash_stable(hcx, hasher);
141+
self.attrs.hash_stable(hcx, hasher);
142+
for tt in &self.inputs {
143+
tt.0.hash_stable(hcx, hasher);
144+
}
145+
//self.inputs.hash_stable(hcx, hasher);
146+
self.output.0.hash_stable(hcx, hasher);
147+
}
148+
}
149+

compiler/rustc_ast/src/expand/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ use rustc_span::{def_id::DefId, symbol::Ident};
55
use crate::MetaItem;
66

77
pub mod allocator;
8+
pub mod typetree;
9+
pub mod autodiff_attrs;
810

911
#[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)]
1012
pub struct StrippedCfgItem<ModId = DefId> {
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
use std::fmt;
2+
use rustc_data_structures::stable_hasher::{HashStable, StableHasher};//, StableOrd};
3+
use crate::HashStableContext;
4+
5+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)]
6+
pub enum Kind {
7+
Anything,
8+
Integer,
9+
Pointer,
10+
Half,
11+
Float,
12+
Double,
13+
Unknown,
14+
}
15+
impl<CTX: HashStableContext> HashStable<CTX> for Kind {
16+
fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
17+
clause_kind_discriminant(self).hash_stable(hcx, hasher);
18+
}
19+
}
20+
fn clause_kind_discriminant(value: &Kind) -> usize {
21+
match value {
22+
Kind::Anything => 0,
23+
Kind::Integer => 1,
24+
Kind::Pointer => 2,
25+
Kind::Half => 3,
26+
Kind::Float => 4,
27+
Kind::Double => 5,
28+
Kind::Unknown => 6,
29+
}
30+
}
31+
32+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)]
33+
pub struct TypeTree(pub Vec<Type>);
34+
35+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)]
36+
pub struct Type {
37+
pub offset: isize,
38+
pub size: usize,
39+
pub kind: Kind,
40+
pub child: TypeTree,
41+
}
42+
43+
impl<CTX: HashStableContext> HashStable<CTX> for Type {
44+
fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
45+
self.offset.hash_stable(hcx, hasher);
46+
self.size.hash_stable(hcx, hasher);
47+
self.kind.hash_stable(hcx, hasher);
48+
self.child.0.hash_stable(hcx, hasher);
49+
}
50+
}
51+
52+
impl Type {
53+
pub fn add_offset(self, add: isize) -> Self {
54+
let offset = match self.offset {
55+
-1 => add,
56+
x => add + x,
57+
};
58+
59+
Self { size: self.size, kind: self.kind, child: self.child, offset }
60+
}
61+
}
62+
63+
impl fmt::Display for Type {
64+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65+
<Self as fmt::Debug>::fmt(self, f)
66+
}
67+
}

compiler/rustc_builtin_macros/messages.ftl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
builtin_macros_alloc_error_must_be_fn = alloc_error_handler must be a function
22
builtin_macros_alloc_must_statics = allocators must be statics
33
4+
builtin_macros_autodiff = autodiff must be applied to function
5+
46
builtin_macros_asm_clobber_abi = clobber_abi
57
builtin_macros_asm_clobber_no_reg = asm with `clobber_abi` must specify explicit registers for outputs
68
builtin_macros_asm_clobber_outputs = generic outputs

0 commit comments

Comments
 (0)