Skip to content

Commit 175d236

Browse files
committed
cleanup, first helper
1 parent 12292ca commit 175d236

File tree

6 files changed

+250
-0
lines changed

6 files changed

+250
-0
lines changed
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
use super::typetree::TypeTree;
2+
use std::str::FromStr;
3+
use rustc_data_structures::stable_hasher::{HashStable, StableHasher};//, StableOrd};
4+
use crate::HashStableContext;
5+
6+
#[allow(dead_code)]
7+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)]
8+
pub enum DiffMode {
9+
Inactive,
10+
Source,
11+
Forward,
12+
Reverse,
13+
}
14+
15+
#[allow(dead_code)]
16+
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)]
17+
pub enum DiffActivity {
18+
None,
19+
Active,
20+
Const,
21+
Duplicated,
22+
DuplicatedNoNeed,
23+
}
24+
fn clause_diffactivity_discriminant(value: &DiffActivity) -> usize {
25+
match value {
26+
DiffActivity::None => 0,
27+
DiffActivity::Active => 1,
28+
DiffActivity::Const => 2,
29+
DiffActivity::Duplicated => 3,
30+
DiffActivity::DuplicatedNoNeed => 4,
31+
}
32+
}
33+
fn clause_diffmode_discriminant(value: &DiffMode) -> usize {
34+
match value {
35+
DiffMode::Inactive => 0,
36+
DiffMode::Source => 1,
37+
DiffMode::Forward => 2,
38+
DiffMode::Reverse => 3,
39+
}
40+
}
41+
42+
43+
impl<CTX: HashStableContext> HashStable<CTX> for DiffMode {
44+
fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
45+
clause_diffmode_discriminant(self).hash_stable(hcx, hasher);
46+
}
47+
}
48+
49+
impl<CTX: HashStableContext> HashStable<CTX> for DiffActivity {
50+
fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
51+
clause_diffactivity_discriminant(self).hash_stable(hcx, hasher);
52+
}
53+
}
54+
55+
56+
impl FromStr for DiffActivity {
57+
type Err = ();
58+
59+
fn from_str(s: &str) -> Result<DiffActivity, ()> {
60+
match s {
61+
"None" => Ok(DiffActivity::None),
62+
"Active" => Ok(DiffActivity::Active),
63+
"Const" => Ok(DiffActivity::Const),
64+
"Duplicated" => Ok(DiffActivity::Duplicated),
65+
"DuplicatedNoNeed" => Ok(DiffActivity::DuplicatedNoNeed),
66+
_ => Err(()),
67+
}
68+
}
69+
}
70+
71+
#[allow(dead_code)]
72+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)]
73+
pub struct AutoDiffAttrs {
74+
pub mode: DiffMode,
75+
pub ret_activity: DiffActivity,
76+
pub input_activity: Vec<DiffActivity>,
77+
}
78+
79+
impl<CTX: HashStableContext> HashStable<CTX> for AutoDiffAttrs {
80+
fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
81+
self.mode.hash_stable(hcx, hasher);
82+
self.ret_activity.hash_stable(hcx, hasher);
83+
self.input_activity.hash_stable(hcx, hasher);
84+
}
85+
}
86+
87+
impl AutoDiffAttrs {
88+
pub fn inactive() -> Self {
89+
AutoDiffAttrs {
90+
mode: DiffMode::Inactive,
91+
ret_activity: DiffActivity::None,
92+
input_activity: Vec::new(),
93+
}
94+
}
95+
96+
pub fn is_active(&self) -> bool {
97+
match self.mode {
98+
DiffMode::Inactive => false,
99+
_ => {
100+
dbg!(&self);
101+
true
102+
},
103+
}
104+
}
105+
106+
pub fn is_source(&self) -> bool {
107+
dbg!(&self);
108+
match self.mode {
109+
DiffMode::Source => true,
110+
_ => false,
111+
}
112+
}
113+
pub fn apply_autodiff(&self) -> bool {
114+
match self.mode {
115+
DiffMode::Inactive => false,
116+
DiffMode::Source => false,
117+
_ => {
118+
dbg!(&self);
119+
true
120+
},
121+
}
122+
}
123+
124+
pub fn into_item(
125+
self,
126+
source: String,
127+
target: String,
128+
inputs: Vec<TypeTree>,
129+
output: TypeTree,
130+
) -> AutoDiffItem {
131+
dbg!(&self);
132+
AutoDiffItem { source, target, inputs, output, attrs: self }
133+
}
134+
}
135+
136+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
137+
pub struct AutoDiffItem {
138+
pub source: String,
139+
pub target: String,
140+
pub attrs: AutoDiffAttrs,
141+
pub inputs: Vec<TypeTree>,
142+
pub output: TypeTree,
143+
}
144+
145+
//impl<CTX: HashStableContext> HashStable<CTX> for AutoDiffItem {
146+
// fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
147+
// self.source.hash_stable(hcx, hasher);
148+
// self.target.hash_stable(hcx, hasher);
149+
// self.attrs.hash_stable(hcx, hasher);
150+
// for tt in &self.inputs {
151+
// tt.0.hash_stable(hcx, hasher);
152+
// }
153+
// //self.inputs.hash_stable(hcx, hasher);
154+
// self.output.0.hash_stable(hcx, hasher);
155+
// }
156+
//}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
use std::fmt;
2+
//use rustc_data_structures::stable_hasher::{HashStable};//, StableHasher};
3+
//use crate::HashStableContext;
4+
5+
6+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
7+
pub enum Kind {
8+
Anything,
9+
Integer,
10+
Pointer,
11+
Half,
12+
Float,
13+
Double,
14+
Unknown,
15+
}
16+
//impl<CTX: HashStableContext> HashStable<CTX> for Kind {
17+
// fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
18+
// clause_kind_discriminant(self).hash_stable(hcx, hasher);
19+
// }
20+
//}
21+
//fn clause_kind_discriminant(value: &Kind) -> usize {
22+
// match value {
23+
// Kind::Anything => 0,
24+
// Kind::Integer => 1,
25+
// Kind::Pointer => 2,
26+
// Kind::Half => 3,
27+
// Kind::Float => 4,
28+
// Kind::Double => 5,
29+
// Kind::Unknown => 6,
30+
// }
31+
//}
32+
33+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
34+
pub struct TypeTree(pub Vec<Type>);
35+
36+
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
37+
pub struct Type {
38+
pub offset: isize,
39+
pub size: usize,
40+
pub kind: Kind,
41+
pub child: TypeTree,
42+
}
43+
44+
//impl<CTX: HashStableContext> HashStable<CTX> for Type {
45+
// fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
46+
// self.offset.hash_stable(hcx, hasher);
47+
// self.size.hash_stable(hcx, hasher);
48+
// self.kind.hash_stable(hcx, hasher);
49+
// self.child.0.hash_stable(hcx, hasher);
50+
// }
51+
//}
52+
53+
impl Type {
54+
pub fn add_offset(self, add: isize) -> Self {
55+
let offset = match self.offset {
56+
-1 => add,
57+
x => add + x,
58+
};
59+
60+
Self { size: self.size, kind: self.kind, child: self.child, offset }
61+
}
62+
}
63+
64+
impl fmt::Display for Type {
65+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66+
<Self as fmt::Debug>::fmt(self, f)
67+
}
68+
}

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
use crate::errors;
44
//use crate::util::check_builtin_macro_attribute;
5+
//use crate::util::check_autodiff;
56

67
use rustc_ast::ptr::P;
78
use rustc_ast::{self as ast, FnHeader, FnSig, Generics, StmtKind};

compiler/rustc_passes/messages.ftl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ passes_abi_ne =
1313
passes_abi_of =
1414
fn_abi_of({$fn_name}) = {$fn_abi}
1515
16+
passes_autodiff_attr =
17+
`#[autodiff]` should be applied to a function
18+
.label = not a function
19+
1620
passes_allow_incoherent_impl =
1721
`rustc_allow_incoherent_impl` attribute should be applied to impl items.
1822
.label = the only currently supported targets are inherent methods

compiler/rustc_passes/src/check_attr.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ impl CheckAttrVisitor<'_> {
232232
self.check_generic_attr(hir_id, attr, target, Target::Fn);
233233
self.check_proc_macro(hir_id, target, ProcMacroKind::Derive)
234234
}
235+
sym::autodiff => self.check_autodiff(hir_id, attr, span, target),
235236
_ => {}
236237
}
237238

@@ -2382,6 +2383,18 @@ impl CheckAttrVisitor<'_> {
23822383
self.abort.set(true);
23832384
}
23842385
}
2386+
2387+
/// Checks if `#[autodiff]` is applied to an item other than a function item.
2388+
fn check_autodiff(&self, _hir_id: HirId, _attr: &Attribute, span: Span, target: Target) {
2389+
dbg!("check_autodiff");
2390+
match target {
2391+
Target::Fn => {}
2392+
_ => {
2393+
self.tcx.sess.emit_err(errors::AutoDiffAttr { attr_span: span });
2394+
self.abort.set(true);
2395+
}
2396+
}
2397+
}
23852398
}
23862399

23872400
impl<'tcx> Visitor<'tcx> for CheckAttrVisitor<'tcx> {

compiler/rustc_passes/src/errors.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ pub struct IncorrectDoNotRecommendLocation {
2424
pub span: Span,
2525
}
2626

27+
#[derive(Diagnostic)]
28+
#[diag(passes_autodiff_attr)]
29+
pub struct AutoDiffAttr {
30+
#[primary_span]
31+
#[label]
32+
pub attr_span: Span,
33+
}
34+
2735
#[derive(LintDiagnostic)]
2836
#[diag(passes_outer_crate_level_attr)]
2937
pub struct OuterCrateLevelAttr;

0 commit comments

Comments
 (0)