Skip to content

Commit 1a8f640

Browse files
Copilotpancetta
andauthored
Fix CI failures from master merge: pyproject.toml syntax and deprecated test syntax (#610)
* Initial plan * Add type hints to errors.py, common.py, and problem.py; add mypy config Co-authored-by: pancetta <7158893+pancetta@users.noreply.github.com> * Add type hints to sweeper.py, collocation.py, and hooks.py Co-authored-by: pancetta <7158893+pancetta@users.noreply.github.com> * Add comprehensive type hints to core modules: base_transfer, space_transfer, level, step, controller, convergence_controller Co-authored-by: pancetta <7158893+pancetta@users.noreply.github.com> * Fix duplicate attribute initialization in controller.py * Use Union type for convergence_controller_order to match runtime behavior Co-authored-by: pancetta <7158893+pancetta@users.noreply.github.com> * Improve comment clarity for convergence_controller_order type * Simplify comment for convergence_controller_order Co-authored-by: pancetta <7158893+pancetta@users.noreply.github.com> * Improve comment clarity in controller.py Co-authored-by: pancetta <7158893+pancetta@users.noreply.github.com> * Add method reference to convergence_controller_order comment * Adjust mypy configuration for progressive type checking Co-authored-by: pancetta <7158893+pancetta@users.noreply.github.com> * Address code review feedback: improve type specificity in step.py and controller.py Co-authored-by: pancetta <7158893+pancetta@users.noreply.github.com> * Format core modules with black Co-authored-by: pancetta <7158893+pancetta@users.noreply.github.com> * Initial investigation - found pyproject.toml syntax error from merge Co-authored-by: pancetta <7158893+pancetta@users.noreply.github.com> * Fix tests_core.py: convert deprecated yield syntax to pytest.mark.parametrize Co-authored-by: pancetta <7158893+pancetta@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: pancetta <7158893+pancetta@users.noreply.github.com> Co-authored-by: Robert Speck <pancetta@users.noreply.github.com>
1 parent 53f681b commit 1a8f640

File tree

14 files changed

+344
-259
lines changed

14 files changed

+344
-259
lines changed

pySDC/core/base_transfer.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
import logging
2+
from typing import Any, Dict, Optional, TYPE_CHECKING
23

34
import scipy.sparse as sp
5+
import numpy as np
46

57
from pySDC.core.errors import UnlockError
68
from pySDC.helpers.pysdc_helper import FrozenClass
79
from qmat.lagrange import LagrangeApproximation
810

11+
if TYPE_CHECKING:
12+
from pySDC.core.level import Level
13+
914

1015
# short helper class to add params as attributes
1116
class _Pars(FrozenClass):
12-
def __init__(self, pars):
13-
self.finter = False
17+
def __init__(self, pars: Dict[str, Any]) -> None:
18+
self.finter: bool = False
1419
for k, v in pars.items():
1520
setattr(self, k, v)
1621

@@ -28,7 +33,14 @@ class BaseTransfer(object):
2833
coarse (pySDC.Level.level): reference to the coarse level
2934
"""
3035

31-
def __init__(self, fine_level, coarse_level, base_transfer_params, space_transfer_class, space_transfer_params):
36+
def __init__(
37+
self,
38+
fine_level: 'Level',
39+
coarse_level: 'Level',
40+
base_transfer_params: Dict[str, Any],
41+
space_transfer_class: Any,
42+
space_transfer_params: Dict[str, Any],
43+
) -> None:
3244
"""
3345
Initialization routine
3446
@@ -40,31 +52,31 @@ def __init__(self, fine_level, coarse_level, base_transfer_params, space_transfe
4052
space_transfer_params (dict): parameters for the space_transfer operations
4153
"""
4254

43-
self.params = _Pars(base_transfer_params)
55+
self.params: _Pars = _Pars(base_transfer_params)
4456

4557
# set up logger
46-
self.logger = logging.getLogger('transfer')
58+
self.logger: logging.Logger = logging.getLogger('transfer')
4759

48-
self.fine = fine_level
49-
self.coarse = coarse_level
60+
self.fine: 'Level' = fine_level
61+
self.coarse: 'Level' = coarse_level
5062

5163
fine_grid = self.fine.sweep.coll.nodes
5264
coarse_grid = self.coarse.sweep.coll.nodes
5365

5466
if len(fine_grid) == len(coarse_grid):
55-
self.Pcoll = sp.eye(len(fine_grid)).toarray()
56-
self.Rcoll = sp.eye(len(fine_grid)).toarray()
67+
self.Pcoll: np.ndarray = sp.eye(len(fine_grid)).toarray()
68+
self.Rcoll: np.ndarray = sp.eye(len(fine_grid)).toarray()
5769
else:
5870
self.Pcoll = self.get_transfer_matrix_Q(fine_grid, coarse_grid)
5971
self.Rcoll = self.get_transfer_matrix_Q(coarse_grid, fine_grid)
6072

6173
# set up spatial transfer
62-
self.space_transfer = space_transfer_class(
74+
self.space_transfer: Any = space_transfer_class(
6375
fine_prob=self.fine.prob, coarse_prob=self.coarse.prob, params=space_transfer_params
6476
)
6577

6678
@staticmethod
67-
def get_transfer_matrix_Q(f_nodes, c_nodes):
79+
def get_transfer_matrix_Q(f_nodes: np.ndarray, c_nodes: np.ndarray) -> np.ndarray:
6880
"""
6981
Helper routine to quickly define transfer matrices from a coarse set
7082
to a fine set of nodes (fully Lagrangian)
@@ -78,7 +90,7 @@ def get_transfer_matrix_Q(f_nodes, c_nodes):
7890
approx = LagrangeApproximation(c_nodes)
7991
return approx.getInterpolationMatrix(f_nodes)
8092

81-
def restrict(self):
93+
def restrict(self) -> None:
8294
"""
8395
Space-time restriction routine
8496
@@ -163,7 +175,7 @@ def restrict(self):
163175

164176
return None
165177

166-
def prolong(self):
178+
def prolong(self) -> None:
167179
"""
168180
Space-time prolongation routine
169181
@@ -202,7 +214,7 @@ def prolong(self):
202214

203215
return None
204216

205-
def prolong_f(self):
217+
def prolong_f(self) -> None:
206218
"""
207219
Space-time prolongation routine w.r.t. the rhs f
208220

pySDC/core/collocation.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from typing import Optional, Any
23
import numpy as np
34
from qmat import Q_GENERATORS
45

@@ -44,7 +45,15 @@ class CollBase(object):
4445
left_is_node (bool): flag to indicate whether left point is collocation node
4546
"""
4647

47-
def __init__(self, num_nodes=None, tleft=0, tright=1, node_type='LEGENDRE', quad_type=None, **kwargs):
48+
def __init__(
49+
self,
50+
num_nodes: Optional[int] = None,
51+
tleft: float = 0,
52+
tright: float = 1,
53+
node_type: str = 'LEGENDRE',
54+
quad_type: Optional[str] = None,
55+
**kwargs: Any,
56+
) -> None:
4857
"""
4958
Initialization routine for a collocation object
5059
@@ -59,7 +68,7 @@ def __init__(self, num_nodes=None, tleft=0, tright=1, node_type='LEGENDRE', quad
5968
if not tleft < tright:
6069
raise CollocationError('interval boundaries are corrupt, got %s and %s' % (tleft, tright))
6170

62-
self.logger = logging.getLogger('collocation')
71+
self.logger: logging.Logger = logging.getLogger('collocation')
6372
try:
6473
self.generator = Q_GENERATORS["Collocation"](
6574
nNodes=num_nodes, nodeType=node_type, quadType=quad_type, tLeft=tleft, tRight=tright
@@ -99,7 +108,7 @@ def __init__(self, num_nodes=None, tleft=0, tright=1, node_type='LEGENDRE', quad
99108
self.delta_m = self._gen_deltas
100109

101110
@staticmethod
102-
def evaluate(weights, data):
111+
def evaluate(weights: np.ndarray, data: np.ndarray) -> np.ndarray:
103112
"""
104113
Evaluates the quadrature over the full interval
105114
@@ -116,7 +125,7 @@ def evaluate(weights, data):
116125
return np.dot(weights, data)
117126

118127
@property
119-
def _gen_deltas(self):
128+
def _gen_deltas(self) -> np.ndarray:
120129
"""
121130
Compute distances between the nodes
122131

pySDC/core/common.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
classes.
99
"""
1010

11+
from typing import Any, Optional, Set, Dict
1112
from pySDC.core.errors import ReadOnlyError
1213

1314

1415
class _MetaRegisterParams(type):
1516
"""Metaclass for RegisterParams base class"""
1617

17-
def __new__(cls, name, bases, dct):
18+
def __new__(cls, name: str, bases: tuple, dct: dict) -> type:
1819
obj = super().__new__(cls, name, bases, dct)
1920
obj._parNamesReadOnly = set()
2021
obj._parNames = set()
@@ -35,7 +36,9 @@ class RegisterParams(metaclass=_MetaRegisterParams):
3536
Names of all the parameters registered as read-only.
3637
"""
3738

38-
def _makeAttributeAndRegister(self, *names, localVars=None, readOnly=False):
39+
def _makeAttributeAndRegister(
40+
self, *names: str, localVars: Optional[Dict[str, Any]] = None, readOnly: bool = False
41+
) -> None:
3942
"""
4043
Register a list of attribute name as parameters of the class.
4144
@@ -66,11 +69,11 @@ def _makeAttributeAndRegister(self, *names, localVars=None, readOnly=False):
6669
self._parNames = self._parNames.union(names)
6770

6871
@property
69-
def params(self):
72+
def params(self) -> Dict[str, Any]:
7073
"""Dictionary containing names and values of registered parameters"""
7174
return {name: getattr(self, name) for name in self._parNamesReadOnly.union(self._parNames)}
7275

73-
def __setattr__(self, name, value):
76+
def __setattr__(self, name: str, value: Any) -> None:
7477
if name in self._parNamesReadOnly:
7578
raise ReadOnlyError(name)
7679
super().__setattr__(name, value)

0 commit comments

Comments
 (0)