|
5 | 5 | LICENSE file in the root directory of this source tree. |
6 | 6 | """ |
7 | 7 |
|
| 8 | +import inspect |
8 | 9 | import json |
9 | 10 | import logging |
10 | 11 | from abc import ABC |
@@ -225,6 +226,13 @@ def __init_subclass__(cls, **kwargs): |
225 | 226 | "The DEPRECATED 'process' method must not be implemented " |
226 | 227 | "alongside 'process_input' or 'process_response'." |
227 | 228 | ) |
| 229 | + if is_process_overridden and inspect.iscoroutinefunction(cls.process): |
| 230 | + # we don't want to add async capabilities to the deprecated function |
| 231 | + raise TypeError( |
| 232 | + f"Cannot create concrete class {cls.__name__}. " |
| 233 | + "The DEPRECATED 'process' method does not support async. " |
| 234 | + "Implement 'process_input' and/or 'process_response' instead." |
| 235 | + ) |
228 | 236 |
|
229 | 237 | return |
230 | 238 |
|
@@ -875,15 +883,18 @@ async def _parse_and_process(self, request: Request) -> Response: |
875 | 883 | prompt_hash, response_hash = (None, None) |
876 | 884 | if input_direction: |
877 | 885 | prompt_hash = prompt.hash() |
878 | | - result: Result | Reject = self.process_input( |
| 886 | + result = await self._handle_process_function( |
| 887 | + self.process_input, |
879 | 888 | metadata=metadata, |
880 | 889 | parameters=parameters, |
881 | 890 | prompt=prompt, |
882 | 891 | request=request, |
883 | 892 | ) |
| 893 | + |
884 | 894 | else: |
885 | 895 | response_hash = response.hash() |
886 | | - result: Result | Reject = self.process_response( |
| 896 | + result = await self._handle_process_function( |
| 897 | + self.process_response, |
887 | 898 | metadata=metadata, |
888 | 899 | parameters=parameters, |
889 | 900 | prompt=prompt, |
@@ -1014,7 +1025,16 @@ def _is_method_overridden(self, method_name: str) -> bool: |
1014 | 1025 | # the method object directly from the Processor class, then it has been overridden. |
1015 | 1026 | return instance_class_method_obj is not base_class_method_obj |
1016 | 1027 |
|
1017 | | - def process_input( |
| 1028 | + async def _process_fallback(self, **kwargs) -> Result | Reject: |
| 1029 | + warnings.warn( |
| 1030 | + f"{type(self).__name__} uses the deprecated 'process' method. " |
| 1031 | + "Implement 'process_input' and/or 'process_response' instead.", |
| 1032 | + DeprecationWarning, |
| 1033 | + stacklevel=2, |
| 1034 | + ) |
| 1035 | + return await self._handle_process_function(self.process, **kwargs) |
| 1036 | + |
| 1037 | + async def process_input( |
1018 | 1038 | self, |
1019 | 1039 | prompt: PROMPT, |
1020 | 1040 | metadata: Metadata, |
@@ -1043,26 +1063,20 @@ def process_input(self, prompt, response, metadata, parameters, request): |
1043 | 1063 |
|
1044 | 1064 | return Result(processor_result=result) |
1045 | 1065 | """ |
1046 | | - if self._is_method_overridden("process"): |
1047 | | - warnings.warn( |
1048 | | - f"{type(self).__name__} uses the deprecated 'process' method for input. " |
1049 | | - "Implement 'process_input' instead.", |
1050 | | - DeprecationWarning, |
1051 | | - stacklevel=2, # Points the warning to the caller of process_input |
| 1066 | + if not self._is_method_overridden("process"): |
| 1067 | + raise NotImplementedError( |
| 1068 | + f"{type(self).__name__} must implement 'process_input' or the " |
| 1069 | + "deprecated 'process' method to handle input." |
1052 | 1070 | ) |
1053 | | - return self.process( |
1054 | | - prompt=prompt, |
1055 | | - response=None, |
1056 | | - metadata=metadata, |
1057 | | - parameters=parameters, |
1058 | | - request=request, |
1059 | | - ) |
1060 | | - raise NotImplementedError( |
1061 | | - f"{type(self).__name__} must implement 'process_input' or the " |
1062 | | - "deprecated 'process' method to handle input." |
| 1071 | + return await self._process_fallback( |
| 1072 | + prompt=prompt, |
| 1073 | + response=None, |
| 1074 | + metadata=metadata, |
| 1075 | + parameters=parameters, |
| 1076 | + request=request, |
1063 | 1077 | ) |
1064 | 1078 |
|
1065 | | - def process_response( |
| 1079 | + async def process_response( |
1066 | 1080 | self, |
1067 | 1081 | prompt: PROMPT | None, |
1068 | 1082 | response: RESPONSE, |
@@ -1096,23 +1110,17 @@ def process_response(self, prompt, response, metadata, parameters, request): |
1096 | 1110 | return Result(processor_result=result) |
1097 | 1111 | """ |
1098 | 1112 |
|
1099 | | - if self._is_method_overridden("process"): |
1100 | | - warnings.warn( |
1101 | | - f"{type(self).__name__} uses the deprecated 'process' method for response. " |
1102 | | - "Implement 'process_response' instead.", |
1103 | | - DeprecationWarning, |
1104 | | - stacklevel=2, # Points the warning to the caller of process_input |
| 1113 | + if not self._is_method_overridden("process"): |
| 1114 | + raise NotImplementedError( |
| 1115 | + f"{type(self).__name__} must implement 'process_response' or the " |
| 1116 | + "deprecated 'process' method to handle input." |
1105 | 1117 | ) |
1106 | | - return self.process( |
1107 | | - prompt=prompt, |
1108 | | - response=response, |
1109 | | - metadata=metadata, |
1110 | | - parameters=parameters, |
1111 | | - request=request, |
1112 | | - ) |
1113 | | - raise NotImplementedError( |
1114 | | - f"{type(self).__name__} must implement 'process_response' or the " |
1115 | | - "deprecated 'process' method to handle input." |
| 1118 | + return await self._process_fallback( |
| 1119 | + prompt=prompt, |
| 1120 | + response=response, |
| 1121 | + metadata=metadata, |
| 1122 | + parameters=parameters, |
| 1123 | + request=request, |
1116 | 1124 | ) |
1117 | 1125 |
|
1118 | 1126 | def process( |
@@ -1159,6 +1167,13 @@ def process(self, prompt, response, metadata, parameters, request): |
1159 | 1167 | "'process_input'/'process_response'." |
1160 | 1168 | ) |
1161 | 1169 |
|
| 1170 | + async def _handle_process_function(self, func, **kwargs) -> Result | Reject: |
| 1171 | + if inspect.iscoroutinefunction(func): |
| 1172 | + result = await func(**kwargs) |
| 1173 | + else: |
| 1174 | + result = func(**kwargs) |
| 1175 | + return result |
| 1176 | + |
1162 | 1177 |
|
1163 | 1178 | def _validation_error_as_messages(err: ValidationError) -> list[str]: |
1164 | 1179 | return [_error_details_to_str(e) for e in err.errors()] |
|
0 commit comments