We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f4e374e commit 7b3b661Copy full SHA for 7b3b661
torch_struct/semirings/semirings.py
@@ -49,6 +49,7 @@ def dot(cls, a, b):
49
50
@staticmethod
51
def fill(c, mask, v):
52
+ mask = mask.to(c.device)
53
return torch.where(
54
mask, v.type_as(c).view((-1,) + (1,) * (len(c.shape) - 1)), c
55
)
0 commit comments