@@ -145,7 +145,7 @@ def describe_conf(cls):
145145
146146 @classmethod
147147 def _wrap_pytorch_functions (cls ):
148- from torch import sum , repeat_interleave , flip
148+ from torch import sum , repeat_interleave , isclose
149149 import torch
150150
151151 def _sum (tensor , axis = None , dtype = None , keepdims = False ):
@@ -265,6 +265,14 @@ def _to_numpy(tensor):
265265 # Not a torch tensor, return as-is
266266 return tensor
267267
268+ def _fill_diagonal (tensor , value ):
269+ """Fill the diagonal of a 2D tensor with the given value"""
270+ if tensor .dim () != 2 :
271+ raise ValueError ("fill_diagonal only supports 2D tensors" )
272+ diagonal_indices = torch .arange (min (tensor .size (0 ), tensor .size (1 )))
273+ tensor [diagonal_indices , diagonal_indices ] = value
274+ return tensor
275+
268276 cls .tfnp .sum = _sum
269277 cls .tfnp .repeat = _repeat
270278 cls .tfnp .expand_dims = lambda tensor , axis : tensor
@@ -285,6 +293,14 @@ def _to_numpy(tensor):
285293 cls .tfnp .tile = lambda tensor , repeats : tensor .repeat (repeats )
286294 cls .tfnp .ravel = lambda tensor : tensor .flatten ()
287295 cls .tfnp .packbits = _packbits
296+ cls .tfnp .fill_diagonal = _fill_diagonal
297+ cls .tfnp .isclose = lambda a , b , rtol = 1e-05 , atol = 1e-08 , equal_nan = False : isclose (
298+ a ,
299+ torch .tensor (b , dtype = a .dtype , device = a .device ),
300+ rtol = rtol ,
301+ atol = atol ,
302+ equal_nan = equal_nan
303+ )
288304
289305 @classmethod
290306 def _wrap_pykeops_functions (cls ):
0 commit comments