|
63 | 63 | from ._version import kernel_protocol_version |
64 | 64 |
|
65 | 65 |
|
66 | | -def _accepts_cell_id(meth): |
| 66 | +def _accepts_parameters(meth, param_names): |
67 | 67 | parameters = inspect.signature(meth).parameters |
68 | | - cid_param = parameters.get("cell_id") |
69 | | - return (cid_param and cid_param.kind == cid_param.KEYWORD_ONLY) or any( |
70 | | - p.kind == p.VAR_KEYWORD for p in parameters.values() |
71 | | - ) |
| 68 | + accepts = {param: False for param in param_names} |
| 69 | + |
| 70 | + for param in param_names: |
| 71 | + param_spec = parameters.get(param) |
| 72 | + accepts[param] = ( |
| 73 | + param_spec |
| 74 | + and param_spec.kind in [param_spec.KEYWORD_ONLY, param_spec.POSITIONAL_OR_KEYWORD] |
| 75 | + ) or any(p.kind == p.VAR_KEYWORD for p in parameters.values()) |
| 76 | + |
| 77 | + return accepts |
72 | 78 |
|
73 | 79 |
|
74 | 80 | class Kernel(SingletonConfigurable): |
@@ -735,25 +741,28 @@ async def execute_request(self, stream, ident, parent): |
735 | 741 | self.execution_count += 1 |
736 | 742 | self._publish_execute_input(code, parent, self.execution_count) |
737 | 743 |
|
738 | | - cell_id = (parent.get("metadata") or {}).get("cellId") |
| 744 | + cell_meta = parent.get("metadata", {}) |
| 745 | + cell_id = cell_meta.get("cellId") |
739 | 746 |
|
740 | | - if _accepts_cell_id(self.do_execute): |
741 | | - reply_content = self.do_execute( |
742 | | - code, |
743 | | - silent, |
744 | | - store_history, |
745 | | - user_expressions, |
746 | | - allow_stdin, |
747 | | - cell_id=cell_id, |
748 | | - ) |
749 | | - else: |
750 | | - reply_content = self.do_execute( |
751 | | - code, |
752 | | - silent, |
753 | | - store_history, |
754 | | - user_expressions, |
755 | | - allow_stdin, |
756 | | - ) |
| 747 | + # Check which parameters do_execute can accept |
| 748 | + accepts_params = _accepts_parameters(self.do_execute, ["cell_meta", "cell_id"]) |
| 749 | + |
| 750 | + # Arguments based on the do_execute signature |
| 751 | + do_execute_args = { |
| 752 | + "code": code, |
| 753 | + "silent": silent, |
| 754 | + "store_history": store_history, |
| 755 | + "user_expressions": user_expressions, |
| 756 | + "allow_stdin": allow_stdin, |
| 757 | + } |
| 758 | + |
| 759 | + if accepts_params["cell_meta"]: |
| 760 | + do_execute_args["cell_meta"] = cell_meta |
| 761 | + if accepts_params["cell_id"]: |
| 762 | + do_execute_args["cell_id"] = cell_id |
| 763 | + |
| 764 | + # Call do_execute with the appropriate arguments |
| 765 | + reply_content = self.do_execute(**do_execute_args) |
757 | 766 |
|
758 | 767 | if inspect.isawaitable(reply_content): |
759 | 768 | reply_content = await reply_content |
@@ -793,6 +802,7 @@ def do_execute( |
793 | 802 | user_expressions=None, |
794 | 803 | allow_stdin=False, |
795 | 804 | *, |
| 805 | + cell_meta=None, |
796 | 806 | cell_id=None, |
797 | 807 | ): |
798 | 808 | """Execute user code. Must be overridden by subclasses.""" |
|
0 commit comments