@@ -82,10 +82,13 @@ def gaussian_kernel(
8282) -> torch .Tensor :
8383 r"""Returns the 1-dimensional Gaussian kernel of size \(K\).
8484
85- $$ G(x) = \frac{1}{\sum_{y = 1}^{K} G(y)} \exp
85+ $$ G(x) = \gamma \exp
8686 \left(\frac{(x - \mu)^2}{2 \sigma^2}\right) $$
8787
88- where \(x \in [1; K]\) is a position in the kernel
88+ where \(\gamma\) is such that
89+
90+ $$ \sum_{x = 1}^{K} G(x) = 1 $$
91+
8992 and \(\mu = \frac{1 + K}{2}\).
9093
9194 Args:
@@ -263,124 +266,3 @@ def gradient_kernel(kernel: torch.Tensor) -> torch.Tensor:
263266 """
264267
265268 return torch .stack ([kernel , kernel .t ()]).unsqueeze (1 )
266-
267-
268- def tensor_norm (
269- x : torch .Tensor ,
270- dim : List [int ], # Union[int, Tuple[int, ...]] = ()
271- keepdim : bool = False ,
272- norm : str = 'L2' ,
273- ) -> torch .Tensor :
274- r"""Returns the norm of \(x\).
275-
276- $$ L_1(x) = \left\| x \right\|_1 = \sum_i \left| x_i \right| $$
277-
278- $$ L_2(x) = \left\| x \right\|_2 = \left( \sum_i x^2_i \right)^\frac{1}{2} $$
279-
280- Args:
281- x: A tensor, \((*,)\).
282- dim: The dimension(s) along which to calculate the norm.
283- keepdim: Whether the output tensor has `dim` retained or not.
284- norm: Specifies the norm funcion to apply:
285- `'L1'` | `'L2'` | `'L2_squared'`.
286-
287- Wikipedia:
288- https://en.wikipedia.org/wiki/Norm_(mathematics)
289-
290- Example:
291- >>> x = torch.arange(9).float().view(3, 3)
292- >>> x
293- tensor([[0., 1., 2.],
294- [3., 4., 5.],
295- [6., 7., 8.]])
296- >>> tensor_norm(x, dim=0)
297- tensor([6.7082, 8.1240, 9.6437])
298- """
299-
300- if norm == 'L1' :
301- x = x .abs ()
302- else : # norm in ['L2', 'L2_squared']
303- x = x ** 2
304-
305- x = x .sum (dim = dim , keepdim = keepdim )
306-
307- if norm == 'L2' :
308- x = x .sqrt ()
309-
310- return x
311-
312-
313- def normalize_tensor (
314- x : torch .Tensor ,
315- dim : List [int ], # Union[int, Tuple[int, ...]] = ()
316- norm : str = 'L2' ,
317- epsilon : float = 1e-8 ,
318- ) -> torch .Tensor :
319- r"""Returns \(x\) normalized.
320-
321- $$ \hat{x} = \frac{x}{\left\|x\right\|} $$
322-
323- Args:
324- x: A tensor, \((*,)\).
325- dim: The dimension(s) along which to normalize.
326- norm: Specifies the norm funcion to use:
327- `'L1'` | `'L2'` | `'L2_squared'`.
328- epsilon: A numerical stability term.
329-
330- Returns:
331- The normalized tensor, \((*,)\).
332-
333- Example:
334- >>> x = torch.arange(9, dtype=torch.float).view(3, 3)
335- >>> x
336- tensor([[0., 1., 2.],
337- [3., 4., 5.],
338- [6., 7., 8.]])
339- >>> normalize_tensor(x, dim=0)
340- tensor([[0.0000, 0.1231, 0.2074],
341- [0.4472, 0.4924, 0.5185],
342- [0.8944, 0.8616, 0.8296]])
343- """
344-
345- norm = tensor_norm (x , dim = dim , keepdim = True , norm = norm )
346-
347- return x / (norm + epsilon )
348-
349-
350- def unravel_index (
351- indices : torch .LongTensor ,
352- shape : List [int ],
353- ) -> torch .LongTensor :
354- r"""Converts flat indices into unraveled coordinates in a target shape.
355-
356- This is a `torch` implementation of `numpy.unravel_index`.
357-
358- Args:
359- indices: A tensor of (flat) indices, \((*, N)\).
360- shape: The targeted shape, \((D,)\).
361-
362- Returns:
363- The unraveled coordinates, \((*, N, D)\).
364-
365- Example:
366- >>> unravel_index(torch.arange(9), shape=(3, 3))
367- tensor([[0, 0],
368- [0, 1],
369- [0, 2],
370- [1, 0],
371- [1, 1],
372- [1, 2],
373- [2, 0],
374- [2, 1],
375- [2, 2]])
376- """
377-
378- coord = []
379-
380- for dim in reversed (shape ):
381- coord .append (indices % dim )
382- indices = indices // dim
383-
384- coord = torch .stack (coord [::- 1 ], dim = - 1 )
385-
386- return coord
0 commit comments