|
44 | 44 | logger.addHandler(logging.NullHandler()) |
45 | 45 |
|
46 | 46 | if TYPE_CHECKING: |
| 47 | + from typing import TypeAlias |
| 48 | + |
47 | 49 | from pennylane.concurrency.executors import ExecBackends |
48 | 50 | from pennylane.devices import Device, LegacyDevice |
49 | 51 | from pennylane.math import SupportedInterfaceUserInput |
50 | 52 | from pennylane.transforms.core import TransformContainer |
51 | 53 | from pennylane.typing import Result |
52 | 54 | from pennylane.workflow.resolution import SupportedDiffMethods |
53 | 55 |
|
54 | | - SupportedDeviceAPIs = LegacyDevice | Device |
| 56 | + SupportedDeviceAPIs: TypeAlias = LegacyDevice | Device |
55 | 57 |
|
56 | 58 |
|
57 | 59 | def _convert_to_interface(result, interface: Interface): |
@@ -119,7 +121,10 @@ def _to_qfunc_output_type(results: Result, qfunc_output, has_partitioned_shots) |
119 | 121 | return pytrees.unflatten(results, qfunc_output_structure) |
120 | 122 |
|
121 | 123 |
|
122 | | -def _validate_mcm_config(postselect_mode: str, mcm_method: str) -> None: |
| 124 | +def _validate_mcm_config( |
| 125 | + postselect_mode: Literal["hw-like", "fill-shots"] | None, |
| 126 | + mcm_method: Literal["deferred", "one-shot", "tree-traversal"] | None, |
| 127 | +) -> None: |
123 | 128 | qml.devices.MCMConfig(postselect_mode=postselect_mode, mcm_method=mcm_method) |
124 | 129 |
|
125 | 130 |
|
@@ -581,7 +586,7 @@ def __init__( |
581 | 586 | self.diff_method = diff_method |
582 | 587 | _validate_diff_method(self.device, self.diff_method) |
583 | 588 |
|
584 | | - self.capture_cache = LRUCache(maxsize=1000) |
| 589 | + self.capture_cache: LRUCache = LRUCache(maxsize=1000) |
585 | 590 | if isinstance(static_argnums, int): |
586 | 591 | static_argnums = (static_argnums,) |
587 | 592 | self.static_argnums = sorted(static_argnums) |
@@ -727,7 +732,7 @@ def circuit(x): |
727 | 732 | original_init_args["gradient_kwargs"] = original_init_args["gradient_kwargs"] or {} |
728 | 733 | # nested dictionary update |
729 | 734 | new_gradient_kwargs = kwargs.pop("gradient_kwargs", {}) |
730 | | - old_gradient_kwargs = original_init_args.get("gradient_kwargs").copy() |
| 735 | + old_gradient_kwargs = (original_init_args.get("gradient_kwargs", {})).copy() |
731 | 736 | old_gradient_kwargs.update(new_gradient_kwargs) |
732 | 737 | kwargs["gradient_kwargs"] = old_gradient_kwargs |
733 | 738 |
|
|
0 commit comments