|
| 1 | +//! Implementation of a Gaussian Naive Bayes Classifier |
| 2 | +
|
1 | 3 | use std::cmp::Ordering; |
2 | 4 | use std::collections::HashMap; |
3 | 5 | use std::f64; |
4 | 6 | use std::fmt; |
5 | 7 | use std::hash::Hash; |
6 | 8 | use std::iter::FromIterator; |
7 | 9 |
|
| 10 | +/// A Gaussian Naive Bayes Classifier |
| 11 | +/// |
| 12 | +/// The classifier is trained by consuming an iterator over the training data: |
| 13 | +/// ``` |
| 14 | +/// let nbc: NaiveBayesClassifier<_> = data |
| 15 | +/// .iter() |
| 16 | +/// .collect(); |
| 17 | +/// ``` |
8 | 18 | #[derive(Debug)] |
9 | 19 | pub struct NaiveBayesClassifier<C> |
10 | 20 | where C: Eq + Hash |
11 | 21 | { |
12 | 22 | class_distributions: HashMap<C, FeatureDistribution>, |
13 | 23 | } |
14 | 24 |
|
| 25 | +/// Distribution of each feature column |
15 | 26 | #[derive(Debug, Clone)] |
16 | 27 | struct FeatureDistribution { |
17 | | - distributions: Vec<UniformNormalDistribution> |
| 28 | + distributions: Vec<NormalDistribution> |
18 | 29 | } |
19 | 30 |
|
| 31 | +/// Univariate Normal Distribution |
20 | 32 | #[derive(Copy, Clone)] |
21 | | -struct UniformNormalDistribution { |
| 33 | +struct NormalDistribution { |
22 | 34 | sum: f64, |
23 | 35 | sqsum: f64, |
24 | 36 | n: usize |
|
40 | 52 |
|
41 | 53 | for (i, &xi) in x.into_iter().enumerate() { |
42 | 54 | if i >= distributions.len() { |
43 | | - distributions.resize(1 + i, UniformNormalDistribution::new()); |
| 55 | + distributions.resize(1 + i, NormalDistribution::new()); |
44 | 56 | } |
45 | 57 |
|
46 | 58 | distributions[i].update(xi); |
|
56 | 68 | impl<C> NaiveBayesClassifier<C> |
57 | 69 | where C: Eq + Hash + Copy, |
58 | 70 | { |
| 71 | + /// predict target class for a single feature vector |
59 | 72 | pub fn predict(&self, x: &[f64]) -> C { |
60 | 73 | self.class_distributions |
61 | 74 | .iter() |
@@ -88,9 +101,9 @@ impl FeatureDistribution { |
88 | 101 | } |
89 | 102 | } |
90 | 103 |
|
91 | | -impl UniformNormalDistribution { |
| 104 | +impl NormalDistribution { |
92 | 105 | fn new() -> Self { |
93 | | - UniformNormalDistribution { |
| 106 | + NormalDistribution { |
94 | 107 | sum: 0.0, |
95 | 108 | sqsum: 0.0, |
96 | 109 | n: 0 |
@@ -120,7 +133,7 @@ impl UniformNormalDistribution { |
120 | 133 | } |
121 | 134 | } |
122 | 135 |
|
123 | | -impl fmt::Debug for UniformNormalDistribution { |
| 136 | +impl fmt::Debug for NormalDistribution { |
124 | 137 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { |
125 | 138 | write!(f, "N{{{}; {}}}", self.mean(), self.variance()) |
126 | 139 | } |
|
0 commit comments