Skip to content
30 changes: 30 additions & 0 deletions tests/test_alignment_crf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
import torch_struct
import warnings


def test_alignment_crf_shapes():
batch, N, M = 2, 4, 5
log_potentials = torch.rand(batch, N, M, 3)

if torch.cuda.is_available():
log_potentials = log_potentials.cuda()
else:
warnings.warn('Could not move log potentials to CUDA device. '
'Will not test marginals.')

dist = torch_struct.AlignmentCRF(log_potentials)
assert (batch, N, M, 3) == dist.argmax.shape
if torch.cuda.is_available():
assert (batch, N, M, 3) == dist.marginals.shape
assert (batch,) == dist.partition.shape

# Fail due to AttributeError: 'BandedMatrix' object has no attribute
# 'unsqueeze'
# assert (batch,) == dist.entropy.shape
# assert (9, batch, N, M, 3) == dist.sample([9]).shape

# Fails due to: RuntimeError: Expected condition, x and y to be on
# the same device, but condition is on cpu and x and y are on
# cuda:0 and cuda:0 respectively
# assert (8, batch,) == dist.topk(8).shape
9 changes: 5 additions & 4 deletions torch_struct/alignment.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import torch
from .helpers import _Struct
import math
import warnings

try:
import genbmm

except ImportError:
pass
warnings.warn('Could not import genbmm. '
'However, genbmm is only used for CUDA operations.')

from .semirings import LogSemiring
from .semirings.fast_semirings import broadcast
Expand Down Expand Up @@ -97,9 +100,7 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False):
# Create finalizing paths.
point = (l + M) // 2

charta[1][:, b, point:, 1, ind, :, :, Mid] = semiring.one_(
charta[1][:, b, point:, 1, ind, :, :, Mid]
)
charta[1][:, b, point:, 1, ind, :, :, Mid] = charta[1][:, b, point:, 1, ind, :, :, Mid].fill_(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this is not going to work.

We need to call

init = torch.zeros(charta[1].shape).bool()
init[:, b, point:, 1, ind, :, :, Mid].fill_(True)
charta[1] = semiring.fill(charta[1], init, semiring.one)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this should fix your other issues too)

Copy link
Contributor Author

@JohnReid JohnReid Oct 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for that. I have to admit I had just copied the code from the one_() method before it was removed in #105. My assumption was that it was the correct code.

Copy link
Contributor Author

@JohnReid JohnReid Oct 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still facing a few issues though. I fixed a few of them in the commits below but some remain. The main sticking point seems to be that the BandedMatrixs are not correctly dispatched to multiply rather than matmul() in semirings.py. The matmul implementation only works for standard tensors. This affects dist.entropy, dist.sample(), dist.topk() but not the partition, argmax, marginals.

I tried to fix this rather naively by overloading the classmethod matmul in some of the semirings but this broke the existing tests. I backed that out and am trying to understand how the code relates to the description in the torch struct paper so that I can make the correct fix.


for b in range(lengths.shape[0]):
point = (lengths[b] + M) // 2
Expand Down