Skip to content

Commit 8fece40

Browse files
authored
fix: always call sample when a strategy is provided (#176)
* fix: always call sample when a strategy is provided; add warning to validate if no requirements * fix: tests that were doing object comparisons with sampling * fix: add comment to changed tests * fix: remove validation warning message
1 parent 342b404 commit 8fece40

File tree

6 files changed

+14
-9
lines changed

6 files changed

+14
-9
lines changed

mellea/stdlib/funcs.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,9 @@ async def _act(
144144
"Must provide a SamplingStrategy when return_sampling_results==True"
145145
)
146146

147-
# if there is no reason to sample, just generate from the context.
148-
if strategy is None or requirements is None or len(requirements) == 0:
149-
if strategy is None and requirements is not None:
147+
if strategy is None:
148+
# Only use the strategy if one is provided. Add a warning if requirements were passed in though.
149+
if requirements is not None and len(requirements) >= 0:
150150
FancyLogger.get_logger().warning(
151151
"Calling the function with NO strategy BUT requirements. No requirement is being checked!"
152152
)
@@ -394,6 +394,7 @@ async def _validate(
394394
# Turn a solitary requirement in to a list of requirements, and then reqify if needed.
395395
reqs = [reqs] if not isinstance(reqs, list) else reqs
396396
reqs = [Requirement(req) if type(req) is str else req for req in reqs]
397+
397398
if output is None:
398399
validation_target_ctx = context
399400
else:

mellea/stdlib/sampling/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ async def sample(
8686
action: Component,
8787
context: Context,
8888
backend: Backend,
89-
requirements: list[Requirement],
89+
requirements: list[Requirement] | None,
9090
*,
9191
validation_ctx: Context | None = None,
9292
format: type[BaseModelSubclass] | None = None,
@@ -123,7 +123,7 @@ async def sample(
123123
show_progress = show_progress and flog.getEffectiveLevel() <= FancyLogger.INFO
124124

125125
reqs = []
126-
# global requirements supersede local requirements (global requiremenst can be defined by user)
126+
# global requirements supersede local requirements (global requirements can be defined by user)
127127
# Todo: re-evaluate if this makes sense
128128
if self.requirements is not None:
129129
reqs += self.requirements

mellea/stdlib/sampling/best_of_n.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ async def sample(
2222
action: Component,
2323
context: Context,
2424
backend: Backend,
25-
requirements: list[Requirement],
25+
requirements: list[Requirement] | None,
2626
*,
2727
validation_ctx: Context | None = None,
2828
format: type[BaseModelSubclass] | None = None,

mellea/stdlib/sampling/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ async def sample(
8181
action: Component,
8282
context: Context,
8383
backend: Backend,
84-
requirements: list[Requirement],
84+
requirements: list[Requirement] | None,
8585
*,
8686
validation_ctx: Context | None = None,
8787
format: type[BaseModelSubclass] | None = None,

test/stdlib_basics/test_vision_ollama.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def test_image_block_construction_from_pil(pil_image: Image.Image):
6161

6262
def test_image_block_in_instruction(m_session: MelleaSession, pil_image: Image.Image, gh_run: int):
6363
image_block = ImageBlock.from_pil_image(pil_image)
64-
instr = m_session.instruct("Is this image mainly blue? Answer yes or no.", images=[image_block])
64+
65+
# Set strategy=None here since we are directly comparing the object and sampling strategies tend to do a deepcopy.
66+
instr = m_session.instruct("Is this image mainly blue? Answer yes or no.", images=[image_block], strategy=None)
6567
assert isinstance(instr, ModelOutputThunk)
6668

6769
# if not on GH

test/stdlib_basics/test_vision_openai.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def test_image_block_construction_from_pil(pil_image: Image.Image):
6565

6666
def test_image_block_in_instruction(m_session: MelleaSession, pil_image: Image.Image, gh_run: int):
6767
image_block = ImageBlock.from_pil_image(pil_image)
68-
instr = m_session.instruct("Is this image mainly blue? Answer yes or no.", images=[image_block])
68+
69+
# Set strategy=None here since we are directly comparing the object and sampling strategies tend to do a deepcopy.
70+
instr = m_session.instruct("Is this image mainly blue? Answer yes or no.", images=[image_block], strategy=None)
6971
assert isinstance(instr, ModelOutputThunk)
7072

7173
# if not on GH

0 commit comments

Comments
 (0)