Skip to content

Commit e915cc0

Browse files
authored
Merge pull request #70 from Turakar/smallopt_repeat
Add short path for LinearOperator.repeat()
2 parents 6b6d19a + 4dbcc93 commit e915cc0

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

linear_operator/operators/_linear_operator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2023,6 +2023,10 @@ def repeat(self, *sizes: Union[int, Tuple[int, ...]]) -> LinearOperator:
20232023
"""
20242024
from .batch_repeat_linear_operator import BatchRepeatLinearOperator
20252025

2026+
# Short path if no repetition is necessary
2027+
if all(x == 1 for x in sizes) and len(sizes) == self.dim():
2028+
return self
2029+
20262030
if len(sizes) < 3 or tuple(sizes[-2:]) != (1, 1):
20272031
raise RuntimeError(
20282032
"Invalid repeat arguments {}. Currently, repeat only works to create repeated "

test/utils/test_repeat.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import unittest
2+
3+
import torch
4+
5+
from linear_operator.operators.dense_linear_operator import DenseLinearOperator
6+
from linear_operator.test.utils import approx_equal
7+
8+
9+
class TestRepeat(unittest.TestCase):
10+
def make_example(self):
11+
return DenseLinearOperator(torch.randn(3, 3))
12+
13+
def test_repeat(self):
14+
example = self.make_example()
15+
repeated = example.repeat(2, 1, 1)
16+
repeated_dense = example.to_dense().repeat(2, 1, 1)
17+
self.assertTrue(approx_equal(repeated.to_dense(), repeated_dense))
18+
19+
def test_repeat_noop(self):
20+
example = self.make_example()
21+
repeated = example.repeat(1, 1)
22+
self.assertTrue(approx_equal(repeated.to_dense(), example.to_dense()))
23+
self.assertIsInstance(repeated, DenseLinearOperator) # ensure that fast path is taken

0 commit comments

Comments
 (0)