Skip to content

Commit 90a3b69

Browse files
authored
fix semimarkov batching and add tests (#114)
* fix semimarkov batching and add tests * . * batched implementation * add back dp_standard
1 parent 5328ec5 commit 90a3b69

File tree

2 files changed

+64
-35
lines changed

2 files changed

+64
-35
lines changed

tests/test_algorithms.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,32 @@ def test_hsmm(model_test, semiring):
519519
partition2 = algorithms[model_test][1].enumerate(semiring, edge)[0]
520520
# third way: dp using edge scores computed from init/transitions/emission
521521
partition3 = algorithms[model_test][0](semiring).logpartition(edge)[0]
522+
# fourth way: dp_standard using edge scores computed from init/transitions/emission
523+
partition4 = algorithms[model_test][0](semiring)._dp_standard(edge)[0]
522524

523525
assert torch.isclose(partition1, partition2).all()
524526
assert torch.isclose(partition2, partition3).all()
527+
assert torch.isclose(partition3, partition4).all()
528+
529+
530+
@given(data())
531+
@pytest.mark.parametrize("model_test", ["SemiMarkov"])
532+
@pytest.mark.parametrize("semiring", [LogSemiring, MaxSemiring])
533+
def test_batching_lengths(model_test, semiring, data):
534+
"Test batching"
535+
gen = Gen(model_test, data, LogSemiring)
536+
model, vals, N, batch = gen.model, gen.vals, gen.N, gen.batch
537+
lengths = torch.tensor(
538+
[data.draw(integers(min_value=2, max_value=N)) for b in range(batch - 1)] + [N]
539+
)
540+
# first way: batched implementation
541+
partition = model(semiring).logpartition(vals, lengths=lengths)[0][0]
542+
# second way: unbatched implementation
543+
for b in range(batch):
544+
vals_b = vals[b:(b + 1), :(lengths[b] - 1)]
545+
lengths_b = lengths[b:(b + 1)]
546+
partition_b = model(semiring).logpartition(vals_b, lengths=lengths_b)[0][0]
547+
assert torch.isclose(partition[b], partition_b).all()
548+
# test _dp_standard
549+
partition_dp_standard = model(semiring)._dp_standard(vals, lengths=lengths)[0][0]
550+
assert torch.isclose(partition, partition_dp_standard).all()

torch_struct/semimarkov.py

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
3434
)
3535

3636
# Init.
37-
mask = torch.zeros(*init.shape).bool()
37+
mask = torch.zeros(*init.shape, device=log_potentials.device).bool()
3838
mask[:, :, :, 0, 0].diagonal(0, -2, -1).fill_(True)
3939
init = semiring.fill(init, mask, semiring.one)
4040

@@ -61,10 +61,13 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
6161
c[:, :, : K - 1, 0] = semiring.sum(
6262
torch.stack([c.data[:, :, : K - 1, 0], lp[:, :, 1:K]], dim=-1)
6363
)
64-
end = torch.min(lengths) - 1
65-
mask = torch.zeros(*init.shape).bool()
64+
mask = torch.zeros(*init.shape, device=log_potentials.device).bool()
65+
mask_length = torch.arange(bin_N).view(1, bin_N, 1).expand(batch, bin_N, C)
66+
mask_length = mask_length.to(log_potentials.device)
6667
for k in range(1, K - 1):
67-
mask[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1).fill_(True)
68+
mask_length_k = mask_length < (lengths - 1 - (k - 1)).view(batch, 1, 1)
69+
mask_length_k = semiring.convert(mask_length_k)
70+
mask[:, :, :, k - 1, k].diagonal(0, -2, -1).masked_fill_(mask_length_k, True)
6871
init = semiring.fill(init, mask, semiring.one)
6972

7073
K_1 = K - 1
@@ -83,37 +86,37 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
8386
v = semiring.sum(semiring.sum(final[:, :, 0, :, 0, :].contiguous()))
8487
return v, [log_potentials]
8588

86-
# def _dp_standard(self, edge, lengths=None, force_grad=False):
87-
# semiring = self.semiring
88-
# ssize = semiring.size()
89-
# edge, batch, N, K, C, lengths = self._check_potentials(edge, lengths)
90-
# edge.requires_grad_(True)
91-
92-
# # Init
93-
# # All paths starting at N of len K
94-
# alpha = self._make_chart(1, (batch, N, K, C), edge, force_grad)[0]
95-
96-
# # All paths finishing at N with label C
97-
# beta = self._make_chart(N, (batch, C), edge, force_grad)
98-
# semiring.one_(beta[0].data)
99-
100-
# # Main.
101-
# for n in range(1, N):
102-
# alpha[:, :, n - 1] = semiring.dot(
103-
# beta[n - 1].view(ssize, batch, 1, 1, C),
104-
# edge[:, :, n - 1].view(ssize, batch, K, C, C),
105-
# )
106-
107-
# t = max(n - K, -1)
108-
# f1 = torch.arange(n - 1, t, -1)
109-
# f2 = torch.arange(1, len(f1) + 1)
110-
# beta[n][:] = semiring.sum(
111-
# torch.stack([alpha[:, :, a, b] for a, b in zip(f1, f2)], dim=-1)
112-
# )
113-
# v = semiring.sum(
114-
# torch.stack([beta[l - 1][:, i] for i, l in enumerate(lengths)], dim=1)
115-
# )
116-
# return v, [edge], beta
89+
def _dp_standard(self, edge, lengths=None, force_grad=False):
90+
semiring = self.semiring
91+
ssize = semiring.size()
92+
edge, batch, N, K, C, lengths = self._check_potentials(edge, lengths)
93+
edge.requires_grad_(True)
94+
95+
# Init
96+
# All paths starting at N of len K
97+
alpha = self._make_chart(1, (batch, N, K, C), edge, force_grad)[0]
98+
99+
# All paths finishing at N with label C
100+
beta = self._make_chart(N, (batch, C), edge, force_grad)
101+
beta[0] = semiring.fill(beta[0], torch.tensor(True).to(edge.device), semiring.one)
102+
103+
# Main.
104+
for n in range(1, N):
105+
alpha[:, :, n - 1] = semiring.dot(
106+
beta[n - 1].view(ssize, batch, 1, 1, C),
107+
edge[:, :, n - 1].view(ssize, batch, K, C, C),
108+
)
109+
110+
t = max(n - K, -1)
111+
f1 = torch.arange(n - 1, t, -1)
112+
f2 = torch.arange(1, len(f1) + 1)
113+
beta[n][:] = semiring.sum(
114+
torch.stack([alpha[:, :, a, b] for a, b in zip(f1, f2)], dim=-1)
115+
)
116+
v = semiring.sum(
117+
torch.stack([beta[l - 1][:, i] for i, l in enumerate(lengths)], dim=1)
118+
)
119+
return v, [edge], beta
117120

118121
@staticmethod
119122
def to_parts(sequence, extra, lengths=None):

0 commit comments

Comments
 (0)