-
Notifications
You must be signed in to change notification settings - Fork 11
Givens orthogonal layer #57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
I somehow broke @kevinchern's tests, what the hell... |
| def test_store_config(self): | ||
| with self.subTest("Simple case"): | ||
|
|
||
| class MyModel(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove formatting changes. Is this "black" formatting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I have it by default on my vscode
@VolodyaCO which tests? I'm seeing |
I forgot to update my tests to float64 precision. Now that I've done it, it's weird that all of the current failing tests are failing on File "/Users/distiller/project/tests/test_nn.py", line 144, in test_LinearBlock
self.assertTrue(model_probably_good(model, (din,), (dout,))) |
Ahhhhhh. OK Theo also flagged this at #50 . It's a poorly-written test.. you can ignore it. |
| Returns: | ||
| list[list[tuple[int, int]]]: Blocks of edges for parallel Givens rotations. | ||
| Note: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better as a note directive: https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html#directive-note
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where should I put this? in the release notes? or in the docstring itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simply change the Note: to
.. note::
Lorem ipsum...
which would render a note box if we generate docs with Sphinx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have done this now
| angles (torch.Tensor): A ((n - 1) * n // 2,) shaped tensor containing all rotations | ||
| between pairs of dimensions. | ||
| blocks (torch.Tensor): A (n-1, n//2, 2) shaped tensor containing the indices that | ||
| specify rotations between pairs of dimensions. Each of the n-1 blocks contains n//2 | ||
| pairs of independent rotations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code formatting?
| angles (torch.Tensor): A ((n - 1) * n // 2,) shaped tensor containing all rotations | |
| between pairs of dimensions. | |
| blocks (torch.Tensor): A (n-1, n//2, 2) shaped tensor containing the indices that | |
| specify rotations between pairs of dimensions. Each of the n-1 blocks contains n//2 | |
| pairs of independent rotations. | |
| angles (torch.Tensor): A ``((n - 1) * n // 2,)`` shaped tensor containing all rotations | |
| between pairs of dimensions. | |
| blocks (torch.Tensor): A ``(n - 1, n // 2, 2)`` shaped tensor containing the indices that | |
| specify rotations between pairs of dimensions. Each of the n-1 blocks contains n // 2 | |
| pairs of independent rotations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks.
| """ | ||
| # Blocks is of shape (n_blocks, n/2, 2) containing indices for angles | ||
| # Within each block, each Givens rotation is commuting, so we can apply them in parallel | ||
| U = torch.eye(n, device=angles.device, dtype=angles.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slight preference to keep variables lower-case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed this in the main GivensRotationLayer class. In the other code, I kept the capital letters just so that if someone is reading the algorithm in the paper alongside the code, each part of the algorithm is more easily understood.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also favour variable names in lower case with upper case reserved for constants, primarily because it is a widely adopted convention. I agree having a 1-1 correspondence between paper notation and implementation is important for readability, but I think making exceptions paper-by-paper can become messy. I suggest noting the correspondence between variable names and paper notation in the docstring.
--
I think I snuck in some upper case variable names in the codebase... should track those down at some point 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Playing devils' advocate against myself here: sometimes descriptive variable names are unnecessarily verbose and unhelpful in describing the algorithm.
🤷♀️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that for some algorithms, readers are more used to certain notations, for example, that U is orthogonal (actually orthonormal). I would make a vote for picking a convention 😆
| angles, blocks, Ufwd_saved = ctx.saved_tensors | ||
| Ufwd = Ufwd_saved.clone() | ||
| M = grad_output.t() # dL/dU, i.e., grad_output is of shape (n, n) | ||
| n = M.size(1) | ||
| block_size = n // 2 | ||
| A = torch.zeros((block_size, n), device=angles.device, dtype=angles.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here re lowercase for Ufwd, M, and A. Avoids incorrect colour highlighting in themes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, I didn't read this about the incorrect colour highlighting before I made my previous comment. I still think that it is easier to read the algorithm alongside the code if the use of lower/upper case match. For example, lower case m is usually used for an integer variable, not a tensor.
| return U | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output: torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing return type hint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the type hint as well as a longer explanation on what this return is.
| U = self._create_rotation_matrix() | ||
| rotated_x = einsum(x, U, "... i, o i -> ... o") | ||
| if self.bias is not None: | ||
| rotated_x = rotated_x + self.bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| rotated_x = rotated_x + self.bias | |
| rotated_x += self.bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
| from einops import einsum | ||
|
|
||
|
|
||
| class NaiveGivensRotationLayer(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not very keen on having a full on separate implementation here just to compare with/test the GivensRotationLayer. If this NaiveGivensRotationLayer is useful, should it be part of the package instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We discussed this in our one on one but, just for the record, there is no difference between the NaiveGivensRotationLayer and the GivensRotationLayer in the forward or backward passes. The naïve implementation is there to make sure that the forward and backward passes indeed match. The GivensRotationLayer should always be used because it has a substantially better runtime complexity. Thus, the naïve implementation is not useful—other than for a sanity check.
tests/test_nn.py
Outdated
| @parameterized.expand([(n, bias) for n in [4, 5, 6, 9, 10] for bias in [True, False]]) | ||
| def test_forward_agreement(self, n, bias): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests do seem a bit too.. complex. Better to try and test more minimal aspects of the class, if possible. I'd much rather have separate integration-like tests that can assert that model behave as expected, while having these be strictly, small scale, isolated unit tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some tests to test invalid inputs too. These forward and backward tests are for testing that the correct input/output is given when compared to the naïve implementation. The model_probably_good test is done as unit test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added other unit tests where I test incorrect inputs as well. In ML models, the forward and backward passes should be what one expects them to be, and this module gives the opportunity to test this correctly. I do agree that we should separate other tests that (at least) I wrote, which have to do with training a model to see if the intended final trained state is what is expected. However, the tests I present in this PR are not the result of training but explicit comparisons with the naïve approach; I don't know if we could regard those as integration tests.
f18b476 to
46062e0
Compare
|
After a bit of git wrangling, I was able to clean my whole mess of merge commits 😆. |
anahitamansouri
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a nice PR Vlad. It took me a while to go over the paper and this PR :) The only thing is the tests that are failing. Thanks for the great work.
| self.n = n | ||
| self.n_angles = n * (n - 1) // 2 | ||
| self.angles = nn.Parameter(torch.randn(self.n_angles)) | ||
| blocks_edges = _get_blocks_edges(n) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could directly return torch.LongTensor from get_blocks_edges to avoid the conversion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I set _get_blocks_edges to a private function, so it shouldn't make a difference if I convert the list to a tensor in the orthogonal module or in the function itself.
46062e0 to
7f76571
Compare
kevinchern
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did a quick pass to provide some feedback before taking some time to take a deep dive into the paper.
|
|
||
|
|
||
| def _get_blocks_edges(n: int) -> list[list[tuple[int, int]]]: | ||
| """Uses the circle method for Round Robin pairing to create blocks of edges for parallel Givens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| """Uses the circle method for Round Robin pairing to create blocks of edges for parallel Givens | |
| """Uses the circle method for round-robin pairing to create blocks of edges for parallel Givens |
(and other occurrences)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should _get_blocks_edges should be a method in GivensRotation instead? The orthogonal module is general while this function is a helper function bespoke to GivensRotation.
cc @thisac
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe... though what would be the attribute of GivensRotation used in _get_blocks_edges? n only?
| return grad_theta, None, None | ||
|
|
||
|
|
||
| class GivensRotationLayer(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rename to GivensRotation (parallel to nn.Linear)
| if n % 2 != 0: | ||
| n += 1 # Add a dummy dimension for odd n | ||
| is_odd = True | ||
| else: | ||
| is_odd = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be cleaner like this 😛
| if n % 2 != 0: | |
| n += 1 # Add a dummy dimension for odd n | |
| is_odd = True | |
| else: | |
| is_odd = False | |
| odd = n % 2 != 0 | |
| if odd: | |
| n += 1 |
or
odd = n % 2 ! = 0
n += odd
but this is less obvious.. (edit: not a big fan of n+=odd notation 😆)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is cleaner! (the first suggestion, not the n+=odd 😆 )
| ignored. | ||
| """ | ||
| if n % 2 != 0: | ||
| n += 1 # Add a dummy dimension for odd n |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rule-of-thumb for comments: explain the "why" or motivation as opposed to "what" (which is clear in this context)
| for _ in range(n - 1): | ||
| pairs = circle_method(sequence) | ||
| if is_odd: | ||
| # Remove pairs involving the dummy dimension: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # Remove pairs involving the dummy dimension: | |
| # Remove pairs involving the dummy dimension |
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None]: | ||
| """Computes the VJP needed for backward propagation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| """Computes the VJP needed for backward propagation. | |
| """Computes the vector-Jacobian product needed for backward propagation. |
| idx_block = torch.arange(block_size, device=angles.device) | ||
| for b, block in enumerate(blocks): | ||
| # angles is of shape (n_angles,) containing all angles for contiguous blocks. | ||
| angles_in_block = angles[idx_block + b * blocks.size(1)] # shape (n/2,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| angles_in_block = angles[idx_block + b * blocks.size(1)] # shape (n/2,) | |
| angles_in_block = angles[idx_block + b * block_size] # shape (n/2,) |
If I understand correctly, blocks.size(1) will be block_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, while writing the algorithm I though blocks could have different sizes if n is odd, but that is not true. All blocks will have the same block size.
| c = torch.cos(angles_in_block) | ||
| s = torch.sin(angles_in_block) | ||
| i_idx = block[:, 0] | ||
| j_idx = block[:, 1] | ||
| r_i = c.unsqueeze(0) * U[:, i_idx] + s.unsqueeze(0) * U[:, j_idx] | ||
| r_j = -s.unsqueeze(0) * U[:, i_idx] + c.unsqueeze(0) * U[:, j_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unsqueeze once in the beginning
| c = torch.cos(angles_in_block) | |
| s = torch.sin(angles_in_block) | |
| i_idx = block[:, 0] | |
| j_idx = block[:, 1] | |
| r_i = c.unsqueeze(0) * U[:, i_idx] + s.unsqueeze(0) * U[:, j_idx] | |
| r_j = -s.unsqueeze(0) * U[:, i_idx] + c.unsqueeze(0) * U[:, j_idx] | |
| c = torch.cos(angles_in_block).unsqueeze(0) | |
| s = torch.sin(angles_in_block).unsqueeze(0) | |
| i_idx = block[:, 0] | |
| j_idx = block[:, 1] | |
| r_i = c * U[:, i_idx] + s * U[:, j_idx] | |
| r_j = -s * U[:, i_idx] + c * U[:, j_idx] |
| U = torch.eye(n, device=angles.device, dtype=angles.dtype) | ||
| block_size = n // 2 | ||
| idx_block = torch.arange(block_size, device=angles.device) | ||
| for b, block in enumerate(blocks): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we commit to using paper variable names here, we should be consistent and use, e.g., B instead of blocks.
If that's the case, I'd prefer to be a little more wasteful and have B = blocks to keep the input argument blocks instead of B. This inconsistency makes me lean towards named variables more (with a look-up table in the docstring).
| r_i = c.unsqueeze(0) * U[:, i_idx] + s.unsqueeze(0) * U[:, j_idx] | ||
| r_j = -s.unsqueeze(0) * U[:, i_idx] + c.unsqueeze(0) * U[:, j_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are r_i and r_j are backwards?
I think it should be:
-
$\cos - \sin$ for i, and -
$\sin + \cos$ for j.
Not sure if this has a significant impact on validity of method. If it does, then tests should be revised first to see why this error was not detected
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes... well... in the paper the rotation matrices were written the other way around, I think. I did the math separately and this way everything is consistent.
This PR adds an orthogonal layer given by Givens rotations, using the parallel algorithm described by Firas in https://arxiv.org/abs/2106.00003, which gives a forward complexity of O(n) and backward complexity of O(n log(n)), even though there are O(n^2) rotations.
This PR still is in draft. I wrote it for even n. Probably some more unit tests are to be done, but I am quite lazy (will do it after all math is checked for odd n).