Skip to content

Commit c6925dd

Browse files
committed
[BUG]
1 parent dc0127f commit c6925dd

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

gempy_engine/core/backend_tensor.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def describe_conf(cls):
145145

146146
@classmethod
147147
def _wrap_pytorch_functions(cls):
148-
from torch import sum, repeat_interleave, flip
148+
from torch import sum, repeat_interleave, isclose
149149
import torch
150150

151151
def _sum(tensor, axis=None, dtype=None, keepdims=False):
@@ -265,6 +265,14 @@ def _to_numpy(tensor):
265265
# Not a torch tensor, return as-is
266266
return tensor
267267

268+
def _fill_diagonal(tensor, value):
269+
"""Fill the diagonal of a 2D tensor with the given value"""
270+
if tensor.dim() != 2:
271+
raise ValueError("fill_diagonal only supports 2D tensors")
272+
diagonal_indices = torch.arange(min(tensor.size(0), tensor.size(1)))
273+
tensor[diagonal_indices, diagonal_indices] = value
274+
return tensor
275+
268276
cls.tfnp.sum = _sum
269277
cls.tfnp.repeat = _repeat
270278
cls.tfnp.expand_dims = lambda tensor, axis: tensor
@@ -285,6 +293,14 @@ def _to_numpy(tensor):
285293
cls.tfnp.tile = lambda tensor, repeats: tensor.repeat(repeats)
286294
cls.tfnp.ravel = lambda tensor: tensor.flatten()
287295
cls.tfnp.packbits = _packbits
296+
cls.tfnp.fill_diagonal = _fill_diagonal
297+
cls.tfnp.isclose = lambda a, b, rtol=1e-05, atol=1e-08, equal_nan=False: isclose(
298+
a,
299+
torch.tensor(b, dtype=a.dtype, device=a.device),
300+
rtol=rtol,
301+
atol=atol,
302+
equal_nan=equal_nan
303+
)
288304

289305
@classmethod
290306
def _wrap_pykeops_functions(cls):

0 commit comments

Comments
 (0)