Skip to content

Commit a047c55

Browse files
update pylint, isort and black versions in format.yml (#8506)
Time to bump the linting and formatting eco-system again. --------- Co-authored-by: Yushao Chen (Jerry) <[email protected]>
1 parent e0d0846 commit a047c55

35 files changed

+121
-131
lines changed

doc/releases/changelog-dev.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
:orphan:
2-
31
# Release 0.44.0-dev (development release)
42

53
<h3>New features since last release</h3>
@@ -148,6 +146,9 @@
148146

149147
<h3>Internal changes ⚙️</h3>
150148

149+
* Update versions for `pylint`, `isort` and `black` in `format.yml`
150+
[(#8506)](https://github.com/PennyLaneAI/pennylane/pull/8506)
151+
151152
* Reclassifies `registers` as a tertiary module for use with tach.
152153
[(#8513)](https://github.com/PennyLaneAI/pennylane/pull/8513)
153154

pennylane/_grad.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _get_grad_prim():
5050
grad_prim.prim_type = "higher_order"
5151

5252
@grad_prim.def_impl
53-
def _(*args, argnums, jaxpr, n_consts, method, h):
53+
def _grad_def_impl(*args, argnums, jaxpr, n_consts, method, h):
5454
if method or h: # pragma: no cover
5555
raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.")
5656
consts = args[:n_consts]
@@ -63,7 +63,7 @@ def func(*inner_args):
6363

6464
# pylint: disable=unused-argument
6565
@grad_prim.def_abstract_eval
66-
def _(*args, argnums, jaxpr, n_consts, method, h):
66+
def _grad_abstract_eval(*args, argnums, jaxpr, n_consts, method, h):
6767
if len(jaxpr.outvars) != 1 or jaxpr.outvars[0].aval.shape != ():
6868
raise TypeError("Grad only applies to scalar-output functions. Try jacobian.")
6969
return tuple(args[i + n_consts] for i in argnums)
@@ -90,7 +90,7 @@ def _get_jacobian_prim():
9090
jacobian_prim.prim_type = "higher_order"
9191

9292
@jacobian_prim.def_impl
93-
def _(*args, argnums, jaxpr, n_consts, method, h):
93+
def _jacobian_def_impl(*args, argnums, jaxpr, n_consts, method, h):
9494
if method or h: # pragma: no cover
9595
raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.")
9696
consts = args[:n_consts]
@@ -103,7 +103,7 @@ def func(*inner_args):
103103

104104
# pylint: disable=unused-argument
105105
@jacobian_prim.def_abstract_eval
106-
def _(*args, argnums, jaxpr, n_consts, method, h):
106+
def _jacobian_abstract_eval(*args, argnums, jaxpr, n_consts, method, h):
107107
in_avals = tuple(args[i + n_consts] for i in argnums)
108108
out_shapes = tuple(outvar.aval.shape for outvar in jaxpr.outvars)
109109
return [

pennylane/allocation.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,24 +49,28 @@ class AllocateState(StrEnum):
4949
allocate_prim.multiple_results = True
5050

5151
@allocate_prim.def_impl
52-
def _(*, num_wires, state: AllocateState = AllocateState.ZERO, restored=False):
52+
def _allocate_primitive_impl(
53+
*, num_wires, state: AllocateState = AllocateState.ZERO, restored=False
54+
):
5355
raise NotImplementedError("jaxpr containing qubit allocation cannot be executed.")
5456

5557
# pylint: disable=unused-argument
5658
@allocate_prim.def_abstract_eval
57-
def _(*, num_wires, state: AllocateState = AllocateState.ZERO, restored=False):
59+
def _allocate_primitive_abstract_eval(
60+
*, num_wires, state: AllocateState = AllocateState.ZERO, restored=False
61+
):
5862
return [jax.core.ShapedArray((), dtype=int) for _ in range(num_wires)]
5963

6064
deallocate_prim = QmlPrimitive("deallocate")
6165
deallocate_prim.multiple_results = True
6266

6367
@deallocate_prim.def_impl
64-
def _(*wires):
68+
def _deallocate_primitive_impl(*wires):
6569
raise NotImplementedError("jaxpr containing qubit deallocation cannot be executed.")
6670

6771
# pylint: disable=unused-argument
6872
@deallocate_prim.def_abstract_eval
69-
def _(*wires):
73+
def _deallocate_primitive_abstract_eval(*wires):
7074
return []
7175

7276

pennylane/capture/base_interpreter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def _(self, x, *dyn_shape, shape, broadcast_dimensions, sharding):
442442

443443
# pylint: disable=unused-argument
444444
@PlxprInterpreter.register_primitive(jax.lax.iota_p)
445-
def _(self, *dyn_shape, dimension, dtype, shape, sharding):
445+
def _iota_primitive(self, *dyn_shape, dimension, dtype, shape, sharding):
446446
"""Handle the iota primitive created by jnp.arange
447447
448448
>>> import jax
@@ -646,7 +646,7 @@ class FlattenedInterpreter(PlxprInterpreter):
646646

647647

648648
@FlattenedInterpreter.register_primitive(pjit_p)
649-
def _(self, *invals, jaxpr, **params):
649+
def _pjit_primitive(self, *invals, jaxpr, **params):
650650
if jax.config.jax_dynamic_shapes:
651651
# just evaluate it so it doesn't throw dynamic shape errors
652652
return copy(self).eval(jaxpr.jaxpr, jaxpr.consts, *invals)

pennylane/control_flow/for_loop.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,9 @@ def _get_for_loop_qfunc_prim():
280280

281281
# pylint: disable=too-many-arguments
282282
@for_loop_prim.def_impl
283-
def _(start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice):
283+
def _impl(
284+
start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice, abstract_shapes_slice
285+
):
284286

285287
consts = args[consts_slice]
286288
init_state = args[args_slice]
@@ -296,7 +298,7 @@ def _(start, stop, step, *args, jaxpr_body_fn, consts_slice, args_slice, abstrac
296298

297299
# pylint: disable=unused-argument
298300
@for_loop_prim.def_abstract_eval
299-
def _(start, stop, step, *args, args_slice, abstract_shapes_slice, **_):
301+
def _abstract_eval(start, stop, step, *args, args_slice, abstract_shapes_slice, **_):
300302
return args[abstract_shapes_slice] + args[args_slice]
301303

302304
return for_loop_prim

pennylane/control_flow/while_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def _get_while_loop_qfunc_prim():
233233
register_custom_staging_rule(while_loop_prim, lambda params: params["jaxpr_body_fn"].outvars)
234234

235235
@while_loop_prim.def_impl
236-
def _(
236+
def _impl(
237237
*args,
238238
jaxpr_body_fn,
239239
jaxpr_cond_fn,
@@ -253,7 +253,7 @@ def _(
253253
return fn_res
254254

255255
@while_loop_prim.def_abstract_eval
256-
def _(*args, args_slice, **__):
256+
def _abstract_eval(*args, args_slice, **__):
257257
return args[args_slice]
258258

259259
return while_loop_prim

pennylane/data/base/typing_util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def get_type_str(cls: type | str | None) -> str: # pylint: disable=too-many-ret
7777
7878
Otherwise, returns the fully-qualified class name, including the module.
7979
"""
80+
# pylint: disable=unidiomatic-typecheck
81+
# Keep this check as it ensures that get_type_str(type(None)) = 'None'
82+
# rather than `NoneType`.
8083
if cls is None or cls is type(None):
8184
return "None"
8285

pennylane/decomposition/collect_resource_ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def interpret_operation(self, op):
3535

3636

3737
@CollectResourceOps.register_primitive(adjoint_transform_prim)
38-
def _(self, *invals, jaxpr, lazy, n_consts): # pylint: disable=unused-argument
38+
def _adjoint_transform_prim(
39+
self, *invals, jaxpr, lazy, n_consts
40+
): # pylint: disable=unused-argument
3941
"""Collect all operations in the base plxpr and create adjoint resource ops with them."""
4042
consts = invals[:n_consts]
4143
args = invals[n_consts:]
@@ -47,7 +49,7 @@ def _(self, *invals, jaxpr, lazy, n_consts): # pylint: disable=unused-argument
4749

4850

4951
@CollectResourceOps.register_primitive(ctrl_transform_prim)
50-
def _(self, *invals, n_control, jaxpr, n_consts, **params):
52+
def _ctrl_transform_prim(self, *invals, n_control, jaxpr, n_consts, **params):
5153
"""Collect all operations in the target plxpr and create controlled resource ops with them."""
5254

5355
consts = invals[:n_consts]

pennylane/devices/default_gaussian.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -517,33 +517,27 @@ def photon_number(cov, mu, params, hbar=2.0):
517517
return ex, var
518518

519519

520-
def homodyne(phi=None):
520+
def homodyne(phi: float | None = None):
521521
"""Function factory that returns the Homodyne expectation of a one mode state.
522522
523523
Args:
524-
phi (float): the default phase space axis to perform the Homodyne measurement
524+
phi (Optional[float]): the default phase space axis to perform the Homodyne measurement
525525
526526
Returns:
527527
function: A function that accepts a single mode means vector, covariance matrix,
528528
and phase space angle phi, and returns the quadrature expectation
529529
value and variance.
530530
"""
531-
if phi is not None:
532531

533-
def _homodyne(cov, mu, params, hbar=2.0):
534-
"""Arbitrary angle homodyne expectation."""
535-
# pylint: disable=unused-argument
536-
rot = rotation(phi)
537-
muphi = rot.T @ mu
538-
covphi = rot.T @ cov @ rot
539-
return muphi[0], covphi[0, 0]
532+
# pylint: disable=unused-argument
533+
def _homodyne(cov, mu, params, hbar=2.0):
534+
"""Calculates the arbitrary angle homodyne expectation."""
540535

541-
return _homodyne
536+
# Use the fixed outer `phi` if it was provided,
537+
# otherwise use the dynamic `phi` from the parameters.
538+
measurement_phi = phi if phi is not None else params[0]
542539

543-
def _homodyne(cov, mu, params, hbar=2.0):
544-
"""Arbitrary angle homodyne expectation."""
545-
# pylint: disable=unused-argument
546-
rot = rotation(params[0])
540+
rot = rotation(measurement_phi)
547541
muphi = rot.T @ mu
548542
covphi = rot.T @ cov @ rot
549543
return muphi[0], covphi[0, 0]

pennylane/devices/qubit/dq_interpreter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def _(self, *invals, reset, postselect):
231231

232232

233233
@DefaultQubitInterpreter.register_primitive(adjoint_transform_prim)
234-
def _(self, *invals, jaxpr, n_consts, lazy=True):
234+
def _adjoint_transform_prim(self, *invals, jaxpr, n_consts, lazy=True):
235235
consts = invals[:n_consts]
236236
args = invals[n_consts:]
237237
recorder = CollectOpsandMeas()
@@ -251,7 +251,7 @@ def _(self, *invals, jaxpr, n_consts, lazy=True):
251251

252252
# pylint: disable=too-many-arguments
253253
@DefaultQubitInterpreter.register_primitive(ctrl_transform_prim)
254-
def _(self, *invals, n_control, jaxpr, control_values, work_wires, n_consts):
254+
def _ctrl_transform_prim(self, *invals, n_control, jaxpr, control_values, work_wires, n_consts):
255255
consts = invals[:n_consts]
256256
control_wires = invals[-n_control:]
257257
args = invals[n_consts:-n_control]

0 commit comments

Comments
 (0)