Skip to content

Commit 7b2701a

Browse files
authored
chore: clean-up type hints in qnode.py and resolution.py (#8086)
**Context:** Fix some annoying `mypy` complaints. **Benefits:** Better code.
1 parent 07608a3 commit 7b2701a

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

doc/releases/changelog-dev.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,9 @@
462462

463463
<h3>Internal changes ⚙️</h3>
464464

465+
* Improve type hinting internally.
466+
[(#8086)](https://github.com/PennyLaneAI/pennylane/pull/8086)
467+
465468
* The `cond` primitive with program capture no longer stores missing false branches as `None`, instead storing them
466469
as jaxprs with no output.
467470
[(#8080)](https://github.com/PennyLaneAI/pennylane/pull/8080)

pennylane/workflow/qnode.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,16 @@
4444
logger.addHandler(logging.NullHandler())
4545

4646
if TYPE_CHECKING:
47+
from typing import TypeAlias
48+
4749
from pennylane.concurrency.executors import ExecBackends
4850
from pennylane.devices import Device, LegacyDevice
4951
from pennylane.math import SupportedInterfaceUserInput
5052
from pennylane.transforms.core import TransformContainer
5153
from pennylane.typing import Result
5254
from pennylane.workflow.resolution import SupportedDiffMethods
5355

54-
SupportedDeviceAPIs = LegacyDevice | Device
56+
SupportedDeviceAPIs: TypeAlias = LegacyDevice | Device
5557

5658

5759
def _convert_to_interface(result, interface: Interface):
@@ -119,7 +121,10 @@ def _to_qfunc_output_type(results: Result, qfunc_output, has_partitioned_shots)
119121
return pytrees.unflatten(results, qfunc_output_structure)
120122

121123

122-
def _validate_mcm_config(postselect_mode: str, mcm_method: str) -> None:
124+
def _validate_mcm_config(
125+
postselect_mode: Literal["hw-like", "fill-shots"] | None,
126+
mcm_method: Literal["deferred", "one-shot", "tree-traversal"] | None,
127+
) -> None:
123128
qml.devices.MCMConfig(postselect_mode=postselect_mode, mcm_method=mcm_method)
124129

125130

@@ -581,7 +586,7 @@ def __init__(
581586
self.diff_method = diff_method
582587
_validate_diff_method(self.device, self.diff_method)
583588

584-
self.capture_cache = LRUCache(maxsize=1000)
589+
self.capture_cache: LRUCache = LRUCache(maxsize=1000)
585590
if isinstance(static_argnums, int):
586591
static_argnums = (static_argnums,)
587592
self.static_argnums = sorted(static_argnums)
@@ -727,7 +732,7 @@ def circuit(x):
727732
original_init_args["gradient_kwargs"] = original_init_args["gradient_kwargs"] or {}
728733
# nested dictionary update
729734
new_gradient_kwargs = kwargs.pop("gradient_kwargs", {})
730-
old_gradient_kwargs = original_init_args.get("gradient_kwargs").copy()
735+
old_gradient_kwargs = (original_init_args.get("gradient_kwargs", {})).copy()
731736
old_gradient_kwargs.update(new_gradient_kwargs)
732737
kwargs["gradient_kwargs"] = old_gradient_kwargs
733738

pennylane/workflow/resolution.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515

1616
from __future__ import annotations
1717

18-
from collections.abc import Callable
1918
from copy import copy
2019
from dataclasses import replace
2120
from importlib.metadata import version
2221
from importlib.util import find_spec
23-
from typing import TYPE_CHECKING, Literal, get_args
22+
from typing import TYPE_CHECKING, Any, Literal, get_args
2423
from warnings import warn
2524

2625
from packaging.version import Version
@@ -146,7 +145,7 @@ def _resolve_mcm_config(
146145
) -> qml.devices.MCMConfig:
147146
"""Helper function to resolve the mid-circuit measurements configuration based on
148147
execution parameters"""
149-
updated_values = {}
148+
updated_values: dict[str, Any] = {}
150149

151150
if not finite_shots:
152151
updated_values["postselect_mode"] = None
@@ -287,10 +286,10 @@ def _resolve_execution_config(
287286
Returns:
288287
qml.devices.ExecutionConfig: resolved execution configuration
289288
"""
290-
updated_values = {}
289+
updated_values: dict[str, Any] = {}
291290

292-
if execution_config.interface in {Interface.JAX, Interface.JAX_JIT} and not isinstance(
293-
execution_config.gradient_method, Callable
291+
if execution_config.interface in {Interface.JAX, Interface.JAX_JIT} and not callable(
292+
execution_config.gradient_method
294293
):
295294
updated_values["grad_on_execution"] = False
296295
execution_config = _resolve_diff_method(execution_config, device, tape=tapes[0])

0 commit comments

Comments
 (0)