Skip to content

Commit d0089d1

Browse files
pass optional input to validate fn
1 parent 00fd928 commit d0089d1

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

mellea/stdlib/session.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -300,9 +300,11 @@ def act(
300300
else:
301301
# Default validation strategy just validates all of the provided requirements.
302302
if strategy.validate is None:
303-
strategy.validate = lambda reqs, val_ctx, output: self.validate(
304-
reqs, output=output
305-
)
303+
strategy.validate = (
304+
lambda reqs, val_ctx, output, input=None: self.validate( # type: ignore
305+
reqs, output=output, input=input
306+
)
307+
) # type: ignore
306308

307309
# Default generation strategy just generates from context.
308310
if strategy.generate is None:
@@ -483,6 +485,7 @@ def validate(
483485
format: type[BaseModelSubclass] | None = None,
484486
model_options: dict | None = None,
485487
generate_logs: list[GenerateLog] | None = None,
488+
input: CBlock | None = None,
486489
) -> list[ValidationResult]:
487490
"""Validates a set of requirements over the output (if provided) or the current context (if the output is not provided)."""
488491
# Turn a solitary requirement in to a list of requirements, and then reqify if needed.
@@ -492,7 +495,17 @@ def validate(
492495
validation_target_ctx = self.ctx
493496
else:
494497
validation_target_ctx = SimpleContext()
495-
validation_target_ctx.insert(output)
498+
if input is not None:
499+
# some validators may need input as well as output
500+
validation_target_ctx.insert_turn(
501+
ContextTurn(
502+
input,
503+
output, # type: ignore
504+
), # type: ignore
505+
generate_logs=generate_logs,
506+
)
507+
else:
508+
validation_target_ctx.insert(output)
496509
rvs = []
497510
for requirement in reqs:
498511
val_result = requirement.validate(

0 commit comments

Comments
 (0)