Skip to content

Commit ea7203c

Browse files
avinash2692GitHub Enterprise
authored andcommitted
minor changes to chanpter 6 alora tutorial
* minor changes to chanpter 6 alora tutorial * pinning trl to avoid error in importing DataCollatorForCompletionOnlyLM
1 parent f6bf855 commit ea7203c

File tree

5 files changed

+39
-31
lines changed

5 files changed

+39
-31
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ from mellea.stdlib.sampling import RejectionSamplingStrategy
114114
# create a session with Mistral running on Ollama
115115
m = MelleaSession(
116116
backend=OllamaModelBackend(
117-
model_id=model_ids.MISTRALAI_MISTRAL_0_3_7b,
117+
model_id=model_ids.MISTRALAI_MISTRAL_0_3_7B,
118118
model_options={ModelOption.MAX_NEW_TOKENS: 300},
119119
)
120120
)

docs/examples/aLora/101_example.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,40 @@
11
import time
22

33
from mellea import LinearContext, MelleaSession
4-
from mellea.backends.aloras.huggingface.granite_aloras import (
5-
HFConstraintAlora,
6-
add_granite_aloras,
7-
)
4+
from mellea.backends.aloras.huggingface.granite_aloras import HFConstraintAlora
85
from mellea.backends.cache import SimpleLRUCache
96
from mellea.backends.huggingface import LocalHFBackend
107
from mellea.stdlib.base import GenerateLog
11-
from mellea.stdlib.requirement import Requirement, req
8+
from mellea.stdlib.requirement import ALoraRequirement, Requirement
129

1310
# Define a backend and add the constraint aLora
1411
backend = LocalHFBackend(
1512
model_id="ibm-granite/granite-3.2-8b-instruct", cache=SimpleLRUCache(5)
1613
)
1714

18-
backend.add_alora(
19-
HFConstraintAlora(
20-
name="custom_construant",
21-
path_or_model_id="my_uploaded_model/goes_here", # can also be the checkpoint path
22-
generation_prompt="<|start_of_role|>check_requirement<|end_of_role|>",
23-
backend=backend,
24-
)
15+
custom_stembolt_failure_constraint = HFConstraintAlora(
16+
name="custom_stembolt_failure_constraint",
17+
path_or_model_id="docs/examples/aLora/checkpoints/alora_adapter", # can also be the checkpoint path
18+
generation_prompt="<|start_of_role|>check_requirement<|end_of_role|>",
19+
backend=backend,
2520
)
2621

22+
backend.add_alora(custom_stembolt_failure_constraint)
23+
2724
# Create M session
2825
m = MelleaSession(backend, ctx=LinearContext())
2926

3027
# define a requirement
31-
failure_check = req("The failure mode should not be none.")
28+
failure_check = ALoraRequirement(
29+
"The failure mode should not be none.", alora=custom_stembolt_failure_constraint
30+
)
3231

3332
# run instruction with requirement attached on the base model
3433
res = m.instruct(
35-
"Write triage summaries based on technician note.", requirements=[failure_check]
34+
"""Write triage summaries based on technician note.
35+
1. Oil seepage around piston rings suggests seal degradation
36+
""",
37+
requirements=[failure_check],
3638
)
3739

3840
print("==== Generation =====")
@@ -77,9 +79,11 @@ def validate_reqs(reqs: list[Requirement]):
7779

7880

7981
# run with aLora -- which is the default if the constraint alora is added to a model
80-
validate_reqs([failure_check])
82+
computetime_alora, alora_result = validate_reqs([failure_check])
8183

84+
# NOTE: This is not meant for use in regular programming using mellea, but just as an illustration for the speedup you can get with aloras.
8285
# force to run without alora
8386
backend.default_to_constraint_checking_alora = False
84-
validate_reqs([failure_check])
85-
backend.default_to_constraint_checking_alora = True
87+
computetime_no_alora, no_alora_result = validate_reqs([failure_check])
88+
89+
print(f"Speed up time with using aloras is {computetime_alora - computetime_no_alora}")

docs/tutorial.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -747,9 +747,13 @@ Mellea provides a command-line interface for training [LoRA](https://arxiv.org/a
747747

748748
We will train a lightweight adapter with the `m alora train` command on this small dataset:
749749

750+
> [!NOTE]
751+
> This script will require access to a gpu to run. You could also run this on your cpu, but it might take a while.
752+
> For mac users, you might not be able to run this script as is, given the lack of `fp16` support in the accelerate library.
753+
750754
```bash
751755
m alora train /to/stembolts_data.jsonl \
752-
--promtfile ./prompt_config.json \
756+
--promptfile ./prompt_config.json \
753757
--basemodel ibm-granite/granite-3.2-8b-instruct \
754758
--outfile ./checkpoints/alora_adapter \
755759
--adapter alora \

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ dependencies = [
4747
"typer",
4848
"click<8.2.0", # Newer versions will cause errors with --help in typer CLIs.
4949
"mistletoe>=1.4.0",
50-
"trl",
50+
"trl==0.19.0",
5151
"peft",
5252
"torch"
5353
]

uv.lock

Lines changed: 11 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)