|
| 1 | +# AUTOGENERATED! DO NOT EDIT! File to edit: ../../notebooks/coding_projects/digital_processing_of_speech_signals/P2_HMM/00hidden_markov_model.ipynb. |
| 2 | + |
| 3 | +# %% auto 0 |
| 4 | +__all__ = ['HiddenMarkovModel'] |
| 5 | + |
| 6 | +# %% ../../notebooks/coding_projects/digital_processing_of_speech_signals/P2_HMM/00hidden_markov_model.ipynb 36 |
| 7 | +from fastcore.all import patch |
| 8 | + |
| 9 | +# %% ../../notebooks/coding_projects/digital_processing_of_speech_signals/P2_HMM/00hidden_markov_model.ipynb 38 |
| 10 | +import jax.numpy as jnp |
| 11 | +from flax import nnx # 导入 nnx 库,里面包含了一些常用的网络层 |
| 12 | +from fastcore.all import store_attr # 导入 fastcore 基础库的 store_attr 函数,用来方便地存储类的属性,这样Python面向对象写起来不那么冗长。 请 pip install fastcore。 |
| 13 | + |
| 14 | +# %% ../../notebooks/coding_projects/digital_processing_of_speech_signals/P2_HMM/00hidden_markov_model.ipynb 39 |
| 15 | +class HiddenMarkovModel(nnx.Module): |
| 16 | + """Hidden Markov Model |
| 17 | +
|
| 18 | + HMM with 3 states and 2 observation categories. |
| 19 | +
|
| 20 | + Attributes: |
| 21 | + ob_category (list, with length 2): observation categories |
| 22 | + total_states (int): number of states, default=3 |
| 23 | + pi (array, with shape (3,)): initial state probability |
| 24 | + A (array, with shape (3, 3)): transition probability. A.sum(axis=1) must be all ones. |
| 25 | + A[i, j] means transition prob from state i to state j. |
| 26 | + A.T[i, j] means transition prob from state j to state i. |
| 27 | + B (array, with shape (3, 2)): emitting probability, B.sum(axis=1) must be all ones. |
| 28 | + B[i, k] means emitting prob from state i to observation k. |
| 29 | +
|
| 30 | + """ |
| 31 | + |
| 32 | + def __init__(self): |
| 33 | + self.ob_category = ['THU', 'PKU'] # 0: THU, 1: PKU |
| 34 | + self.total_states = 3 |
| 35 | + self.pi = nnx.Param(jnp.array([0.2, 0.4, 0.4])) |
| 36 | + self.A = nnx.Param(jnp.array([[0.1, 0.6, 0.3], |
| 37 | + [0.3, 0.5, 0.2], |
| 38 | + [0.7, 0.2, 0.1]])) |
| 39 | + self.B = nnx.Param(jnp.array([[0.5, 0.5], |
| 40 | + [0.4, 0.6], |
| 41 | + [0.7, 0.3]])) |
| 42 | + |
| 43 | +# %% ../../notebooks/coding_projects/digital_processing_of_speech_signals/P2_HMM/00hidden_markov_model.ipynb 45 |
| 44 | +@patch |
| 45 | +def compute_likelihood_by_forward(self: HiddenMarkovModel, ob): |
| 46 | + """HMM Forward Algorithm. |
| 47 | +
|
| 48 | + Args: |
| 49 | + ob (array, with shape(T,)): (o1, o2, ..., oT), observations |
| 50 | +
|
| 51 | + Returns: |
| 52 | + fwd (array, with shape(T, 3)): fwd[t, s] means full-path forward probability torwards state s at |
| 53 | + timestep t given the observation ob[0:t+1]. |
| 54 | + 给定观察ob[0:t+1]情况下t时刻到达状态s的所有可能路径的概率和 |
| 55 | + prob: the probability of HMM model generating observations. |
| 56 | +
|
| 57 | + """ |
| 58 | + T = ob.shape[0] |
| 59 | + fwd = jnp.zeros((T, self.total_states)) |
| 60 | + |
| 61 | + # Begin Assignment |
| 62 | + |
| 63 | + # 初始化 t=0 时刻的前向概率 |
| 64 | + # fwd[0, :] = self.pi * self.B[:, ob[0]] # jax 不支持 in place 复制 |
| 65 | + fwd = fwd.at[0, :].set(self.pi * self.B[:, ob[0]]) |
| 66 | + |
| 67 | + # 根据前向概率的递推公式计算 t=1 到 T 时刻的前向概率 |
| 68 | + for t in range(1, T): |
| 69 | + for j in range(self.total_states): |
| 70 | + fwd = fwd.at[t, j].set(self.B[j, ob[t]] * jnp.dot(fwd[t - 1, :], self.A[:, j])) |
| 71 | + |
| 72 | + # End Assignment |
| 73 | + |
| 74 | + prob = fwd[-1, :].sum() |
| 75 | + |
| 76 | + return fwd, prob |
| 77 | + |
| 78 | +# %% ../../notebooks/coding_projects/digital_processing_of_speech_signals/P2_HMM/00hidden_markov_model.ipynb 51 |
| 79 | +@patch |
| 80 | +def compute_likelihood_by_backward(self:HiddenMarkovModel, ob): |
| 81 | + """HMM Backward Algorithm. |
| 82 | +
|
| 83 | + Args: |
| 84 | + ob (array, with shape(T,)): (o1, o2, ..., oT), observations |
| 85 | +
|
| 86 | + Returns: |
| 87 | + bwd (array, with shape(T, 3)): bwd[t, s] means full-path backward probability torwards state s at |
| 88 | + timestep t given the observation ob[t+1::] |
| 89 | + 给定观察ob[t+1::]情况下t时刻到达状态s的所有可能路径的概率和 |
| 90 | + prob: the probability of HMM model generating observations. |
| 91 | +
|
| 92 | + """ |
| 93 | + T = ob.shape[0] |
| 94 | + bwd = jnp.zeros((T, self.total_states)) |
| 95 | + |
| 96 | + # Begin Assignment |
| 97 | + |
| 98 | + # 初始化 t == T-1 时刻到达各个状态的概率 |
| 99 | + bwd = bwd.at[T - 1, :].set(1.0) |
| 100 | + |
| 101 | + # Induction step |
| 102 | + for t in reversed(range(T - 1)): |
| 103 | + for i in range(self.total_states): |
| 104 | + bwd = bwd.at[t, i].set(jnp.dot(bwd[t + 1, :] * self.B[:, ob[t + 1]], self.A[i, :])) |
| 105 | + |
| 106 | + # End Assignment |
| 107 | + |
| 108 | + prob = (bwd[0, :] * self.B[:, ob[0]] * self.pi).sum() |
| 109 | + |
| 110 | + return bwd, prob |
| 111 | + |
| 112 | +# %% ../../notebooks/coding_projects/digital_processing_of_speech_signals/P2_HMM/00hidden_markov_model.ipynb 56 |
| 113 | +@patch |
| 114 | +def decode_states_by_viterbi(self:HiddenMarkovModel, ob): |
| 115 | + """Viterbi Decoding Algorithm. |
| 116 | +
|
| 117 | + Args: |
| 118 | + ob (array, with shape(T,)): (o1, o2, ..., oT), observations |
| 119 | +
|
| 120 | + Variables: |
| 121 | + delta (array, with shape(T, 3)): delta[t, s] means max probability torwards state s at |
| 122 | + timestep t given the observation ob[0:t+1] |
| 123 | + 给定观察ob[0:t+1]情况下t时刻到达状态s的概率最大的路径的概率 |
| 124 | + phi (array, with shape(T, 3)): phi[t, s] means prior state s' for delta[t, s] |
| 125 | + 给定观察ob[0:t+1]情况下t时刻到达状态s的概率最大的路径的t-1时刻的状态s' |
| 126 | +
|
| 127 | + Returns: |
| 128 | + best_prob: the probability of the best state sequence |
| 129 | + best_path: the best state sequence |
| 130 | +
|
| 131 | + """ |
| 132 | + T = ob.shape[0] |
| 133 | + delta = jnp.zeros((T, self.total_states)) |
| 134 | + #update np.int32 |
| 135 | + phi = jnp.zeros((T, self.total_states), jnp.int32) |
| 136 | + best_prob, best_path = 0.0, jnp.zeros(T, dtype=jnp.int32) |
| 137 | + |
| 138 | + # Begin Assignment |
| 139 | + |
| 140 | + # 从初始状态开始 |
| 141 | + delta = delta.at[0, :].set(self.pi * self.B[:, ob[0]]) |
| 142 | + |
| 143 | + # 根据动态规划的公式来更新delta和phi |
| 144 | + for t in range(1, T): |
| 145 | + for j in range(self.total_states): |
| 146 | + d, p = max((delta[t - 1, i] * self.A[i, j] * self.B[j, ob[t]], i) for i in range(self.total_states)) |
| 147 | + delta = delta.at[t, j].set(d) |
| 148 | + phi = phi.at[t, j].set(p) |
| 149 | + |
| 150 | + # End Assignment |
| 151 | + |
| 152 | + best_path = best_path.at[T-1].set(delta[T-1, :].argmax(0)) |
| 153 | + best_prob = delta[T-1, best_path[T-1]] |
| 154 | + for t in reversed(range(T-1)): |
| 155 | + best_path = best_path.at[t].set(phi[t+1, best_path[t+1]]) |
| 156 | + |
| 157 | + return best_prob, best_path |
| 158 | + |
| 159 | +# %% ../../notebooks/coding_projects/digital_processing_of_speech_signals/P2_HMM/00hidden_markov_model.ipynb 60 |
| 160 | +if __name__ == "__main__": |
| 161 | + model = HiddenMarkovModel() |
| 162 | + observations = jnp.array([0, 1, 0, 1, 1]) # [THU, PKU, THU, PKU, PKU] |
| 163 | + fwd, p = model.compute_likelihood_by_forward(observations) |
| 164 | + print(p) |
| 165 | + print(fwd) |
| 166 | + bwd, p = model.compute_likelihood_by_backward(observations) |
| 167 | + print(p) |
| 168 | + print(bwd) |
| 169 | + prob, path = model.decode_states_by_viterbi(observations) |
| 170 | + print(prob) |
| 171 | + print(path) |
0 commit comments