Skip to content

Commit 7b3b661

Browse files
authored
Fix for GPU runtime error for semiring (#124)
Fix #121
1 parent f4e374e commit 7b3b661

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

torch_struct/semirings/semirings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def dot(cls, a, b):
4949

5050
@staticmethod
5151
def fill(c, mask, v):
52+
mask = mask.to(c.device)
5253
return torch.where(
5354
mask, v.type_as(c).view((-1,) + (1,) * (len(c.shape) - 1)), c
5455
)

0 commit comments

Comments
 (0)