Skip to content

Commit 616ed76

Browse files
committed
Apply fixes to pass ruff tests
1 parent 8280654 commit 616ed76

File tree

11 files changed

+14
-32
lines changed

11 files changed

+14
-32
lines changed

cebra/attribution/attribution_models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
# Code will be open-sourced upon publication.
66
#
77
import dataclasses
8-
import sys
98
import time
109

10+
import cvxpy as cp
1111
import numpy as np
1212
import scipy.linalg
1313
import sklearn.metrics
1414
import torch
1515
import torch.nn as nn
16+
import tqdm
1617
from captum.attr import NeuronFeatureAblation
1718
from captum.attr import NeuronGradient
1819
from captum.attr import NeuronGradientShap
@@ -172,10 +173,11 @@ def _inverse_lsq_cvxpy(matrix: np.ndarray,
172173
matrix_param = cp.Parameter((matrix.shape[0], matrix.shape[1]))
173174
matrix_param.value = matrix
174175

175-
I = np.eye(matrix.shape[0])
176+
identity = np.eye(matrix.shape[0])
176177
matrix_inverse = cp.Variable((matrix.shape[1], matrix.shape[0]))
177178

178-
objective = cp.Minimize(cp.norm(matrix @ matrix_inverse - I, "fro"))
179+
objective = cp.Minimize(
180+
cp.norm(matrix @ matrix_inverse - identity, "fro"))
179181
prob = cp.Problem(objective)
180182
prob.solve(verbose=False, solver=solver)
181183

cebra/data/multiobjective.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# Not licensed yet. Distribution for review.
55
# Code will be open-sourced upon publication.
66
#
7-
from typing import List
87

98
import literate_dataclasses as dataclasses
109

cebra/models/jacobian_regularizer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from __future__ import division
1414

1515
import torch
16-
import torch.autograd as autograd
1716
import torch.nn as nn
1817

1918

cebra/models/multi_criterions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# Not licensed yet. Distribution for review.
55
# Code will be open-sourced upon publication.
66
#
7-
from typing import Optional, Tuple, Union
7+
from typing import Tuple
88

99
import torch
1010
from torch import nn

cebra/solver/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,9 @@
3434
import logging
3535
import os
3636
import time
37-
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
37+
from typing import Callable, Dict, List, Literal, Optional
3838

3939
import literate_dataclasses as dataclasses
40-
import numpy as np
4140
import torch
4241

4342
import cebra

cebra/solver/multiobjective.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,22 @@
66
#
77
"""Multiobjective contrastive learning."""
88

9-
import abc
109
import logging
11-
import os
1210
import time
1311
import warnings
14-
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
12+
from typing import Callable, Dict, List, Optional, Tuple
1513

1614
import literate_dataclasses as dataclasses
1715
import numpy as np
1816
import torch
19-
import tqdm
2017

2118
import cebra
2219
import cebra.data
2320
import cebra.io
2421
import cebra.models
25-
import cebra.solver.base as abc_
2622
from cebra.solver import register
2723
from cebra.solver.base import Solver
24+
from cebra.solver.schedulers import Scheduler
2825
from cebra.solver.util import Meter
2926

3027

@@ -43,9 +40,8 @@ def __init__(self, loader):
4340
def _check_overwriting_key(self, key):
4441
if key in self.current_info:
4542
warnings.warn(
46-
f"Configuration key already exists. Overwriting existing value. "
47-
f"If you don't want to overwrite you should call push() before."
48-
)
43+
"Configuration key already exists. Overwriting existing value. "
44+
"If you don't want to overwrite you should call push() before.")
4945

5046
def _check_pushed_status(self):
5147
if "slice" not in self.current_info:

cebra/solver/regularized.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,14 @@
66
#
77
"""Regularized contrastive learning."""
88

9-
from typing import Callable, Dict, List, Literal, Optional, Union
9+
from typing import Dict, Optional
1010

1111
import literate_dataclasses as dataclasses
1212
import torch
1313

1414
import cebra
1515
import cebra.data
1616
import cebra.models
17-
import cebra.solver.multiobjective as abc_
1817
from cebra.solver import register
1918
from cebra.solver.single_session import SingleSessionSolver
2019

cebra/solver/single_session.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222
"""Single session solvers embed a single pair of time series."""
2323

2424
import copy
25-
import os
26-
from collections.abc import Iterable
27-
from typing import Callable, Dict, List, Literal, Optional, Union
25+
from typing import Dict
2826

2927
import literate_dataclasses as dataclasses
3028
import torch

examples/train_and_evaluate.ipynb

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
"import cebra.models\n",
2121
"import cebra.solver\n",
2222
"\n",
23-
"import numpy as np\n",
2423
"import matplotlib.pyplot as plt\n",
2524
"import torch"
2625
]
@@ -50,7 +49,6 @@
5049
}
5150
],
5251
"source": [
53-
"import cebra.solver\n",
5452
"\n",
5553
"cebra.solver.get_options()"
5654
]
@@ -386,7 +384,6 @@
386384
}
387385
],
388386
"source": [
389-
"import numpy as np\n",
390387
"from sklearn.linear_model import LinearRegression\n",
391388
"\n",
392389
"# Create and fit the linear regression model\n",
@@ -534,7 +531,6 @@
534531
"metadata": {},
535532
"outputs": [],
536533
"source": [
537-
"import numpy as np\n",
538534
"from sklearn.linear_model import LinearRegression\n",
539535
"\n",
540536
"R2_dict = {}\n",

tests/test_models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
import pytest
2525
import torch
26-
from torch import nn
2726

2827
import cebra.models
2928
import cebra.models.model

0 commit comments

Comments
 (0)