Skip to content

Commit 6054fc2

Browse files
authored
Merge pull request #21 from microsoft/experimental
Experimental
2 parents f6715a2 + 4162782 commit 6054fc2

File tree

3 files changed

+173
-0
lines changed

3 files changed

+173
-0
lines changed

README.md

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,83 @@ for i in range(epoch):
135135

136136
Then, we can use the familiar PyTorch-like syntax to conduct the optimization.
137137

138+
Here is another example of a simple sales agent:
139+
140+
```python
141+
from opto import trace
142+
143+
@trace.model
144+
class Agent:
145+
146+
def __init__(self, system_prompt):
147+
self.system_prompt = system_prompt
148+
self.instruct1 = trace.node("Decide the language", trainable=True)
149+
self.instruct2 = trace.node("Extract name if it's there", trainable=True)
150+
151+
def __call__(self, user_query):
152+
response = trace.operators.call_llm(self.system_prompt,
153+
self.instruct1, user_query)
154+
en_or_es = self.decide_lang(response)
155+
156+
user_name = trace.operators.call_llm(self.system_prompt,
157+
self.instruct2, user_query)
158+
greeting = self.greet(en_or_es, user_name)
159+
160+
return greeting
161+
162+
@trace.bundle(trainable=True)
163+
def decide_lang(self, response):
164+
"""Map the language into a variable"""
165+
return
166+
167+
@trace.bundle(trainable=True)
168+
def greet(self, lang, user_name):
169+
"""Produce a greeting based on the language"""
170+
greeting = "Hola"
171+
return f"{greeting}, {user_name}!"
172+
```
173+
174+
Imagine we have a feedback function (like a reward function) that tells us how well the agent is doing. We can then optimize this agent online:
175+
176+
```python
177+
from opto.optimizers import OptoPrime
178+
179+
def feedback_fn(generated_response, gold_label='en'):
180+
if gold_label == 'en' and 'Hello' in generated_response:
181+
return "Correct"
182+
elif gold_label == 'es' and 'Hola' in generated_response:
183+
return "Correct"
184+
else:
185+
return "Incorrect"
186+
187+
def train():
188+
epoch = 3
189+
agent = Agent("You are a sales assistant.")
190+
optimizer = OptoPrime(agent.parameters())
191+
192+
for i in range(epoch):
193+
print(f"Training Epoch {i}")
194+
try:
195+
greeting = agent("Hola, soy Juan.")
196+
feedback = feedback_fn(greeting.data, 'es')
197+
except trace.ExecutionError as e:
198+
greeting = e.exception_node
199+
feedback, terminal, reward = greeting.data, False, 0
200+
201+
optimizer.zero_feedback()
202+
optimizer.backward(greeting, feedback)
203+
optimizer.step(verbose=True)
204+
205+
if feedback == 'Correct':
206+
break
207+
208+
return agent
209+
210+
agent = train()
211+
```
212+
213+
Defining and training an agent through Trace will give you more flexibility and control over what the agent learns.
214+
138215
## Tutorials
139216

140217
| **Level** | **Tutorial** | **Run in Colab** | **Description** |

examples/greeting.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from opto import trace
2+
from opto.trace import node, bundle, model, ExecutionError
3+
from opto.optimizers import OptoPrime
4+
5+
6+
@trace.model
7+
class Agent:
8+
9+
def __init__(self, system_prompt):
10+
self.system_prompt = system_prompt
11+
self.instruct1 = trace.node("Decide the language", trainable=True)
12+
self.instruct2 = trace.node("Extract name if it's there", trainable=True)
13+
14+
def __call__(self, user_query):
15+
response = trace.operators.call_llm(self.system_prompt,
16+
self.instruct1, user_query)
17+
en_or_es = self.decide_lang(response)
18+
19+
user_name = trace.operators.call_llm(self.system_prompt,
20+
self.instruct2, user_query)
21+
greeting = self.greet(en_or_es, user_name)
22+
23+
return greeting
24+
25+
@trace.bundle(trainable=True)
26+
def decide_lang(self, response):
27+
"""Map the language into a variable"""
28+
return
29+
30+
@trace.bundle(trainable=True)
31+
def greet(self, lang, user_name):
32+
"""Produce a greeting based on the language"""
33+
greeting = "Hola"
34+
return f"{greeting}, {user_name}!"
35+
36+
37+
def feedback_fn(generated_response, gold_label='en'):
38+
if gold_label == 'en' and 'Hello' in generated_response:
39+
return "Correct"
40+
elif gold_label == 'es' and 'Hola' in generated_response:
41+
return "Correct"
42+
else:
43+
return "Incorrect"
44+
45+
46+
def train():
47+
epoch = 3
48+
agent = Agent("You are a sales assistant.")
49+
optimizer = OptoPrime(agent.parameters())
50+
51+
for i in range(epoch):
52+
print(f"Training Epoch {i}")
53+
try:
54+
greeting = agent("Hola, soy Juan.")
55+
feedback = feedback_fn(greeting.data, 'es')
56+
except ExecutionError as e:
57+
greeting = e.exception_node
58+
feedback, terminal, reward = greeting.data, False, 0
59+
60+
optimizer.zero_feedback()
61+
optimizer.backward(greeting, feedback)
62+
optimizer.step(verbose=True)
63+
64+
if feedback == 'Correct':
65+
break
66+
67+
return agent
68+
69+
70+
class CorrectAgent:
71+
72+
def __init__(self, system_prompt):
73+
self.system_prompt = system_prompt
74+
self.instruct1 = node("Decide the language: es or en?", trainable=True)
75+
self.instruct2 = node("Extract name if it's there", trainable=True)
76+
77+
def __call__(self, user_query):
78+
response = trace.operators.call_llm(self.system_prompt, self.instruct1, user_query)
79+
en_or_es = self.decide_lang(response)
80+
81+
user_name = trace.operators.call_llm(self.system_prompt, self.instruct2, user_query)
82+
greeting = self.greet(en_or_es, user_name)
83+
84+
return greeting
85+
86+
@bundle(trainable=True)
87+
def decide_lang(self, response):
88+
"""Map the language into a variable"""
89+
return 'es' if 'es' or 'spanish' in response.lower() else 'en'
90+
91+
@bundle(trainable=True)
92+
def greet(self, lang, user_name):
93+
"""Produce a greeting based on the language"""
94+
greeting = "Hola" if lang.lower() == "es" else "Hello"
95+
return f"{greeting}, {user_name}!"

opto/trace/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from opto.trace.containers import NodeContainer
44
from opto.trace.broadcast import apply_op
55
import opto.trace.propagators as propagators
6+
import opto.trace.operators as operators
67

78
from opto.trace.nodes import Node, GRAPH
89
from opto.trace.nodes import node

0 commit comments

Comments
 (0)