Skip to content

Commit a484559

Browse files
authored
add hsmm helper (#112)
1 parent e51fecc commit a484559

File tree

3 files changed

+114
-0
lines changed

3 files changed

+114
-0
lines changed

tests/extensions.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,56 @@ def enumerate(semiring, edge):
165165
ls = [s for (_, s) in chains[N]]
166166
return semiring.unconvert(semiring.sum(torch.stack(ls, dim=1), dim=1)), ls
167167

168+
@staticmethod
169+
def enumerate_hsmm(semiring, init_z_1, transition_z_to_z, transition_z_to_l, emission_n_l_z):
170+
ssize = semiring.size()
171+
batch, N, K, C = emission_n_l_z.shape
172+
173+
if init_z_1.dim() == 1:
174+
init_z_1 = init_z_1.unsqueeze(0).expand(batch, C) # batch, C
175+
transition_z_to_z = transition_z_to_z.unsqueeze(0).expand(batch, C, C)
176+
transition_z_to_l = transition_z_to_l.unsqueeze(0).expand(batch, C, K)
177+
178+
init_z_1 = semiring.convert(init_z_1) # ssize, batch, C
179+
transition_z_to_z = semiring.convert(transition_z_to_z) # ssize, batch, C, C
180+
transition_z_to_l = semiring.convert(transition_z_to_l) # ssize, batch, C, K
181+
emission_n_l_z = semiring.convert(emission_n_l_z) # ssize, batch, N, K, C
182+
183+
def score_chain(chain):
184+
score = semiring.fill(torch.zeros(ssize, batch), torch.tensor(True), semiring.one)
185+
state_0, _ = chain[0]
186+
# P(z_{-1})
187+
score = semiring.mul(score, init_z_1[:, :, state_0])
188+
prev_state = state_0
189+
n = 0
190+
for t in range(len(chain) - 1):
191+
state, k = chain[t + 1]
192+
# P(z_t | z_{t-1})
193+
score = semiring.mul(score, transition_z_to_z[:, :, prev_state, state])
194+
# P(l_t | z_t)
195+
score = semiring.mul(score, transition_z_to_l[:, :, state, k])
196+
# P(x_{n:n+l_t} | z_t, l_t)
197+
score = semiring.mul(score, emission_n_l_z[:, :, n, k, state])
198+
prev_state = state
199+
n += k
200+
return score
201+
202+
chains = {}
203+
chains[0] = [
204+
[(c, 0)] for c in range(C)
205+
]
206+
207+
for n in range(1, N + 1):
208+
chains[n] = []
209+
for k in range(1, K):
210+
if n - k not in chains:
211+
continue
212+
for chain in chains[n - k]:
213+
for c in range(C):
214+
chains[n].append(chain + [(c, k)])
215+
ls = [score_chain(chain) for chain in chains[N]]
216+
return semiring.unconvert(semiring.sum(torch.stack(ls, dim=1), dim=1)), ls
217+
168218

169219
### Tests
170220

tests/test_algorithms.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,3 +499,26 @@ def ignore_alignment(data):
499499
# assert torch.isclose(count, alpha).all()
500500
struct = model(semiring, max_gap=1)
501501
alpha = struct.sum(vals)
502+
503+
504+
@pytest.mark.parametrize("model_test", ["SemiMarkov"])
505+
@pytest.mark.parametrize("semiring", [LogSemiring, MaxSemiring])
506+
def test_hsmm(model_test, semiring):
507+
"Test HSMM helper function."
508+
C, K, batch, N = 5, 3, 2, 5
509+
init_z_1 = torch.rand(batch, C)
510+
transition_z_to_z = torch.rand(C, C)
511+
transition_z_to_l = torch.rand(C, K)
512+
emission_n_l_z = torch.rand(batch, N, K, C)
513+
514+
# first way: enumerate using init/transitions/emission
515+
partition1 = algorithms[model_test][1].enumerate_hsmm(semiring, init_z_1, transition_z_to_z,
516+
transition_z_to_l, emission_n_l_z)[0]
517+
# second way: enumerate using edge scores computed from init/transitions/emission
518+
edge = SemiMarkov.hsmm(init_z_1, transition_z_to_z, transition_z_to_l, emission_n_l_z)
519+
partition2 = algorithms[model_test][1].enumerate(semiring, edge)[0]
520+
# third way: dp using edge scores computed from init/transitions/emission
521+
partition3 = algorithms[model_test][0](semiring).logpartition(edge)[0]
522+
523+
assert torch.isclose(partition1, partition2).all()
524+
assert torch.isclose(partition2, partition3).all()

torch_struct/semimarkov.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,44 @@ def from_parts(edge):
173173
labels[on[i][0], on[i][1] + on[i][2]] = on[i][3]
174174
# print(edge.nonzero(), labels)
175175
return labels, (C, K)
176+
177+
# Adapters
178+
@staticmethod
179+
def hsmm(init_z_1, transition_z_to_z, transition_z_to_l, emission_n_l_z):
180+
"""
181+
Convert HSMM log-probs to edge scores.
182+
183+
Parameters:
184+
init_z_1: C or b x C (init_z[i] = log P(z_{-1}=i), note that z_{-1} is an
185+
auxiliary state whose purpose is to induce a distribution over z_0.)
186+
transition_z_to_z: C X C (transition_z_to_z[i][j] = log P(z_{n+1}=j | z_n=i),
187+
note that the order of z_{n+1} and z_n is different
188+
from `edges`.)
189+
transition_z_to_l: C X K (transition_z_to_l[i][j] = P(l_n=j | z_n=i))
190+
emission_n_l_z: b x N x K x C
191+
192+
Returns:
193+
edges: b x (N-1) x K x C x C, where edges[b, n, k, c2, c1]
194+
= log P(z_n=c2 | z_{n-1}=c1) + log P(l_n=k | z_n=c2)
195+
+ log P(x_{n:n+l_n} | z_n=c2, l_n=k), if n>0
196+
= log P(z_n=c2 | z_{n-1}=c1) + log P(l_n=k | z_n=c2)
197+
+ log P(x_{n:n+l_n} | z_n=c2, l_n=k) + log P(z_{-1}), if n=0
198+
"""
199+
batch, N, K, C = emission_n_l_z.shape
200+
edges = torch.zeros(batch, N, K, C, C).type_as(emission_n_l_z)
201+
202+
# initial state: log P(z_{-1})
203+
if init_z_1.dim() == 1:
204+
init_z_1 = init_z_1.unsqueeze(0).expand(batch, -1)
205+
edges[:, 0, :, :, :] += init_z_1.view(batch, 1, 1, C)
206+
207+
# transitions: log P(z_n | z_{n-1})
208+
edges += transition_z_to_z.transpose(-1, -2).view(1, 1, 1, C, C)
209+
210+
# l given z: log P(l_n | z_n)
211+
edges += transition_z_to_l.transpose(-1, -2).view(1, 1, K, C, 1)
212+
213+
# emissions: log P(x_{n:n+l_n} | z_n, l_n)
214+
edges += emission_n_l_z.view(batch, N, K, C, 1)
215+
216+
return edges

0 commit comments

Comments
 (0)