|
8 | 8 | cast, |
9 | 9 | ) |
10 | 10 |
|
11 | | - |
| 11 | +from guardrails.validator_service import AsyncValidatorService |
12 | 12 | from guardrails.actions.reask import SkeletonReAsk |
13 | 13 | from guardrails.classes import ValidationOutcome |
14 | 14 | from guardrails.classes.history import Call, Inputs, Iteration, Outputs |
15 | 15 | from guardrails.classes.output_type import OutputTypes |
16 | | -from guardrails.constants import pass_status |
17 | 16 | from guardrails.llm_providers import ( |
18 | 17 | AsyncLiteLLMCallable, |
19 | 18 | AsyncPromptCallableBase, |
|
28 | 27 | from guardrails.run.async_runner import AsyncRunner |
29 | 28 | from guardrails.telemetry import trace_async_stream_step |
30 | 29 | from guardrails.hub_telemetry.hub_tracing import async_trace_stream |
| 30 | +from guardrails.types import OnFailAction |
| 31 | +from guardrails.classes.validation.validation_result import ( |
| 32 | + PassResult, |
| 33 | + FailResult, |
| 34 | +) |
31 | 35 |
|
32 | 36 |
|
33 | 37 | class AsyncStreamRunner(AsyncRunner, StreamRunner): |
@@ -133,49 +137,113 @@ async def async_step( |
133 | 137 | parsed_fragment, validated_fragment, valid_op = None, None, None |
134 | 138 | verified = set() |
135 | 139 | validation_response = "" |
| 140 | + validation_progress = {} |
| 141 | + refrain_triggered = False |
| 142 | + validation_passed = True |
136 | 143 |
|
137 | 144 | if self.output_type == OutputTypes.STRING: |
| 145 | + validator_service = AsyncValidatorService(self.disable_tracer) |
138 | 146 | async for chunk in stream_output: |
139 | 147 | chunk_text = self.get_chunk_text(chunk, api) |
140 | 148 | _ = self.is_last_chunk(chunk, api) |
141 | | - fragment += chunk_text |
142 | 149 |
|
143 | | - parsed_chunk, move_to_next = self.parse( |
144 | | - chunk_text, output_schema, verified=verified |
145 | | - ) |
146 | | - if move_to_next: |
147 | | - continue |
148 | | - validated_fragment = await self.async_validate( |
| 150 | + fragment += chunk_text |
| 151 | + results = await validator_service.async_partial_validate( |
| 152 | + chunk_text, |
| 153 | + self.metadata, |
| 154 | + self.validation_map, |
149 | 155 | iteration, |
150 | | - index, |
151 | | - parsed_chunk, |
152 | | - output_schema, |
153 | | - validate_subschema=True, |
154 | | - stream=True, |
| 156 | + "$", |
| 157 | + "$", |
| 158 | + True, |
155 | 159 | ) |
156 | | - # TODO why? how does it happen in the other places we handle streams |
157 | | - if validated_fragment is None: |
158 | | - validated_fragment = "" |
159 | | - |
160 | | - if isinstance(validated_fragment, SkeletonReAsk): |
161 | | - raise ValueError( |
162 | | - "Received fragment schema is an invalid sub-schema " |
163 | | - "of the expected output JSON schema." |
| 160 | + validators = self.validation_map["$"] or [] |
| 161 | + # collect the result validated_chunk into validation progress |
| 162 | + # per validator |
| 163 | + for result in results: |
| 164 | + validator_log = result.validator_logs # type: ignore |
| 165 | + validator = next( |
| 166 | + filter( |
| 167 | + lambda x: x.rail_alias == validator_log.registered_name, |
| 168 | + validators, |
| 169 | + ), |
| 170 | + None, |
164 | 171 | ) |
| 172 | + if ( |
| 173 | + validator_log.validation_result |
| 174 | + and validator_log.validation_result.validated_chunk |
| 175 | + ): |
| 176 | + is_filter = validator.on_fail_descriptor is OnFailAction.FILTER # type: ignore |
| 177 | + is_refrain = ( |
| 178 | + validator.on_fail_descriptor is OnFailAction.REFRAIN # type: ignore |
| 179 | + ) |
| 180 | + if validator_log.validation_result.outcome == "fail": |
| 181 | + validation_passed = False |
| 182 | + reasks, valid_op = self.introspect( |
| 183 | + validator_log.validation_result |
| 184 | + ) |
| 185 | + if reasks: |
| 186 | + raise ValueError( |
| 187 | + "Reasks are not yet supported with streaming. Please " |
| 188 | + "remove reasks from schema or disable streaming." |
| 189 | + ) |
165 | 190 |
|
166 | | - reasks, valid_op = self.introspect(validated_fragment) |
167 | | - if reasks: |
168 | | - raise ValueError( |
169 | | - "Reasks are not yet supported with streaming. Please " |
170 | | - "remove reasks from schema or disable streaming." |
| 191 | + if isinstance(validator_log.validation_result, PassResult): |
| 192 | + chunk = validator_log.validation_result.validated_chunk |
| 193 | + elif isinstance(validator_log.validation_result, FailResult): |
| 194 | + if is_filter or is_refrain: |
| 195 | + refrain_triggered = True |
| 196 | + chunk = "" |
| 197 | + else: |
| 198 | + chunk = validator_service.perform_correction( |
| 199 | + validator_log.validation_result, |
| 200 | + validator_log.validation_result.validated_chunk, |
| 201 | + validator, # type: ignore |
| 202 | + rechecked_value=None, |
| 203 | + ) # type: ignore |
| 204 | + |
| 205 | + if not hasattr( |
| 206 | + validation_progress, validator_log.validator_name |
| 207 | + ): |
| 208 | + validation_progress[validator_log.validator_name] = "" |
| 209 | + |
| 210 | + validation_progress[validator_log.validator_name] += chunk |
| 211 | + # if there is an entry for every validator |
| 212 | + # run a merge and emit a validation outcome |
| 213 | + if len(validation_progress) == len(validators): |
| 214 | + if refrain_triggered: |
| 215 | + current = "" |
| 216 | + else: |
| 217 | + merge_chunks = [] |
| 218 | + for piece in validation_progress: |
| 219 | + merge_chunks.append(validation_progress[piece]) |
| 220 | + |
| 221 | + current = validator_service.multi_merge(fragment, merge_chunks) |
| 222 | + |
| 223 | + vo = ValidationOutcome( |
| 224 | + call_id=call_log.id, # type: ignore |
| 225 | + raw_llm_output=fragment, |
| 226 | + validated_output=current, |
| 227 | + validation_passed=True, |
171 | 228 | ) |
172 | | - validation_response += validated_fragment |
173 | | - passed = call_log.status == pass_status |
| 229 | + fragment = "" |
| 230 | + validation_progress = {} |
| 231 | + refrain_triggered = False |
| 232 | + |
| 233 | + yield vo |
| 234 | + |
| 235 | + # if theres anything left merge and emit a chunk |
| 236 | + if len(validation_progress) > 0: |
| 237 | + merge_chunks = [] |
| 238 | + for piece in validation_progress: |
| 239 | + merge_chunks.append(validation_progress[piece]) |
| 240 | + |
| 241 | + current = validator_service.multi_merge(fragment, merge_chunks) |
174 | 242 | yield ValidationOutcome( |
175 | 243 | call_id=call_log.id, # type: ignore |
176 | | - raw_llm_output=chunk_text, |
177 | | - validated_output=validated_fragment, |
178 | | - validation_passed=passed, |
| 244 | + raw_llm_output=fragment, |
| 245 | + validated_output=current, |
| 246 | + validation_passed=validation_passed, |
179 | 247 | ) |
180 | 248 | else: |
181 | 249 | async for chunk in stream_output: |
|
0 commit comments