diff --git a/mellea/stdlib/funcs.py b/mellea/stdlib/funcs.py index 5ff344f5..5e9ff9df 100644 --- a/mellea/stdlib/funcs.py +++ b/mellea/stdlib/funcs.py @@ -144,9 +144,9 @@ async def _act( "Must provide a SamplingStrategy when return_sampling_results==True" ) - # if there is no reason to sample, just generate from the context. - if strategy is None or requirements is None or len(requirements) == 0: - if strategy is None and requirements is not None: + if strategy is None: + # Only use the strategy if one is provided. Add a warning if requirements were passed in though. + if requirements is not None and len(requirements) >= 0: FancyLogger.get_logger().warning( "Calling the function with NO strategy BUT requirements. No requirement is being checked!" ) @@ -394,6 +394,7 @@ async def _validate( # Turn a solitary requirement in to a list of requirements, and then reqify if needed. reqs = [reqs] if not isinstance(reqs, list) else reqs reqs = [Requirement(req) if type(req) is str else req for req in reqs] + if output is None: validation_target_ctx = context else: diff --git a/mellea/stdlib/sampling/base.py b/mellea/stdlib/sampling/base.py index fae5a922..12532f7b 100644 --- a/mellea/stdlib/sampling/base.py +++ b/mellea/stdlib/sampling/base.py @@ -86,7 +86,7 @@ async def sample( action: Component, context: Context, backend: Backend, - requirements: list[Requirement], + requirements: list[Requirement] | None, *, validation_ctx: Context | None = None, format: type[BaseModelSubclass] | None = None, @@ -123,7 +123,7 @@ async def sample( show_progress = show_progress and flog.getEffectiveLevel() <= FancyLogger.INFO reqs = [] - # global requirements supersede local requirements (global requiremenst can be defined by user) + # global requirements supersede local requirements (global requirements can be defined by user) # Todo: re-evaluate if this makes sense if self.requirements is not None: reqs += self.requirements diff --git a/mellea/stdlib/sampling/best_of_n.py b/mellea/stdlib/sampling/best_of_n.py index b165d8d6..53cfa83c 100644 --- a/mellea/stdlib/sampling/best_of_n.py +++ b/mellea/stdlib/sampling/best_of_n.py @@ -22,7 +22,7 @@ async def sample( action: Component, context: Context, backend: Backend, - requirements: list[Requirement], + requirements: list[Requirement] | None, *, validation_ctx: Context | None = None, format: type[BaseModelSubclass] | None = None, diff --git a/mellea/stdlib/sampling/types.py b/mellea/stdlib/sampling/types.py index a7068999..e5b81130 100644 --- a/mellea/stdlib/sampling/types.py +++ b/mellea/stdlib/sampling/types.py @@ -81,7 +81,7 @@ async def sample( action: Component, context: Context, backend: Backend, - requirements: list[Requirement], + requirements: list[Requirement] | None, *, validation_ctx: Context | None = None, format: type[BaseModelSubclass] | None = None, diff --git a/test/stdlib_basics/test_vision_ollama.py b/test/stdlib_basics/test_vision_ollama.py index d0c0ed1c..275b5db3 100644 --- a/test/stdlib_basics/test_vision_ollama.py +++ b/test/stdlib_basics/test_vision_ollama.py @@ -61,7 +61,9 @@ def test_image_block_construction_from_pil(pil_image: Image.Image): def test_image_block_in_instruction(m_session: MelleaSession, pil_image: Image.Image, gh_run: int): image_block = ImageBlock.from_pil_image(pil_image) - instr = m_session.instruct("Is this image mainly blue? Answer yes or no.", images=[image_block]) + + # Set strategy=None here since we are directly comparing the object and sampling strategies tend to do a deepcopy. + instr = m_session.instruct("Is this image mainly blue? Answer yes or no.", images=[image_block], strategy=None) assert isinstance(instr, ModelOutputThunk) # if not on GH diff --git a/test/stdlib_basics/test_vision_openai.py b/test/stdlib_basics/test_vision_openai.py index 385a8ffe..4cf22415 100644 --- a/test/stdlib_basics/test_vision_openai.py +++ b/test/stdlib_basics/test_vision_openai.py @@ -65,7 +65,9 @@ def test_image_block_construction_from_pil(pil_image: Image.Image): def test_image_block_in_instruction(m_session: MelleaSession, pil_image: Image.Image, gh_run: int): image_block = ImageBlock.from_pil_image(pil_image) - instr = m_session.instruct("Is this image mainly blue? Answer yes or no.", images=[image_block]) + + # Set strategy=None here since we are directly comparing the object and sampling strategies tend to do a deepcopy. + instr = m_session.instruct("Is this image mainly blue? Answer yes or no.", images=[image_block], strategy=None) assert isinstance(instr, ModelOutputThunk) # if not on GH