Skip to content

Commit 75ef97a

Browse files
authored
style: Ignore some type errors (#370)
* Add comment to ignore mypy errors in _nash_mtl.py * Add comment to ignore type error in mgda * Add comment to ignore type error in TensorDict immutability tricks, and add a comment
1 parent 9182b4a commit 75ef97a

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

src/torchjd/_autojac/_transform/_tensor_dict.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,18 @@ def _check_key_value_pair(key: Tensor, value: Tensor) -> None:
3232
# Make TensorDict immutable, following answer in
3333
# https://stackoverflow.com/questions/11014262/how-to-create-an-immutable-dictionary-in-python
3434
# coming from https://peps.python.org/pep-0351/
35+
# Note that this is not a perfect solution, because it breaks Liskov Substitution Principle, but
36+
# it works.
3537
def _raise_immutable_error(self, *args, **kwargs) -> None:
3638
raise TypeError(f"{self.__class__.__name__} is immutable.")
3739

3840
__setitem__ = _raise_immutable_error
3941
__delitem__ = _raise_immutable_error
4042
clear = _raise_immutable_error
41-
update = _raise_immutable_error
42-
setdefault = _raise_immutable_error
43-
pop = _raise_immutable_error
44-
popitem = _raise_immutable_error
43+
update = _raise_immutable_error # type: ignore[assignment]
44+
setdefault = _raise_immutable_error # type: ignore[assignment]
45+
pop = _raise_immutable_error # type: ignore[assignment]
46+
popitem = _raise_immutable_error # type: ignore[assignment]
4547

4648

4749
class Gradients(TensorDict):

src/torchjd/aggregation/_mgda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def forward(self, gramian: Tensor) -> Tensor:
7777
elif b <= a:
7878
gamma = 0.0
7979
else:
80-
gamma = (b - a) / (b + c - 2 * a)
80+
gamma = (b - a) / (b + c - 2 * a) # type: ignore[assignment]
8181
alpha = (1 - gamma) * alpha + gamma * e_t
8282
if gamma < self.epsilon:
8383
break

src/torchjd/aggregation/_nash_mtl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2424
# SOFTWARE.
2525

26+
# mypy: ignore-errors
27+
2628
from ._utils.check_dependencies import check_dependencies_are_installed
2729
from ._weighting_bases import Matrix, Weighting
2830

0 commit comments

Comments
 (0)