|
| 1 | +use super::typetree::TypeTree; |
| 2 | +use std::str::FromStr; |
| 3 | +use rustc_data_structures::stable_hasher::{HashStable, StableHasher};//, StableOrd}; |
| 4 | +use crate::HashStableContext; |
| 5 | + |
| 6 | +#[allow(dead_code)] |
| 7 | +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)] |
| 8 | +pub enum DiffMode { |
| 9 | + Inactive, |
| 10 | + Source, |
| 11 | + Forward, |
| 12 | + Reverse, |
| 13 | +} |
| 14 | + |
| 15 | +#[allow(dead_code)] |
| 16 | +#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug)] |
| 17 | +pub enum DiffActivity { |
| 18 | + None, |
| 19 | + Active, |
| 20 | + Const, |
| 21 | + Duplicated, |
| 22 | + DuplicatedNoNeed, |
| 23 | +} |
| 24 | +fn clause_diffactivity_discriminant(value: &DiffActivity) -> usize { |
| 25 | + match value { |
| 26 | + DiffActivity::None => 0, |
| 27 | + DiffActivity::Active => 1, |
| 28 | + DiffActivity::Const => 2, |
| 29 | + DiffActivity::Duplicated => 3, |
| 30 | + DiffActivity::DuplicatedNoNeed => 4, |
| 31 | + } |
| 32 | +} |
| 33 | +fn clause_diffmode_discriminant(value: &DiffMode) -> usize { |
| 34 | + match value { |
| 35 | + DiffMode::Inactive => 0, |
| 36 | + DiffMode::Source => 1, |
| 37 | + DiffMode::Forward => 2, |
| 38 | + DiffMode::Reverse => 3, |
| 39 | + } |
| 40 | +} |
| 41 | + |
| 42 | + |
| 43 | +impl<CTX: HashStableContext> HashStable<CTX> for DiffMode { |
| 44 | + fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { |
| 45 | + clause_diffmode_discriminant(self).hash_stable(hcx, hasher); |
| 46 | + } |
| 47 | +} |
| 48 | + |
| 49 | +impl<CTX: HashStableContext> HashStable<CTX> for DiffActivity { |
| 50 | + fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { |
| 51 | + clause_diffactivity_discriminant(self).hash_stable(hcx, hasher); |
| 52 | + } |
| 53 | +} |
| 54 | + |
| 55 | + |
| 56 | +impl FromStr for DiffActivity { |
| 57 | + type Err = (); |
| 58 | + |
| 59 | + fn from_str(s: &str) -> Result<DiffActivity, ()> { |
| 60 | + match s { |
| 61 | + "None" => Ok(DiffActivity::None), |
| 62 | + "Active" => Ok(DiffActivity::Active), |
| 63 | + "Const" => Ok(DiffActivity::Const), |
| 64 | + "Duplicated" => Ok(DiffActivity::Duplicated), |
| 65 | + "DuplicatedNoNeed" => Ok(DiffActivity::DuplicatedNoNeed), |
| 66 | + _ => Err(()), |
| 67 | + } |
| 68 | + } |
| 69 | +} |
| 70 | + |
| 71 | +#[allow(dead_code)] |
| 72 | +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)] |
| 73 | +pub struct AutoDiffAttrs { |
| 74 | + pub mode: DiffMode, |
| 75 | + pub ret_activity: DiffActivity, |
| 76 | + pub input_activity: Vec<DiffActivity>, |
| 77 | +} |
| 78 | + |
| 79 | +impl<CTX: HashStableContext> HashStable<CTX> for AutoDiffAttrs { |
| 80 | + fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { |
| 81 | + self.mode.hash_stable(hcx, hasher); |
| 82 | + self.ret_activity.hash_stable(hcx, hasher); |
| 83 | + self.input_activity.hash_stable(hcx, hasher); |
| 84 | + } |
| 85 | +} |
| 86 | + |
| 87 | +impl AutoDiffAttrs { |
| 88 | + pub fn inactive() -> Self { |
| 89 | + AutoDiffAttrs { |
| 90 | + mode: DiffMode::Inactive, |
| 91 | + ret_activity: DiffActivity::None, |
| 92 | + input_activity: Vec::new(), |
| 93 | + } |
| 94 | + } |
| 95 | + |
| 96 | + pub fn is_active(&self) -> bool { |
| 97 | + match self.mode { |
| 98 | + DiffMode::Inactive => false, |
| 99 | + _ => true, |
| 100 | + } |
| 101 | + } |
| 102 | + |
| 103 | + pub fn is_source(&self) -> bool { |
| 104 | + match self.mode { |
| 105 | + DiffMode::Source => true, |
| 106 | + _ => false, |
| 107 | + } |
| 108 | + } |
| 109 | + pub fn apply_autodiff(&self) -> bool { |
| 110 | + match self.mode { |
| 111 | + DiffMode::Inactive => false, |
| 112 | + DiffMode::Source => false, |
| 113 | + _ => true, |
| 114 | + } |
| 115 | + } |
| 116 | + |
| 117 | + pub fn into_item( |
| 118 | + self, |
| 119 | + source: String, |
| 120 | + target: String, |
| 121 | + inputs: Vec<TypeTree>, |
| 122 | + output: TypeTree, |
| 123 | + ) -> AutoDiffItem { |
| 124 | + AutoDiffItem { source, target, inputs, output, attrs: self } |
| 125 | + } |
| 126 | +} |
| 127 | + |
| 128 | +#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug)] |
| 129 | +pub struct AutoDiffItem { |
| 130 | + pub source: String, |
| 131 | + pub target: String, |
| 132 | + pub attrs: AutoDiffAttrs, |
| 133 | + pub inputs: Vec<TypeTree>, |
| 134 | + pub output: TypeTree, |
| 135 | +} |
| 136 | + |
| 137 | +impl<CTX: HashStableContext> HashStable<CTX> for AutoDiffItem { |
| 138 | + fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) { |
| 139 | + self.source.hash_stable(hcx, hasher); |
| 140 | + self.target.hash_stable(hcx, hasher); |
| 141 | + self.attrs.hash_stable(hcx, hasher); |
| 142 | + for tt in &self.inputs { |
| 143 | + tt.0.hash_stable(hcx, hasher); |
| 144 | + } |
| 145 | + //self.inputs.hash_stable(hcx, hasher); |
| 146 | + self.output.0.hash_stable(hcx, hasher); |
| 147 | + } |
| 148 | +} |
| 149 | + |
0 commit comments