Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion docs/examples/best_of_n/prm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from docs.examples.helper import w
from mellea import start_session
from mellea.backends.model_ids import IBM_GRANITE_3_3_8B
from mellea.backends.process_reward_models.huggingface.prms import (
HFGenerativePRM,
HFRegressionPRM,
Expand All @@ -11,7 +12,11 @@
from mellea.stdlib.sampling.best_of_n import BestofNSamplingStrategy

# create a session for the generator using Granite 3.3 8B on Huggingface and a simple context [see below]
m = start_session(backend_name="hf", model_options={ModelOption.MAX_NEW_TOKENS: 512})
m = start_session(
backend_name="hf",
model_id=IBM_GRANITE_3_3_8B,
model_options={ModelOption.MAX_NEW_TOKENS: 512},
)

# initialize the PRM model
prm_model = HFGenerativePRM(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ def extract_final_short_answer(

if __name__ == "__main__":
scores = []
m = start_session()

for question, target in (
x.values() for x in load_dataset("gsm8k", "main", split="train[:100]")
):
m = start_session()

target = int(target.split("####")[-1])
response = compute_chain_of_thought_and_final_answer(m, question=question)
for step in response.step_by_step_solution:
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/generative_slots/generative_slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def generate_summary(text: str) -> str:


if __name__ == "__main__":
with start_session():
sentiment_component = classify_sentiment(text="I love this!")
with start_session() as m:
sentiment_component = classify_sentiment(m, text="I love this!")
print("Output sentiment is : ", sentiment_component)

summary = generate_summary(
Expand Down
6 changes: 4 additions & 2 deletions docs/examples/information_extraction/101_with_gen_slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from mellea import generative, start_session
from mellea.backends import model_ids

m = start_session(model_id=model_ids.MISTRALAI_MISTRAL_0_3_7B)
m = start_session()


@generative
def extract_all_person_names(doc: str) -> list[str]:
"""Given a document, extract all person names. Return these names as list of strings."""
"""
Given a document, extract names of ALL mentioned persons. Return these names as list of strings.
"""


# ref: https://www.nytimes.com/2012/05/20/world/world-leaders-at-us-meeting-urge-growth-not-austerity.html
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ def _at_least_(t: str) -> bool:


# start session
m = start_session(model_id=model_ids.MISTRALAI_MISTRAL_0_3_7B)
m = start_session()

# run extraction using grounding context and sampling strategy
sampled_p_names = m.instruct(
"Extract the person names from the document (doc1).",
"Extract ALL person names from the document (doc1).",
grounding_context={"doc1": NYTimes_text},
requirements=[check(None, validation_fn=simple_validate(at_least_n(2)))],
strategy=RejectionSamplingStrategy(loop_budget=5),
Expand Down
5 changes: 3 additions & 2 deletions docs/examples/instruct_validate_repair/101_email.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
# # start_session() is equivalent to:
# from mellea.backends import model_ids
# from mellea.backends.ollama import OllamaModelBackend
# from mellea import MelleaSession, SimpleContext
# from mellea import MelleaSession
# from mellea.stdlib.base import SimpleContext
# m = MelleaSession(
# backend=OllamaModelBackend(
# model_id=model_ids.IBM_GRANITE_3_3_8B,
# model_id=model_ids.IBM_GRANITE_4_MICRO_3B,
# model_options={ModelOption.MAX_NEW_TOKENS: 200},
# ),
# ctx=SimpleContext()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"Write a very funny email to invite all interns to the office party."
)
print(
f"***** email 1 ****\n{w(email_v1)}\n*******email 2 ******\n{w(email_v2)}\n*******"
f"***** email 1 ****\n{w(email_v1)}\n*******email 2 ******\n{w(email_v2)}\n\n*******"
)

# Use the emails as grounding context to evaluate which one is quirkier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# write an email with automatic requirement checking.
email_v1 = m.instruct(
"Write an email to invite all interns to the office party.",
requirements=["be formal", "Use 'Dear interns' as greeting."],
requirements=["be formal", "Use 'Dear Interns' as greeting."],
)

# print result
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from docs.examples.helper import req_print, w
from mellea import start_session
from mellea.backends.model_ids import IBM_GRANITE_3_3_8B
from mellea.backends.types import ModelOption
from mellea.stdlib.sampling import RejectionSamplingStrategy

# create a session using Granite 3.3 8B on Ollama and a simple context [see below]
# create a session using Granite 4 Micro (3B) on Ollama and a simple context [see below]
m = start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200})

email_v2_samples = m.instruct(
"Write an email to invite all interns to the office party.",
requirements=["be formal", "Use 'Dear interns' as greeting."],
"Write a very short email to invite all interns to the office party.",
requirements=["Use formal language.", "Use 'Dear Interns' as greeting."],
strategy=RejectionSamplingStrategy(loop_budget=3),
return_sampling_results=True,
)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/mini_researcher/researcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
@cache
def get_session():
"""Get M session (change model here)."""
return MelleaSession(backend=OllamaModelBackend(model_ids.IBM_GRANITE_3_3_8B))
return MelleaSession(backend=OllamaModelBackend(model_ids.IBM_GRANITE_4_MICRO_3B))


@cache
Expand Down
4 changes: 3 additions & 1 deletion docs/examples/mobject/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas

import mellea
from mellea.backends.model_ids import IBM_GRANITE_3_3_8B
from mellea.stdlib.mify import mify


Expand Down Expand Up @@ -51,9 +52,10 @@ def transpose(self):
if __name__ == "__main__":
m = mellea.start_session()
db = MyCompanyDatabase()
print(m.query(db, "What were sales for the Northeast branch this month?"))
print(m.query(db, "What were sales for the Northeast branch this month?").value)
result = m.transform(db, "Update the northeast sales to 1250.")
print(type(result))
print(db.table)
print(m.query(db, "What were sales for the Northeast branch this month?"))
result = m.transform(db, "Transpose the table.")
print(result)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
},
"source": [
"# Compositionality with Generative Slots\n",
"This Jupyter notebook runs on Colab demonstrates GEnerative Slots, a function whose implementation is provided by an LLM. "
"This Jupyter notebook runs on Colab demonstrates GEnerative Slots, a function whose implementation is provided by an LLM."
]
},
{
Expand Down Expand Up @@ -168,7 +168,7 @@
"metadata": {},
"source": [
"## Start a Mellea Session\n",
"We initialize a backend running Ollama using the granite3.3-chat model."
"We initialize a backend running Ollama using the granite 4 micro model."
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/notebooks/context_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
"\n",
"Up to this point we have used SimpleContext, a context manager that resets the chat message history on each model call. That is, the model's context is entirely determined by the current Component.\n",
"\n",
"Mellea also provides a LinearContext, which behaves like a chat history. We will use the LinearContext to interact with cat hmodels:"
"Mellea also provides a LinearContext, which behaves like a chat history. We will use the ChatContext to interact with chat models:"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/notebooks/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
},
"source": [
"## Import Mellea and Start a Session\n",
"We initialize a backend running Ollama using the granite3.3-chat model."
"We initialize a backend running Ollama using the granite 4 micro model."
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions docs/examples/notebooks/georgia_tech.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"\n",
"Run the first cell during our introduction. The first cell will:\n",
" * download an install ollama on your Colab instance\n",
" * download the `granite3.3:8b` model weights\n"
" * download the `ibm/granite4:micro` model weights\n"
]
},
{
Expand All @@ -37,7 +37,7 @@
"!nohup ollama serve >/dev/null 2>&1 &\n",
"\n",
"# Download the granite:3.3:8b weights.\n",
"!ollama pull granite3.3:8b\n",
"!ollama pull ibm/granite4:micro\n",
"!ollama pull llama3.2:3b\n",
"\n",
"# install Mellea.\n",
Expand Down Expand Up @@ -68,7 +68,7 @@
"# LAB 1: Hello, Mellea!\n",
"\n",
"Running `mellea.start_session()` initialize a new `MelleaSession`. The session holds three things:\n",
"1. The model to use for this session. In this tutorial we will use granite3.3:8b.\n",
"1. The model to use for this session. In this tutorial we will use granite 4 micro (3B).\n",
"2. An inference engine; i.e., the code that actually calls our model. We will be using ollama, but you can also use Huggingface or any OpenAI-compatible endpoint.\n",
"3. A `Context`, which tells Mellea how to remember context between requests. This is sometimes called the \"Message History\" in other frameworks. Throughout this tutorial, we will be using a `SimpleContext`. In `SimpleContext`s, **every request starts with a fresh context**. There is no preserved chat history between requests. Mellea provides other types of context, but today we will not be using those features. See the Tutorial for further details."
]
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/rag/simple_rag_with_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def is_answer_relevant_to_question(answer: str, question: str) -> bool:
print(f"results:\n {results_str}\n ====")
del embedding_model # help GC

# Create Mellea session
# Create Mellea session with Mistral. Also work with other models.
m = start_session(model_id=model_ids.MISTRALAI_MISTRAL_0_3_7B)

# Check for each document from retrieval if it is actually relevant
Expand Down
6 changes: 4 additions & 2 deletions docs/examples/tutorial/document_mobject.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from mellea.backends import model_ids
from mellea.backends.model_ids import IBM_GRANITE_3_3_8B
from mellea.stdlib.docs.richdocument import RichDocument

rd = RichDocument.from_document_file("https://arxiv.org/pdf/1906.04043")
Expand All @@ -10,11 +12,11 @@
from mellea import start_session # noqa: E402
from mellea.backends.types import ModelOption # noqa: E402

m = start_session()
m = start_session(model_id=model_ids.META_LLAMA_3_2_3B)
for seed in [x * 12 for x in range(5)]:
table2 = m.transform(
table1,
"Add a column 'Model' that extracts which model was used or 'None' if none.",
"Add a column 'Model' that extracts which model was used in Feature description or 'None' if none.",
model_options={ModelOption.SEED: seed},
)
if isinstance(table2, Table):
Expand Down
7 changes: 3 additions & 4 deletions docs/examples/tutorial/model_options_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from mellea.backends.types import ModelOption

m = mellea.MelleaSession(
backend=OllamaModelBackend(
model_id=model_ids.IBM_GRANITE_3_2_8B, model_options={ModelOption.SEED: 42}
)
backend=OllamaModelBackend(model_options={ModelOption.SEED: 42})
)

answer = m.instruct(
"What is 2x2?", model_options={"temperature": 0.5, "num_predict": 5}
"What is 2x2?",
model_options={ModelOption.TEMPERATURE: 0.5, ModelOption.MAX_NEW_TOKENS: 15},
)

print(str(answer))
4 changes: 2 additions & 2 deletions docs/examples/tutorial/simple_email.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def write_email(m: mellea.MelleaSession, name: str, notes: str) -> str:
"Write an email to {{name}} using the notes following: {{notes}}.",
user_variables={"name": name, "notes": notes},
)
return email.value # str(email) also works.
return str(email.value) # str(email) also works.


print(
Expand Down Expand Up @@ -70,7 +70,7 @@ def write_email_with_strategy(m: mellea.MelleaSession, name: str, notes: str) ->
return str(email_candidate.result)
else:
print("Expect sub-par result.")
return email_candidate.sample_generations[0].value
return str(email_candidate.sample_generations[0].value)


print(
Expand Down
2 changes: 1 addition & 1 deletion mellea/backends/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class LiteLLMBackend(FormatterBackend):

def __init__(
self,
model_id: str = "ollama/" + str(model_ids.IBM_GRANITE_3_3_8B.ollama_name),
model_id: str = "ollama/" + str(model_ids.IBM_GRANITE_4_MICRO_3B.ollama_name),
formatter: Formatter | None = None,
base_url: str | None = "http://localhost:11434",
model_options: dict | None = None,
Expand Down
8 changes: 8 additions & 0 deletions mellea/backends/model_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ class ModelIdentifier:
#### IBM models ####
####################

IBM_GRANITE_4_MICRO_3B = ModelIdentifier(
hf_model_name="ibm-granite/granite-4.0-micro",
ollama_name="ibm/granite4:micro",
watsonx_name="ibm/granite-4-h-small",
)
Comment on lines +31 to +32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we care that if we eventually switch this we might change results / performance for watsonx backend users? I don't think so, just wanted to flag that.

# todo: watsonx model is different from ollama model - should be same.


IBM_GRANITE_3_2_8B = ModelIdentifier(
hf_model_name="ibm-granite/granite-3.2-8b-instruct",
ollama_name="granite3.2:8b",
Expand Down
2 changes: 1 addition & 1 deletion mellea/backends/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class OllamaModelBackend(FormatterBackend):

def __init__(
self,
model_id: str | ModelIdentifier = model_ids.IBM_GRANITE_3_3_8B,
model_id: str | ModelIdentifier = model_ids.IBM_GRANITE_4_MICRO_3B,
formatter: Formatter | None = None,
base_url: str | None = None,
model_options: dict | None = None,
Expand Down
2 changes: 1 addition & 1 deletion mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class OpenAIBackend(FormatterBackend, AloraBackendMixin):

def __init__(
self,
model_id: str | ModelIdentifier = model_ids.IBM_GRANITE_3_3_8B,
model_id: str | ModelIdentifier = model_ids.IBM_GRANITE_4_MICRO_3B,
formatter: Formatter | None = None,
base_url: str | None = None,
model_options: dict | None = None,
Expand Down
8 changes: 6 additions & 2 deletions mellea/stdlib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

import mellea.stdlib.funcs as mfuncs
from mellea.backends import Backend, BaseModelSubclass
from mellea.backends.model_ids import IBM_GRANITE_3_3_8B, ModelIdentifier
from mellea.backends.model_ids import (
IBM_GRANITE_3_3_8B,
IBM_GRANITE_4_MICRO_3B,
ModelIdentifier,
)
from mellea.backends.ollama import OllamaModelBackend
from mellea.backends.openai import OpenAIBackend
from mellea.helpers.fancy_logger import FancyLogger
Expand Down Expand Up @@ -70,7 +74,7 @@ def backend_name_to_class(name: str) -> Any:

def start_session(
backend_name: Literal["ollama", "hf", "openai", "watsonx", "litellm"] = "ollama",
model_id: str | ModelIdentifier = IBM_GRANITE_3_3_8B,
model_id: str | ModelIdentifier = IBM_GRANITE_4_MICRO_3B,
ctx: Context | None = None,
*,
model_options: dict | None = None,
Expand Down