diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml index 7634b8be..d68b457f 100644 --- a/.github/workflows/quality.yml +++ b/.github/workflows/quality.yml @@ -17,6 +17,9 @@ jobs: strategy: matrix: python-version: ['3.10', '3.11', '3.12'] # Need to add 3.13 once we resolve outlines issues. + env: + CICD: 1 + OLLAMA_HOST: "127.0.0.1:5000" steps: - uses: actions/checkout@v4 - name: Install uv and set the python version @@ -31,9 +34,22 @@ jobs: path: ~/.cache/pre-commit key: pre-commit|${{ env.PY }}|${{ hashFiles('.pre-commit-config.yaml') }} - name: Install dependencies - run: uv sync --frozen --all-extras + run: uv sync --frozen --all-extras --group dev - name: Check style and run tests run: pre-commit run --all-files - - name: Send failure message + - name: Send failure message pre-commit if: failure() # This step will only run if a previous step failed run: echo "The quality verification failed. Please run precommit " + - name: Install Ollama + run: curl -fsSL https://ollama.com/install.sh | sh + - name: Start serving ollama + run: nohup ollama serve & + - name: Pull Llama 3.2:1b model + run: ollama pull llama3.2:1b + + - name: Run Tests + run: uv run -m pytest -v test + - name: Send failure message tests + if: failure() # This step will only run if a previous step failed + run: echo "Tests failed. Please verify that tests are working locally." + diff --git a/mellea/backends/model_ids.py b/mellea/backends/model_ids.py index 40db617f..a750998c 100644 --- a/mellea/backends/model_ids.py +++ b/mellea/backends/model_ids.py @@ -89,6 +89,10 @@ class ModelIdentifier: ollama_name="llama-guard3:1b", hf_model_name="unsloth/Llama-Guard-3-1B" ) +META_LLAMA_3_2_1B = ModelIdentifier( + ollama_name="llama3.2:1b", hf_model_name="unsloth/Llama-3.2-1B" +) + ######################## #### Mistral models #### ######################## diff --git a/mellea/stdlib/mify.py b/mellea/stdlib/mify.py index 639bbc80..5b278f2d 100644 --- a/mellea/stdlib/mify.py +++ b/mellea/stdlib/mify.py @@ -132,8 +132,8 @@ def _get_all_fields(self) -> dict[str, Any]: if self._fields_exclude: fields_exclude = self._fields_exclude - # This includes fields defined by any superclasses, as long as it's not object. - all_fields = _get_non_duplicate_fields(self, object) + # This includes fields defined by any superclasses, as long as it's not Protocol. + all_fields = _get_non_duplicate_fields(self, Protocol) # It does matter if include is an empty set. Handle it's cases here. if self._fields_include is not None: @@ -366,18 +366,15 @@ def mification(obj: T) -> T: def _get_non_duplicate_members( - object: object, check_duplicates: object + obj: object, check_duplicates: object ) -> dict[str, Callable]: """Returns all methods/functions unique to the object.""" members = dict( inspect.getmembers( - object, + obj, # Checks for ismethod or isfunction because of the methods added from the MifiedProtocol. - predicate=lambda x: inspect.ismethod(x) - or ( - inspect.isfunction(x) - and x.__name__ not in dict(inspect.getmembers(check_duplicates)).keys() - ), + predicate=lambda x: (inspect.ismethod(x) or inspect.isfunction(x)) + and x.__name__ not in dict(inspect.getmembers(check_duplicates)).keys(), ) ) return members diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 0d2d6f97..0250b768 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -414,31 +414,6 @@ def validate( return rvs - def req(self, *args, **kwargs): - """Shorthand for Requirement.__init__(...).""" - return req(*args, **kwargs) - - def check(self, *args, **kwargs): - """Shorthand for Requirement.__init__(..., check_only=True).""" - return check(*args, **kwargs) - - def load_default_aloras(self): - """Loads the default Aloras for this model, if they exist and if the backend supports.""" - from mellea.backends.huggingface import LocalHFBackend - - if self.backend.model_id == IBM_GRANITE_3_2_8B and isinstance( - self.backend, LocalHFBackend - ): - from mellea.backends.aloras.huggingface.granite_aloras import ( - add_granite_aloras, - ) - - add_granite_aloras(self.backend) - return - self._session_logger.warning( - "This model/backend combination does not support any aloras." - ) - def genslot( self, gen_slot: Component, diff --git a/pyproject.toml b/pyproject.toml index 1548e1d9..4208cdf8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -161,3 +161,8 @@ skip = 'requirements.txt,uv.lock' [tool.mypy] disable_error_code = ["empty-body", "import-untyped"] python_version = "3.10" + +[tool.pytest.ini_options] +markers = [ + "qualitative: Marks the test as needing an exact output from an LLM; set by an ENV variable for CICD. All tests marked with this will xfail in CI/CD" +] \ No newline at end of file diff --git a/test/backends/test_huggingface.py b/test/backends/test_huggingface.py index 625f22d3..6859099a 100644 --- a/test/backends/test_huggingface.py +++ b/test/backends/test_huggingface.py @@ -37,7 +37,7 @@ def session(backend): yield session session.reset() - +@pytest.mark.qualitative def test_system_prompt(session): result = session.chat( "Where are we going?", @@ -45,7 +45,7 @@ def test_system_prompt(session): ) print(result) - +@pytest.mark.qualitative def test_constraint_alora(session, backend): answer = session.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa. Be concise and don't write code to answer the question.", @@ -63,7 +63,7 @@ def test_constraint_alora(session, backend): ) assert alora_output in ["Y", "N"], alora_output - +@pytest.mark.qualitative def test_constraint_lora_with_requirement(session, backend): answer = session.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" @@ -80,6 +80,7 @@ def test_constraint_lora_with_requirement(session, backend): assert str(val_result.reason) in ["Y", "N"] +@pytest.mark.qualitative def test_constraint_lora_override(session, backend): backend.default_to_constraint_checking_alora = False # type: ignore answer = session.instruct( @@ -95,6 +96,7 @@ def test_constraint_lora_override(session, backend): backend.default_to_constraint_checking_alora = True +@pytest.mark.qualitative def test_constraint_lora_override_does_not_override_alora(session, backend): backend.default_to_constraint_checking_alora = False # type: ignore answer = session.instruct( @@ -112,6 +114,7 @@ def test_constraint_lora_override_does_not_override_alora(session, backend): backend.default_to_constraint_checking_alora = True +@pytest.mark.qualitative def test_llmaj_req_does_not_use_alora(session, backend): backend.default_to_constraint_checking_alora = True # type: ignore answer = session.instruct( @@ -127,12 +130,13 @@ def test_llmaj_req_does_not_use_alora(session, backend): assert isinstance(val_result, ValidationResult) assert str(val_result.reason) not in ["Y", "N"] - +@pytest.mark.qualitative def test_instruct(session): result = session.instruct("Compute 1+1.") print(result) +@pytest.mark.qualitative def test_multiturn(session): session.instruct("Compute 1+1") beta = session.instruct( @@ -143,6 +147,7 @@ def test_multiturn(session): print(words) +@pytest.mark.qualitative def test_format(session): class Person(pydantic.BaseModel): name: str @@ -172,7 +177,7 @@ class Email(pydantic.BaseModel): "The email address should be at example.com" ) - +@pytest.mark.qualitative def test_generate_from_raw(session): prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] @@ -183,6 +188,7 @@ def test_generate_from_raw(session): assert len(results) == len(prompts) +@pytest.mark.qualitative def test_generate_from_raw_with_format(session): prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] diff --git a/test/backends/test_ollama.py b/test/backends/test_ollama.py index 78b07de4..b90d93fb 100644 --- a/test/backends/test_ollama.py +++ b/test/backends/test_ollama.py @@ -1,11 +1,13 @@ -from mellea import start_session, SimpleContext -from mellea.stdlib.base import CBlock -from mellea.stdlib.requirement import Requirement -import pydantic import json + +import pydantic +import pytest from typing_extensions import Annotated + +from mellea import SimpleContext, start_session from mellea.backends.types import ModelOption -import pytest +from mellea.stdlib.base import CBlock +from mellea.stdlib.requirement import Requirement @pytest.fixture(scope="function") @@ -15,6 +17,8 @@ def session(): yield session session.reset() + +@pytest.mark.qualitative def test_simple_instruct(session): result = session.instruct( "Write an email to Hendrik trying to sell him self-sealing stembolts." @@ -23,6 +27,8 @@ def test_simple_instruct(session): assert "chat_response" in result._meta assert result._meta["chat_response"].message.role == "assistant" + +@pytest.mark.qualitative def test_instruct_with_requirement(session): response = session.instruct( "Write an email to Hendrik convincing him to buy some self-sealing stembolts." @@ -45,12 +51,14 @@ def test_instruct_with_requirement(session): ) print(results) +@pytest.mark.qualitative def test_chat(session): output_message = session.chat("What is 1+1?") - assert ( - "2" in output_message.content - ), f"Expected a message with content containing 2 but found {output_message}" + assert "2" in output_message.content, ( + f"Expected a message with content containing 2 but found {output_message}" + ) +@pytest.mark.qualitative def test_format(session): class Person(pydantic.BaseModel): name: str @@ -83,6 +91,7 @@ class Email(pydantic.BaseModel): # assert email.to.email_address.endswith("example.com") pass +@pytest.mark.qualitative def test_generate_from_raw(session): prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] @@ -113,9 +122,9 @@ class Answer(pydantic.BaseModel): try: answer = Answer.model_validate_json(random_result.value) except pydantic.ValidationError as e: - assert ( - False - ), f"formatting directive failed for {random_result.value}: {e.json()}" + assert False, ( + f"formatting directive failed for {random_result.value}: {e.json()}" + ) if __name__ == "__main__": diff --git a/test/backends/test_openai_ollama.py b/test/backends/test_openai_ollama.py index 5bf3b5f3..def41004 100644 --- a/test/backends/test_openai_ollama.py +++ b/test/backends/test_openai_ollama.py @@ -1,41 +1,54 @@ # test/rits_backend_tests/test_openai_integration.py -from mellea import MelleaSession -from mellea.stdlib.base import CBlock, LinearContext, ModelOutputThunk -from mellea.backends.openai import OpenAIBackend -from mellea.backends.formatter import TemplateFormatter -from mellea.backends.types import ModelOption +import os import pydantic -from typing_extensions import Annotated import pytest +from typing_extensions import Annotated + +from mellea import MelleaSession +from mellea.backends.formatter import TemplateFormatter +from mellea.backends.model_ids import META_LLAMA_3_2_1B +from mellea.backends.openai import OpenAIBackend +from mellea.backends.types import ModelOption +from mellea.stdlib.base import CBlock, LinearContext, ModelOutputThunk @pytest.fixture(scope="module") -def backend(): +def backend(gh_run: int): """Shared OpenAI backend configured for Ollama.""" - return OpenAIBackend( - model_id="granite3.3:8b", - formatter=TemplateFormatter(model_id="ibm-granite/granite-3.2-8b-instruct"), - base_url="http://localhost:11434/v1", + if gh_run == 1: + return OpenAIBackend( + model_id=META_LLAMA_3_2_1B, + formatter=TemplateFormatter(model_id=META_LLAMA_3_2_1B), + base_url=f"http://{os.environ.get('OLLAMA_HOST', 'localhost:11434')}/v1", api_key="ollama", ) + else: + return OpenAIBackend( + model_id="granite3.3:8b", + formatter=TemplateFormatter(model_id="ibm-granite/granite-3.2-8b-instruct"), + base_url=f"http://{os.environ.get('OLLAMA_HOST', 'localhost:11434')}/v1", + api_key="ollama", + ) @pytest.fixture(scope="function") -def session(backend): +def m_session(backend): """Fresh OpenAI session for each test.""" session = MelleaSession(backend, ctx=LinearContext(is_chat_context=True)) yield session session.reset() -def test_instruct(session): - result = session.instruct("Compute 1+1.") +@pytest.mark.qualitative +def test_instruct(m_session): + result = m_session.instruct("Compute 1+1.") assert isinstance(result, ModelOutputThunk) assert "2" in result.value # type: ignore -def test_multiturn(session): - session.instruct("What is the capital of France?") - answer = session.instruct("Tell me the answer to the previous question.") +@pytest.mark.qualitative +def test_multiturn(m_session): + m_session.instruct("What is the capital of France?") + answer = m_session.instruct("Tell me the answer to the previous question.") assert "Paris" in answer.value # type: ignore # def test_api_timeout_error(self): @@ -53,7 +66,8 @@ def test_multiturn(session): # assert "granite3.3:8b" in result.value # self.m.reset() -def test_format(session): +@pytest.mark.qualitative +def test_format(m_session): class Person(pydantic.BaseModel): name: str # it does not support regex patterns in json schema @@ -68,7 +82,7 @@ class Email(pydantic.BaseModel): subject: str body: str - output = session.instruct( + output = m_session.instruct( "Write a short email to Olivia, thanking her for organizing a sailing activity. Her email server is example.com. No more than two sentences. ", format=Email, model_options={ModelOption.MAX_NEW_TOKENS: 2**8}, diff --git a/test/backends/test_watsonx.py b/test/backends/test_watsonx.py index b62def03..85dedd66 100644 --- a/test/backends/test_watsonx.py +++ b/test/backends/test_watsonx.py @@ -1,46 +1,56 @@ # test/rits_backend_tests/test_watsonx_integration.py import os -from mellea import MelleaSession -from mellea.stdlib.base import CBlock, LinearContext, ModelOutputThunk -from mellea.backends.watsonx import WatsonxAIBackend -from mellea.backends.formatter import TemplateFormatter -from mellea.backends.types import ModelOption import pydantic -from typing_extensions import Annotated import pytest +from typing_extensions import Annotated + +from mellea import MelleaSession +from mellea.backends.formatter import TemplateFormatter +from mellea.backends.types import ModelOption +from mellea.backends.watsonx import WatsonxAIBackend +from mellea.stdlib.base import CBlock, LinearContext, ModelOutputThunk @pytest.fixture(scope="module") def backend(): """Shared Watson backend for all tests in this module.""" - return WatsonxAIBackend( + if os.environ.get("CICD") == 1: + pytest.skip("Skipping watsonx tests.") + else: + return WatsonxAIBackend( model_id="ibm/granite-3-3-8b-instruct", formatter=TemplateFormatter(model_id="ibm-granite/granite-3.3-8b-instruct"), ) @pytest.fixture(scope="function") -def session(backend): - """Fresh Watson session for each test.""" - session = MelleaSession(backend, ctx=LinearContext(is_chat_context=True)) - yield session - session.reset() - - - - -def test_instruct(session): +def session(backend: WatsonxAIBackend): + if os.environ.get("CICD") == 1: + pytest.skip("Skipping watsonx tests.") + else: + """Fresh Watson session for each test.""" + session = MelleaSession(backend, ctx=LinearContext(is_chat_context=True)) + yield session + session.reset() + + +@pytest.mark.qualitative +def test_instruct(session: MelleaSession): result = session.instruct("Compute 1+1.") assert isinstance(result, ModelOutputThunk) assert "2" in result.value # type: ignore -def test_multiturn(session): + +@pytest.mark.qualitative +def test_multiturn(session: MelleaSession): session.instruct("What is the capital of France?") answer = session.instruct("Tell me the answer to the previous question.") assert "Paris" in answer.value # type: ignore -def test_format(session): + +@pytest.mark.qualitative +def test_format(session: MelleaSession): class Person(pydantic.BaseModel): name: str # it does not support regex patterns in json schema @@ -72,7 +82,9 @@ class Email(pydantic.BaseModel): # assert email.to.email_address.endswith("example.com") pass -def test_generate_from_raw(session): + +@pytest.mark.qualitative +def test_generate_from_raw(session: MelleaSession): prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] results = session.backend._generate_from_raw( diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 00000000..e95ce41b --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,26 @@ +import os + +import pytest + +from mellea.backends.huggingface import LocalHFBackend +from mellea.backends.ollama import OllamaModelBackend +from mellea.backends.openai import OpenAIBackend +from mellea.stdlib.session import MelleaSession + + +@pytest.fixture(scope="session") +def gh_run() -> int: + return int(os.environ.get("CICD", 0)) # type: ignore + + +def pytest_runtest_setup(item): + # Runs tests *not* marked with `@pytest.mark.qualitative` to run normally. + if not item.get_closest_marker("qualitative"): + return + + gh_run = int(os.environ.get("CICD", 0)) + + if gh_run == 1: + pytest.xfail( + reason="Skipping qualitative test: got env variable CICD == 1. Used only in gh workflows." + ) diff --git a/test/stdlib_basics/test_base.py b/test/stdlib_basics/test_base.py index 9ec99db6..6d4008bd 100644 --- a/test/stdlib_basics/test_base.py +++ b/test/stdlib_basics/test_base.py @@ -1,5 +1,4 @@ -from mellea.stdlib.base import Component, CBlock -from mellea.stdlib.base import LinearContext +from mellea.stdlib.base import CBlock, Component, LinearContext def test_cblock(): diff --git a/test/stdlib_basics/test_chat_view.py b/test/stdlib_basics/test_chat_view.py index c0d9e9eb..c56b7c2e 100644 --- a/test/stdlib_basics/test_chat_view.py +++ b/test/stdlib_basics/test_chat_view.py @@ -1,7 +1,8 @@ import pytest -from mellea.stdlib.base import ModelOutputThunk, LinearContext -from mellea.stdlib.chat import as_chat_history, Message + +from mellea.stdlib.base import LinearContext, ModelOutputThunk +from mellea.stdlib.chat import Message, as_chat_history from mellea.stdlib.session import start_session diff --git a/test/stdlib_basics/test_contextual_session.py b/test/stdlib_basics/test_contextual_session.py index 97699831..a142f879 100644 --- a/test/stdlib_basics/test_contextual_session.py +++ b/test/stdlib_basics/test_contextual_session.py @@ -1,10 +1,21 @@ -import pytest from typing import Literal -from mellea import generative, start_session, instruct, chat, validate, query, transform + +import pytest + +from mellea import chat, generative, instruct, query, start_session, transform, validate +from mellea.backends.model_ids import IBM_GRANITE_3_3_8B, META_LLAMA_3_2_1B from mellea.stdlib.base import ModelOutputThunk -from mellea.stdlib.session import get_session, MelleaSession -from mellea.stdlib.mify import mify, MifiedProtocol +from mellea.stdlib.mify import MifiedProtocol, mify from mellea.stdlib.requirement import req +from mellea.stdlib.session import MelleaSession, get_session + + +@pytest.fixture(scope="module") +def model_id(gh_run: int): + if gh_run == 1: + return META_LLAMA_3_2_1B + else: + return IBM_GRANITE_3_3_8B @generative @@ -26,9 +37,9 @@ def get_info(self) -> str: return f"{self.name} is {self.age} years old" -def test_basic_contextual_session(): +def test_basic_contextual_session(model_id): """Test basic contextual session usage with convenience functions.""" - with start_session(): + with start_session(model_id=model_id): # Test instruct result = instruct("Say hello") assert isinstance(result, ModelOutputThunk) @@ -51,9 +62,9 @@ def test_no_active_session_error(): chat("test") -def test_generative_with_contextual_session(): +def test_generative_with_contextual_session(model_id): """Test generative slots work with contextual sessions.""" - with start_session(): + with start_session(model_id=model_id): # Test without explicit session parameter result = classify_sentiment(text="I love this!") assert result in ["positive", "negative"] @@ -63,18 +74,18 @@ def test_generative_with_contextual_session(): assert isinstance(summary, str) assert len(summary) > 0 - -def test_generative_backward_compatibility(): +@pytest.mark.qualitative +def test_generative_backward_compatibility(model_id): """Test that generative slots still work with explicit session parameter.""" - with start_session() as m: + with start_session(model_id=model_id) as m: # Test old pattern still works result = classify_sentiment(m, text="I love this!") assert result in ["positive", "negative"] -def test_mify_with_contextual_session(): +def test_mify_with_contextual_session(model_id): """Test mify functionality with contextual sessions.""" - with start_session(): + with start_session(model_id=model_id): person = TestPerson("Alice", 30) assert isinstance(person, MifiedProtocol) @@ -88,13 +99,13 @@ def test_mify_with_contextual_session(): assert transform_result is not None -def test_nested_sessions(): +def test_nested_sessions(model_id): """Test nested sessions behavior.""" - with start_session() as outer_session: + with start_session(model_id=model_id) as outer_session: outer_result = instruct("outer session test") assert isinstance(outer_result, ModelOutputThunk) - with start_session() as inner_session: + with start_session(model_id=model_id) as inner_session: # Inner session should be active current_session = get_session() assert current_session is inner_session @@ -107,10 +118,10 @@ def test_nested_sessions(): assert current_session is outer_session -def test_session_cleanup(): +def test_session_cleanup(model_id): """Test session cleanup after context exit.""" session_ref = None - with start_session() as m: + with start_session(model_id=model_id) as m: session_ref = m instruct("test during session") @@ -119,19 +130,19 @@ def test_session_cleanup(): get_session() # Session should have been cleaned up - assert hasattr(session_ref, 'ctx') + assert hasattr(session_ref, "ctx") -def test_all_convenience_functions(): +def test_all_convenience_functions(model_id): """Test all convenience functions work within contextual session.""" - with start_session(): + with start_session(model_id=model_id): # Test instruct instruct_result = instruct("Generate a greeting") assert isinstance(instruct_result, ModelOutputThunk) # Test chat chat_result = chat("Hello there") - assert hasattr(chat_result, 'content') + assert hasattr(chat_result, "content") # Test validate validation = validate([req("The response should be positive")]) @@ -147,18 +158,18 @@ def test_all_convenience_functions(): assert transform_result is not None -def test_session_with_parameters(): +def test_session_with_parameters(model_id): """Test contextual session with custom parameters.""" - with start_session(backend_name="ollama", model_id="granite3.3:8b") as m: + with start_session(backend_name="ollama", model_id=model_id) as m: result = instruct("test with parameters") assert isinstance(result, ModelOutputThunk) assert isinstance(m, MelleaSession) -def test_multiple_sequential_sessions(): +def test_multiple_sequential_sessions(model_id): """Test multiple sequential contextual sessions.""" # First session - with start_session(): + with start_session(model_id=model_id): result1 = instruct("first session") assert isinstance(result1, ModelOutputThunk) @@ -167,14 +178,14 @@ def test_multiple_sequential_sessions(): get_session() # Second session - with start_session(): + with start_session(model_id=model_id): result2 = instruct("second session") assert isinstance(result2, ModelOutputThunk) -def test_contextual_session_with_mified_object_methods(): +def test_contextual_session_with_mified_object_methods(model_id): """Test that mified objects work properly within contextual sessions.""" - with start_session(): + with start_session(model_id=model_id): person = TestPerson("Bob", 25) # Test that mified object methods work @@ -187,12 +198,12 @@ def test_contextual_session_with_mified_object_methods(): # Test format_for_llm llm_format = person.format_for_llm() assert llm_format is not None - assert hasattr(llm_format, 'args') + assert hasattr(llm_format, "args") -def test_session_methods_with_mified_objects(): +def test_session_methods_with_mified_objects(model_id): """Test using session query/transform methods with mified objects.""" - with start_session() as m: + with start_session(model_id=model_id) as m: person = TestPerson("Charlie", 35) # Test session query method @@ -205,11 +216,11 @@ def test_session_methods_with_mified_objects(): assert transform_result is not None # Verify mified objects have query/transform object creation methods - assert hasattr(person, 'get_query_object') - assert hasattr(person, 'get_transform_object') - assert hasattr(person, '_query_type') - assert hasattr(person, '_transform_type') + assert hasattr(person, "get_query_object") + assert hasattr(person, "get_transform_object") + assert hasattr(person, "_query_type") + assert hasattr(person, "_transform_type") if __name__ == "__main__": - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/test/stdlib_basics/test_genslot.py b/test/stdlib_basics/test_genslot.py index 8aa51828..f47b7577 100644 --- a/test/stdlib_basics/test_genslot.py +++ b/test/stdlib_basics/test_genslot.py @@ -33,7 +33,7 @@ def test_func(session): write_email_component = write_me_an_email(session) assert isinstance(write_email_component, str) - +@pytest.mark.qualitative def test_sentiment_output(classify_sentiment_output): assert classify_sentiment_output in ["positive", "negative"] diff --git a/test/stdlib_basics/test_richdocument.py b/test/stdlib_basics/test_richdocument.py index 8809b1a0..6046ea96 100644 --- a/test/stdlib_basics/test_richdocument.py +++ b/test/stdlib_basics/test_richdocument.py @@ -93,6 +93,7 @@ def test_empty_table(): assert table is None, "table should be empty when supplied string is empty" +@pytest.mark.skip # Test requires too much memory for smaller machines. def test_richdocument_generation(rd: RichDocument): m = mellea.start_session(backend_name="hf") response = m.chat(rd.to_markdown()[:500] + "\nSummarize the provided document.") diff --git a/test/stdlib_basics/test_session.py b/test/stdlib_basics/test_session.py index 351c08d8..9caa8d6f 100644 --- a/test/stdlib_basics/test_session.py +++ b/test/stdlib_basics/test_session.py @@ -1,22 +1,36 @@ +import os + import pytest + from mellea.stdlib.base import ModelOutputThunk from mellea.stdlib.session import start_session -def test_start_session_watsonx(): - m = start_session(backend_name="watsonx") - response = m.instruct("testing") - assert isinstance(response, ModelOutputThunk) - assert response.value is not None +def test_start_session_watsonx(gh_run): + if gh_run == 1: + pytest.skip("Skipping watsonx tests.") + else: + m = start_session(backend_name="watsonx") + response = m.instruct("testing") + assert isinstance(response, ModelOutputThunk) + assert response.value is not None -def test_start_session_openai_with_kwargs(): - m = start_session( +def test_start_session_openai_with_kwargs(gh_run): + if gh_run == 1: + m = start_session( "openai", - model_id="granite3.3:8b", - base_url="http://localhost:11434/v1", + model_id="llama3.2:1b", + base_url=f"http://{os.environ.get('OLLAMA_HOST', 'localhost:11434')}/v1", api_key="ollama", ) + else: + m = start_session( + "openai", + model_id="granite3.3:8b", + base_url=f"http://{os.environ.get('OLLAMA_HOST', 'localhost:11434')}/v1", + api_key="ollama", + ) response = m.instruct("testing") assert isinstance(response, ModelOutputThunk) assert response.value is not None