|
1 | 1 | import time |
2 | 2 |
|
3 | 3 | from mellea import LinearContext, MelleaSession |
4 | | -from mellea.backends.aloras.huggingface.granite_aloras import ( |
5 | | - HFConstraintAlora, |
6 | | - add_granite_aloras, |
7 | | -) |
| 4 | +from mellea.backends.aloras.huggingface.granite_aloras import HFConstraintAlora |
8 | 5 | from mellea.backends.cache import SimpleLRUCache |
9 | 6 | from mellea.backends.huggingface import LocalHFBackend |
10 | 7 | from mellea.stdlib.base import GenerateLog |
11 | | -from mellea.stdlib.requirement import Requirement, req |
| 8 | +from mellea.stdlib.requirement import ALoraRequirement, Requirement |
12 | 9 |
|
13 | 10 | # Define a backend and add the constraint aLora |
14 | 11 | backend = LocalHFBackend( |
15 | 12 | model_id="ibm-granite/granite-3.2-8b-instruct", cache=SimpleLRUCache(5) |
16 | 13 | ) |
17 | 14 |
|
18 | | -backend.add_alora( |
19 | | - HFConstraintAlora( |
20 | | - name="custom_construant", |
21 | | - path_or_model_id="my_uploaded_model/goes_here", # can also be the checkpoint path |
22 | | - generation_prompt="<|start_of_role|>check_requirement<|end_of_role|>", |
23 | | - backend=backend, |
24 | | - ) |
| 15 | +custom_stembolt_failure_constraint = HFConstraintAlora( |
| 16 | + name="custom_stembolt_failure_constraint", |
| 17 | + path_or_model_id="docs/examples/aLora/checkpoints/alora_adapter", # can also be the checkpoint path |
| 18 | + generation_prompt="<|start_of_role|>check_requirement<|end_of_role|>", |
| 19 | + backend=backend, |
25 | 20 | ) |
26 | 21 |
|
| 22 | +backend.add_alora(custom_stembolt_failure_constraint) |
| 23 | + |
27 | 24 | # Create M session |
28 | 25 | m = MelleaSession(backend, ctx=LinearContext()) |
29 | 26 |
|
30 | 27 | # define a requirement |
31 | | -failure_check = req("The failure mode should not be none.") |
| 28 | +failure_check = ALoraRequirement( |
| 29 | + "The failure mode should not be none.", alora=custom_stembolt_failure_constraint |
| 30 | +) |
32 | 31 |
|
33 | 32 | # run instruction with requirement attached on the base model |
34 | 33 | res = m.instruct( |
35 | | - "Write triage summaries based on technician note.", requirements=[failure_check] |
| 34 | + """Write triage summaries based on technician note. |
| 35 | + 1. Oil seepage around piston rings suggests seal degradation |
| 36 | + """, |
| 37 | + requirements=[failure_check], |
36 | 38 | ) |
37 | 39 |
|
38 | 40 | print("==== Generation =====") |
@@ -77,9 +79,11 @@ def validate_reqs(reqs: list[Requirement]): |
77 | 79 |
|
78 | 80 |
|
79 | 81 | # run with aLora -- which is the default if the constraint alora is added to a model |
80 | | -validate_reqs([failure_check]) |
| 82 | +computetime_alora, alora_result = validate_reqs([failure_check]) |
81 | 83 |
|
| 84 | +# NOTE: This is not meant for use in regular programming using mellea, but just as an illustration for the speedup you can get with aloras. |
82 | 85 | # force to run without alora |
83 | 86 | backend.default_to_constraint_checking_alora = False |
84 | | -validate_reqs([failure_check]) |
85 | | -backend.default_to_constraint_checking_alora = True |
| 87 | +computetime_no_alora, no_alora_result = validate_reqs([failure_check]) |
| 88 | + |
| 89 | +print(f"Speed up time with using aloras is {computetime_alora - computetime_no_alora}") |
0 commit comments