diff --git a/Makefile b/Makefile
new file mode 100644
index 00000000..ec45bb82
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,19 @@
+.PHONY: help doc doc-deploy
+
+help:
+ @echo "Makefile for managing Jupyter Book documentation and deployment"
+ @echo ""
+ @echo "Usage:"
+ @echo " make doc - Build the documentation"
+ @echo " make doc-deploy - Deploy the documentation to GitHub Pages"
+ @echo " make help - Display this help message"
+ @echo ""
+ @echo "For more information, refer to the README or script documentation."
+
+doc:
+ @echo "Building documentation..."
+ @bash ./docs/jupyter_build.sh
+
+doc-deploy:
+ @echo "Deploying documentation to GitHub Pages..."
+ @ghp-import -n -p -f docs/_build/html
\ No newline at end of file
diff --git a/docs/_config.yml b/docs/_config.yml
index 64a8e27a..82390eee 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -4,7 +4,7 @@
title: Trace
author: Ching-An Cheng, Allen Nie, Adith Swaminathan
copyright: "2024 Trace Team"
-# logo: logo.png
+logo: trace_logo.png
only_build_toc_files: true
# Force re-execution of notebooks on each build.
@@ -45,24 +45,26 @@ sphinx:
html_context:
default_mode: light
extra_extensions:
- - sphinx_plausible
- - autodoc2
+ - 'sphinx_plausible'
- 'sphinx.ext.autodoc'
- 'sphinx.ext.napoleon'
+ - 'sphinx.ext.autosummary'
- 'sphinx.ext.viewcode'
config:
add_module_names: false
plausible_domain: microsoft.github.io/trace
nb_merge_streams: true
- autodoc2_index_template: null
- autodoc2_output_dir: api
- autodoc2_module_all_regexes:
- - "opto.*(trace|optimizers)"
- autodoc2_packages:
- - path: "../opto"
- autodoc2_hidden_objects:
- - inherited
- - private
- - dunder
- autodoc2_skip_module_regexes:
- - .*test.*
\ No newline at end of file
+ templates_path: ["_templates"]
+ autosummary_generate: True
+ autodoc_mock_imports: ['autogen']
+ suppress_warnings: ["etoc.toctree"]
+ # autodoc settings
+ # ref: https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#configuration
+ autoclass_content: both
+ autodoc_class_signature: separated
+ autodoc_member_order: groupwise
+ autodoc_docstring_signature: True
+ autodoc_typehints: signature
+ autodoc_typehints_format: short
+ autosummary_filename_map:
+ opto.trace.nodes.node: "opto.trace.nodes.node-function"
\ No newline at end of file
diff --git a/docs/_toc.yml b/docs/_toc.yml
index 6a6b1d2c..47cf5984 100644
--- a/docs/_toc.yml
+++ b/docs/_toc.yml
@@ -11,12 +11,6 @@ parts:
- file: quickstart/quick_start
- file: quickstart/quick_start_2
- file: quickstart/virtualhome
-
- - caption: FAQ
- numbered: false
- chapters:
- - file: faq/faq
-
- caption: 📚Tutorials
chapters:
- file: tutorials/basic_tutorial
@@ -41,6 +35,11 @@ parts:
chapters:
- file: examples/robotics/metaworld
+ - caption: FAQ
+ numbered: false
+ chapters:
+ - file: faq/faq
+
- caption: đź“– API Reference
chapters:
- - glob: api/opto/*
\ No newline at end of file
+ - file: api
\ No newline at end of file
diff --git a/docs/api.rst b/docs/api.rst
new file mode 100644
index 00000000..857476a6
--- /dev/null
+++ b/docs/api.rst
@@ -0,0 +1,6 @@
+.. autosummary::
+ :toctree: api/
+ :template: module.rst_t
+ :recursive:
+
+ opto
\ No newline at end of file
diff --git a/docs/faq/faq.md b/docs/faq/faq.md
index b5cfdb23..0c4e1926 100644
--- a/docs/faq/faq.md
+++ b/docs/faq/faq.md
@@ -1,6 +1,6 @@
# FAQ
-### Difference to Libraries like TextGrad
+## Difference to Libraries like TextGrad
TextGrad is both a library and an optimizer algorithm. Currently, we support three optimizers:
@@ -43,5 +43,5 @@ We provide a comparison to validate our implementation of TextGrad in Trace:
To produce this table, we ran the TextGrad pip-installed repo on 2024-10-30, and we also include the numbers reported in the TextGrad paper.
The LLM APIs are called around the same time to ensure a fair comparison. TextGrad paper's result was reported in 2024-06.
-### Difference to Libraries like AutoGen, AG2, OpenAI Swarm, Llama Stack
+## Difference to Libraries like AutoGen, AG2, OpenAI Swarm, Llama Stack
diff --git a/docs/intro.md b/docs/intro.md
index 326fc417..1fae750c 100644
--- a/docs/intro.md
+++ b/docs/intro.md
@@ -3,6 +3,9 @@
**Trace is a Python library for tracing and optimizing workflows end-to-end by using LLM-powered generative optimizers.**
**It can record *traces* of operations on any Python objects and functions, and automatically construct an execution graph that is useful when LLMs are used as optimizers.**
+
+
+
Our implementation is minimal and purely based on Python. It does not involve any API calls or library-specific dependencies, so it is composable with other libraries and tools.
Trace features an API design inspired by PyTorch Autograd's gradient tape mechanism, which we adopted to reduce the learning curve of using Trace.
These features make Trace an intuitive and flexible framework for building self-adapting AI agents.
@@ -25,14 +28,59 @@ This step is the **declare** phase where a user chooses how to represent the age
After the user has declared the inputs and operations, Trace captures the execution flow of the program as a graph. This step is the **forward** phase.
Finally, the user can optimize the entire program, such as by updating the LLM instructions, using Trace. This step is the **optimize** phase.
+```python
+@trace.model
+class Agent:
+
+ def __init__(self, system_prompt):
+ self.system_prompt = system_prompt
+ self.instruct1 = trace.node("Decide the language", trainable=True)
+ self.instruct2 = trace.node("Extract name", trainable=True)
+
+ def __call__(self, user_query):
+ # First LLM
+ response = call_llm(self.system_prompt, self.instruct1, user_query)
+ en_or_es = self.decide_lang(response)
+ # Second LLM
+ user_name = call_llm(self.system_prompt, self.instruct2, user_query)
+ greeting = self.greet(en_or_es, user_name)
+ return greeting
+
+ @trace.bundle(trainable=True)
+ def decide_lang(self, response):
+ """Map the language into a variable"""
+ return
+
+ @trace.bundle(trainable=True)
+ def greet(self, lang, user_name):
+ """Produce a greeting based on the language"""
+ greeting = "Hola"
+ return f"{greeting}, {user_name}!"
+```
+
Each application of Trace is defined by an **agent**, a source of **feedback**, and an **optimizer**.
Enabling traces of operations on Python objects allows us to capture the execution flow of an agent, including AI systems that involve LLMs.
In the example below, we show how Trace can optimize an entire AI system end-to-end.
+```python
+agent = Agent("You are a sales assistant.")
+optimizer = OptoPrime(agent.parameters())
+try:
+ greeting = agent("Hola, soy Juan.")
+ feedback = feedback_fn(greeting.data, 'es')
+ # feedback = "Correct" or "Incorrect"
+except ExecutionError as e:
+ greeting = e.exception_node
+ feedback = greeting.data,
+
+optimizer.zero_feedback()
+optimizer.backward(greeting, feedback)
+optimizer.step()
+```
----
+``` -->
\ No newline at end of file
diff --git a/docs/jupyter_build.sh b/docs/jupyter_build.sh
new file mode 100644
index 00000000..dd9ebf48
--- /dev/null
+++ b/docs/jupyter_build.sh
@@ -0,0 +1,16 @@
+#!/usr/bin/env bash
+cd "$(dirname "$0")/.." || exit
+rm -r docs/_build docs/api
+ORIGINAL_PYTHONPATH=$PYTHONPATH
+export PYTHONPATH=$(pwd)/..:$PYTHONPATH
+
+jupyter-book build docs
+
+# clean up sphinx-autosummary generated files
+rm -r docs/api
+
+# Restored PYTHONPATH
+export PYTHONPATH=$ORIGINAL_PYTHONPATH
+
+# move all files associated with the landing page into the `_build/html` folder
+python docs/post_build_script.py
\ No newline at end of file
diff --git a/docs/logo.png b/docs/logo.png
deleted file mode 100644
index efb3fbff..00000000
Binary files a/docs/logo.png and /dev/null differ
diff --git a/docs/quickstart/installation.md b/docs/quickstart/installation.md
index 85118d8b..1f4d1d1c 100644
--- a/docs/quickstart/installation.md
+++ b/docs/quickstart/installation.md
@@ -12,10 +12,10 @@ then we require `autogen` package to make LLM API calls.
To install Trace, run:
```{admonition} Installation Command
+
```bash
pip install trace-opt
```
-```
To contribute to the development, you can clone the repository and install the package in editable mode:
diff --git a/docs/readme.md b/docs/readme.md
index 7b4ec9cd..973cda5c 100644
--- a/docs/readme.md
+++ b/docs/readme.md
@@ -2,19 +2,16 @@ Steps of deployment:
IMPORTANT: checkout the `website` branch.
-1. Run `jupyter-book build docs` under the root directory to build the book. This will create a folder `_build/html` that has the static webpages.
-2. Run `python docs/post_build_script.py` to move all files associated with the landing page into the `_build/html` folder.
-3. Run `ghp-import -n -p -f docs/_build/html` to deploy the book to GitHub Pages (it creates a branch in the repo)
-
-Or simply run `docs/publish.sh` to run all the above commands.
+1. Run `make doc` under the root directory to build the book. This will create a folder `docs/_build/html` that has the static webpages.
+2. Run `make doc-deploy` to deploy the book to GitHub Pages (it creates a branch in the repo)
References:
-https://jupyterbook.org/en/stable/start/publish.html
+https://jupyterbook.org/en/stable/start/publish.html
A few notes:
1. There is no direct way to add an HTML page to Jupyter book.
-2. Run `pip install ghp-import` before step 3.
+2. Run `pip install -r requirements.txt` to install dependencies.
3. Do not manually modify `gh-pages` branch.
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 7e821e45..e5ad1997 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,3 +1,8 @@
jupyter-book
matplotlib
numpy
+sphinx
+sphinx-plausible
+# sphinx-autodoc2
+sphinx-autoapi
+ghp-import
\ No newline at end of file
diff --git a/docs/trace_logo.png b/docs/trace_logo.png
new file mode 100644
index 00000000..ceba3460
Binary files /dev/null and b/docs/trace_logo.png differ
diff --git a/docs/tutorials/custom_optimizers.ipynb b/docs/tutorials/custom_optimizers.ipynb
index 879d6ece..4a3d835a 100644
--- a/docs/tutorials/custom_optimizers.ipynb
+++ b/docs/tutorials/custom_optimizers.ipynb
@@ -13,7 +13,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Basic back-propagation and gradient descent with PyTorch\n",
+ "## Basic back-propagation and gradient descent with PyTorch\n",
"\n",
"To start, let's define a simple objective and run vanilla gradient descent to optimize the variable in pytorch. This code will be used as the reference of desired behaviors. We make the code below transparent for tutorial purppose, so we use the `torch.autograd.grad` api and write down the gradient descent update rule manually."
]
@@ -73,7 +73,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Set up the objective in Trace\n",
+ "## Set up the objective in Trace\n",
"\n",
"After seeing how ideally basic gradient descent + back-propagation behaves, next we show how it can be implemented it in Trace. To this end, we need to turn each math ops used in the above loss as a `bundle`, and define the parameter as a `node`. In this way, Trace can create a computational graph (DAG) of the workflow of computing the objective. We visualize the DAG below."
]
@@ -216,7 +216,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Version 1 Trace Implementation based on Optimizer\n",
+ "## Version 1 Trace Implementation based on Optimizer\n",
"\n",
"The first way is to implement the back-propagation algorithm as part of the optimizer in Trace. By default, optimzers in Trace receive the propagated Trace graph at the parameter nodes. Trace graph is a generalization of gradient. Here we show how we can implement back-propagation on the Trace graph to recover the propagated gradient and use it for gradient descent. We can see the loss sequence here matches what we had above implemented by PyTorch.\n"
]
@@ -286,7 +286,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Version 2 Trace Implementation based on Propagator + Optimizer\n",
+ "## Version 2 Trace Implementation based on Propagator + Optimizer\n",
"\n",
"Another way is to override the what's propagated in the `backward` call of Trace. Trace has a generic backward routine performed on the computational graph that can support designing new end-to-end optimization algorithms. While by default Trace propagates Trace graphes in `backward` for generality, for the differentiable problems here we can override the behavior and let it directly propagate gradients. In this way, the optimizer would receive directly the propagted gradient instead of Trace graphs.\n"
]
diff --git a/docs/tutorials/minibatch.ipynb b/docs/tutorials/minibatch.ipynb
index 19847024..80fd12b9 100644
--- a/docs/tutorials/minibatch.ipynb
+++ b/docs/tutorials/minibatch.ipynb
@@ -267,7 +267,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Batching Non-Commutative Feedbacks\n",
+ "## Batching Non-Commutative Feedbacks\n",
"\n",
"In the earlier numerical example, the loss function was commutative so that we can do `batch_loss += loss(each_input)`. What if the feedbacks received are not commutative? This can happen often with non-numeric (e.g. text) feedbacks. Here we will see a simple design pattern for using `trace` and `OptoPrime` for batch optimization in such cases."
]
diff --git a/opto/optimizers/buffers.py b/opto/optimizers/buffers.py
index 02841138..b0f2c9b9 100644
--- a/opto/optimizers/buffers.py
+++ b/opto/optimizers/buffers.py
@@ -7,7 +7,7 @@ def __init__(self, size: int):
def add(self, item):
if self.size > 0:
self.buffer.append(item)
- self.buffer = self.buffer[-self.size :]
+ self.buffer = self.buffer[-self.size:]
def __iter__(self):
return iter(self.buffer)
diff --git a/opto/optimizers/opro.py b/opto/optimizers/opro.py
index 60b0b749..22ae199c 100644
--- a/opto/optimizers/opro.py
+++ b/opto/optimizers/opro.py
@@ -19,7 +19,8 @@ class OPRO(OptoPrime):
output_format_prompt = dedent(
"""
- Output_format: Your output should be in the following json format, satisfying the json syntax:
+ Output_format: Your output should be in the following json format, satisfying
+ the json syntax:
{{
"suggestion": {{
@@ -28,7 +29,10 @@ class OPRO(OptoPrime):
}}
}}
- When suggestion variables, write down the suggested values in "suggestion". When of a variable is (code), you should write the new definition in the format of python code without syntax errors, and you should not change the function name or the function signature.
+ When suggestion variables, write down the suggested values in "suggestion".
+ When of a variable is (code), you should write the new definition in the
+ format of python code without syntax errors, and you should not change the
+ function name or the function signature.
If no changes or answer are needed, just output TERMINATE.
"""
@@ -57,5 +61,7 @@ def construct_prompt(self, summary, mask=None, *args, **kwargs):
)
examples = "\n".join(examples)
- user_prompt = self.user_prompt_template.format(examples=examples, instruction=self.objective)
- return self.output_format_prompt, user_prompt
\ No newline at end of file
+ user_prompt = self.user_prompt_template.format(
+ examples=examples, instruction=self.objective
+ )
+ return self.output_format_prompt, user_prompt
diff --git a/opto/optimizers/optimizer.py b/opto/optimizers/optimizer.py
index 33e1e3ae..c97865f2 100644
--- a/opto/optimizers/optimizer.py
+++ b/opto/optimizers/optimizer.py
@@ -29,9 +29,15 @@ def propagator(self):
class Optimizer(AbstractOptimizer):
- """ Optimizer based on Trace graph. """
-
- def __init__(self, parameters: List[ParameterNode], *args, propagator: Propagator = None, **kwargs):
+ """Optimizer based on Trace graph."""
+
+ def __init__(
+ self,
+ parameters: List[ParameterNode],
+ *args,
+ propagator: Propagator = None,
+ **kwargs
+ ):
super().__init__(parameters)
propagator = propagator if propagator is not None else self.default_propagator()
assert isinstance(propagator, Propagator)
@@ -43,7 +49,7 @@ def propagator(self):
@property
def trace_graph(self):
- """ Aggregate the graphs of all the parameters. """
+ """Aggregate the graphs of all the parameters."""
return sum_feedback(self.parameters)
def step(self, *args, **kwargs):
diff --git a/opto/optimizers/optoprime.py b/opto/optimizers/optoprime.py
index f47ad2c4..26dfc0fb 100644
--- a/opto/optimizers/optoprime.py
+++ b/opto/optimizers/optoprime.py
@@ -1,4 +1,3 @@
-
from typing import Any, List, Dict, Union, Tuple
from dataclasses import dataclass, asdict
from textwrap import dedent, indent
@@ -248,7 +247,6 @@ class OptoPrime(Optimizer):
"documentation": "#Documentation",
}
-
def __init__(
self,
parameters: List[ParameterNode],
@@ -262,7 +260,7 @@ def __init__(
max_tokens=4096,
log=True,
prompt_symbols=None,
- filter_dict : Dict = None, # autogen filter_dict
+ filter_dict: Dict = None, # autogen filter_dict
**kwargs,
):
super().__init__(parameters, *args, propagator=propagator, **kwargs)
@@ -305,7 +303,11 @@ def default_propagator(self):
def summarize(self):
# Aggregate feedback from all the parameters
- feedbacks = [self.propagator.aggregate(node.feedback) for node in self.parameters if node.trainable]
+ feedbacks = [
+ self.propagator.aggregate(node.feedback)
+ for node in self.parameters
+ if node.trainable
+ ]
summary = sum(feedbacks) # TraceGraph
# Construct variables and update others
# Some trainable nodes might not receive feedback, because they might not be connected to the output
@@ -315,10 +317,14 @@ def summarize(self):
trainable_param_dict = {p.py_name: p for p in self.parameters if p.trainable}
summary.variables = {
- py_name: data for py_name, data in summary.roots.items() if py_name in trainable_param_dict
+ py_name: data
+ for py_name, data in summary.roots.items()
+ if py_name in trainable_param_dict
}
summary.inputs = {
- py_name: data for py_name, data in summary.roots.items() if py_name not in trainable_param_dict
+ py_name: data
+ for py_name, data in summary.roots.items()
+ if py_name not in trainable_param_dict
} # non-variable roots
return summary
@@ -348,29 +354,52 @@ def repr_node_constraint(node_dict):
def problem_instance(self, summary, mask=None):
mask = mask or []
return ProblemInstance(
- instruction=self.objective if '#Instruction' not in mask else "",
- code="\n".join([v for k, v in sorted(summary.graph)]) if "#Code" not in mask else "",
- documentation="\n".join([v for v in summary.documentation.values()])
- if "#Documentation" not in mask
- else "",
- variables=self.repr_node_value(summary.variables) if "#Variables" not in mask else "",
- constraints=self.repr_node_constraint(summary.variables) if "#Constraints" not in mask else "",
- inputs=self.repr_node_value(summary.inputs) if "#Inputs" not in mask else "",
- outputs=self.repr_node_value(summary.output) if "#Outputs" not in mask else "",
- others=self.repr_node_value(summary.others) if "#Others" not in mask else "",
+ instruction=self.objective if "#Instruction" not in mask else "",
+ code=(
+ "\n".join([v for k, v in sorted(summary.graph)])
+ if "#Code" not in mask
+ else ""
+ ),
+ documentation=(
+ "\n".join([v for v in summary.documentation.values()])
+ if "#Documentation" not in mask
+ else ""
+ ),
+ variables=(
+ self.repr_node_value(summary.variables)
+ if "#Variables" not in mask
+ else ""
+ ),
+ constraints=(
+ self.repr_node_constraint(summary.variables)
+ if "#Constraints" not in mask
+ else ""
+ ),
+ inputs=(
+ self.repr_node_value(summary.inputs) if "#Inputs" not in mask else ""
+ ),
+ outputs=(
+ self.repr_node_value(summary.output) if "#Outputs" not in mask else ""
+ ),
+ others=(
+ self.repr_node_value(summary.others) if "#Others" not in mask else ""
+ ),
feedback=summary.user_feedback if "#Feedback" not in mask else "",
)
def construct_prompt(self, summary, mask=None, *args, **kwargs):
"""Construct the system and user prompt."""
- system_prompt = self.representation_prompt + self.output_format_prompt # generic representation + output rule
+ system_prompt = (
+ self.representation_prompt + self.output_format_prompt
+ ) # generic representation + output rule
user_prompt = self.user_prompt_template.format(
problem_instance=str(self.problem_instance(summary, mask=mask))
) # problem instance
if self.include_example:
user_prompt = (
self.example_problem_template.format(
- example_problem=self.example_problem, example_response=self.example_response
+ example_problem=self.example_problem,
+ example_response=self.example_response,
)
+ user_prompt
)
@@ -405,7 +434,9 @@ def replace_symbols(self, text: str, symbols: Dict[str, str]) -> str:
text = text.replace(self.default_prompt_symbols[k], v)
return text
- def _step(self, verbose=False, mask=None, *args, **kwargs) -> Dict[ParameterNode, Any]:
+ def _step(
+ self, verbose=False, mask=None, *args, **kwargs
+ ) -> Dict[ParameterNode, Any]:
assert isinstance(self.propagator, GraphPropagator)
summary = self.summarize()
system_prompt, user_prompt = self.construct_prompt(summary, mask=mask)
@@ -414,7 +445,10 @@ def _step(self, verbose=False, mask=None, *args, **kwargs) -> Dict[ParameterNode
user_prompt = self.replace_symbols(user_prompt, self.prompt_symbols)
response = self.call_llm(
- system_prompt=system_prompt, user_prompt=user_prompt, verbose=verbose, max_tokens=self.max_tokens
+ system_prompt=system_prompt,
+ user_prompt=user_prompt,
+ verbose=verbose,
+ max_tokens=self.max_tokens,
)
if "TERMINATE" in response:
@@ -424,12 +458,22 @@ def _step(self, verbose=False, mask=None, *args, **kwargs) -> Dict[ParameterNode
update_dict = self.construct_update_dict(suggestion)
if self.log is not None:
- self.log.append({"system_prompt": system_prompt, "user_prompt": user_prompt, "response": response})
- self.summary_log.append({'problem_instance': self.problem_instance(summary), 'summary': summary})
+ self.log.append(
+ {
+ "system_prompt": system_prompt,
+ "user_prompt": user_prompt,
+ "response": response,
+ }
+ )
+ self.summary_log.append(
+ {"problem_instance": self.problem_instance(summary), "summary": summary}
+ )
return update_dict
- def construct_update_dict(self, suggestion: Dict[str, Any]) -> Dict[ParameterNode, Any]:
+ def construct_update_dict(
+ self, suggestion: Dict[str, Any]
+ ) -> Dict[ParameterNode, Any]:
"""Convert the suggestion in text into the right data type."""
# TODO: might need some automatic type conversion
update_dict = {}
@@ -491,19 +535,26 @@ def extract_llm_suggestion(self, response: str):
# if the suggested value is a code, and the entire code body is empty (i.e., not even function signature is present)
# then we remove such suggestion
for key, value in suggestion.items():
- if "__code" in key and value == '':
+ if "__code" in key and value == "":
del suggestion[key]
return suggestion
def call_llm(
- self, system_prompt: str, user_prompt: str, verbose: Union[bool, str] = False, max_tokens: int = 4096
+ self,
+ system_prompt: str,
+ user_prompt: str,
+ verbose: Union[bool, str] = False,
+ max_tokens: int = 4096,
):
"""Call the LLM with a prompt and return the response."""
if verbose not in (False, "output"):
print("Prompt\n", system_prompt + user_prompt)
- messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
+ messages = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt},
+ ]
try: # Try tp force it to be a json object
response = self.llm(
@@ -517,4 +568,4 @@ def call_llm(
if verbose:
print("LLM response:\n", response)
- return response
\ No newline at end of file
+ return response
diff --git a/opto/optimizers/optoprimemulti.py b/opto/optimizers/optoprimemulti.py
index cd55314f..dc680187 100644
--- a/opto/optimizers/optoprimemulti.py
+++ b/opto/optimizers/optoprimemulti.py
@@ -7,28 +7,39 @@
class OptoPrimeMulti(OptoPrime):
- def __init__(self, *args,
- num_responses: int = 5,
- temperature_range: Optional[List[float]] = None,
- selector: Optional[callable] = None,
- **kwargs):
+ def __init__(
+ self,
+ *args,
+ num_responses: int = 5,
+ temperature_range: Optional[List[float]] = None,
+ selector: Optional[callable] = None,
+ **kwargs,
+ ):
super().__init__(*args, **kwargs)
if temperature_range is None:
- self.temperature_range = [1.3, 0.]
+ self.temperature_range = [1.3, 0.0]
self.candidates = [] # Store all candidate solutions
self.selected_candidate = None # Store the selected candidate solution
self.num_responses = num_responses
self.selector = selector
def call_llm(
- self, system_prompt: str, user_prompt: str, verbose: Union[bool, str] = False,
- max_tokens: int = 4096, num_responses: int = 1, temperature: float = 0.
+ self,
+ system_prompt: str,
+ user_prompt: str,
+ verbose: Union[bool, str] = False,
+ max_tokens: int = 4096,
+ num_responses: int = 1,
+ temperature: float = 0.0,
) -> List[str]:
"""Call the LLM with a prompt and return multiple responses."""
if verbose not in (False, "output"):
print("Prompt\n", system_prompt + user_prompt)
- messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
+ messages = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt},
+ ]
try:
response = self.llm.create(
@@ -51,8 +62,15 @@ def call_llm(
return responses
def generate_candidates(
- self, summary, system_prompt: str, user_prompt: str, verbose: Union[bool, str] = False,
- mask=None, max_tokens: int = None, num_responses: Optional[int] = None, temperature_range: Optional[List[float]] = None
+ self,
+ summary,
+ system_prompt: str,
+ user_prompt: str,
+ verbose: Union[bool, str] = False,
+ mask=None,
+ max_tokens: int = None,
+ num_responses: Optional[int] = None,
+ temperature_range: Optional[List[float]] = None,
) -> List[str]:
"""
Generate multiple candidates with progressively decreasing temperatures.
@@ -68,11 +86,19 @@ def generate_candidates(
Returns:
List[str]: List of LLM responses as strings.
"""
- num_responses = num_responses if num_responses is not None else self.num_responses # Allow overriding num_responses
- temperature_range = temperature_range if temperature_range is not None else self.temperature_range
+ num_responses = (
+ num_responses if num_responses is not None else self.num_responses
+ ) # Allow overriding num_responses
+ temperature_range = (
+ temperature_range
+ if temperature_range is not None
+ else self.temperature_range
+ )
max_tokens = max_tokens or self.max_tokens # Allow overriding max_tokens
- max_temp, min_temp = max(temperature_range), min(temperature_range) # Ensure max > min
+ max_temp, min_temp = max(temperature_range), min(
+ temperature_range
+ ) # Ensure max > min
temperatures = [
max_temp - i * (max_temp - min_temp) / max(1, num_responses - 1)
for i in range(num_responses)
@@ -88,14 +114,24 @@ def generate_candidates(
verbose=verbose,
max_tokens=max_tokens,
num_responses=1,
- temperature=temp
- )[0] # Extract the single response
+ temperature=temp,
+ )[
+ 0
+ ] # Extract the single response
for temp in temperatures
]
if self.log is not None:
- self.log.append({"system_prompt": system_prompt, "user_prompt": user_prompt, "response": candidates})
- self.summary_log.append({'problem_instance': self.problem_instance(summary), 'summary': summary})
+ self.log.append(
+ {
+ "system_prompt": system_prompt,
+ "user_prompt": user_prompt,
+ "response": candidates,
+ }
+ )
+ self.summary_log.append(
+ {"problem_instance": self.problem_instance(summary), "summary": summary}
+ )
return candidates
@@ -110,8 +146,14 @@ def select_candidate(self, candidates: List[Dict]) -> Dict: # Fixed type annota
return candidates[-1] if candidates else {} # Default to the last candidate
def _step(
- self, verbose=False, mask=None, num_responses: Optional[int] = None, temperature_range: Optional[List[float]] = None,
- selector: callable = None, *args, **kwargs
+ self,
+ verbose=False,
+ mask=None,
+ num_responses: Optional[int] = None,
+ temperature_range: Optional[List[float]] = None,
+ selector: callable = None,
+ *args,
+ **kwargs,
) -> Dict: # Added type annotation for return value
"""
Perform a single optimization step, storing responses in self.responses and allowing selection.
@@ -124,8 +166,14 @@ def _step(
Returns:
Dict: The update dictionary based on the selected response.
"""
- num_responses = num_responses if num_responses is not None else self.num_responses # Allow overriding num_responses
- temperature_range = temperature_range if temperature_range is not None else self.temperature_range
+ num_responses = (
+ num_responses if num_responses is not None else self.num_responses
+ ) # Allow overriding num_responses
+ temperature_range = (
+ temperature_range
+ if temperature_range is not None
+ else self.temperature_range
+ )
selector = selector if selector is not None else self.selector
assert isinstance(self.propagator, GraphPropagator)
@@ -137,8 +185,13 @@ def _step(
# Generate candidates
responses = self.generate_candidates(
- summary, system_prompt, user_prompt, verbose=verbose, mask=mask,
- num_responses=num_responses, temperature_range=temperature_range
+ summary,
+ system_prompt,
+ user_prompt,
+ verbose=verbose,
+ mask=mask,
+ num_responses=num_responses,
+ temperature_range=temperature_range,
)
self.candidates = [] # Clear previous responses
diff --git a/opto/optimizers/textgrad.py b/opto/optimizers/textgrad.py
index 72cab4a8..e63a026d 100644
--- a/opto/optimizers/textgrad.py
+++ b/opto/optimizers/textgrad.py
@@ -61,9 +61,7 @@
"Here is the context and feedback we got for the variable:\n\n"
)
-TGD_MULTIPART_PROMPT_PREFIX = (
- "Improve the variable ({variable_desc}) using the feedback provided in tags.\n"
-)
+TGD_MULTIPART_PROMPT_PREFIX = "Improve the variable ({variable_desc}) using the feedback provided in tags.\n"
TGD_PROMPT_SUFFIX = (
"Send the improved variable "
@@ -89,10 +87,12 @@
)
-def construct_tgd_prompt(do_momentum: bool = False,
- do_constrained: bool = False,
- do_in_context_examples: bool = False,
- **optimizer_kwargs):
+def construct_tgd_prompt(
+ do_momentum: bool = False,
+ do_constrained: bool = False,
+ do_in_context_examples: bool = False,
+ **optimizer_kwargs,
+):
"""
Construct the textual gradient descent prompt.
@@ -113,7 +113,9 @@ def construct_tgd_prompt(do_momentum: bool = False,
else:
gradient_context = optimizer_kwargs["variable_grad"]
- gradient_context = [TGD_MULTIPART_PROMPT_INIT.format(**optimizer_kwargs)] + gradient_context
+ gradient_context = [
+ TGD_MULTIPART_PROMPT_INIT.format(**optimizer_kwargs)
+ ] + gradient_context
multipart = True
prompt = TGD_MULTIPART_PROMPT_PREFIX.format(**optimizer_kwargs)
@@ -170,7 +172,8 @@ def construct_tgd_prompt(do_momentum: bool = False,
"Only provide strategies, explanations, and methods to change in the variable. DO NOT propose a new version of the variable, that will be the job of the optimizer. Your only job is to send feedback and criticism (compute 'gradients'). "
"For instance, feedback can be in the form of 'Since language models have the X failure mode...', 'Adding X can fix this error because...', 'Removing X can improve the objective function because...', 'Changing X to Y would fix the mistake ...', that gets at the downstream objective.\n"
"If a variable is already working well (e.g. the objective function is perfect, an evaluation shows the response is accurate), you should not give feedback.\n"
- f"{GLOSSARY_TEXT_BACKWARD}")
+ f"{GLOSSARY_TEXT_BACKWARD}"
+)
# First part of the prompt for the llm backward function
CONVERSATION_TEMPLATE = (
@@ -223,10 +226,6 @@ def construct_tgd_prompt(do_momentum: bool = False,
"{results_gradient}\n\n"
)
-IN_CONTEXT_EXAMPLE_PROMPT_ADDITION = (
- "You must base on the following examples when give feedback and criticism to the variable:\n\n"
- "{in_context_examples}\n\n"
-)
"""
Gradient accumulation: reduce / sum
@@ -283,7 +282,7 @@ def rm_node_attrs(text: str) -> str:
Returns:
String with trace node attributes removed
"""
- return re.sub(r'\[.*?\]', '', text).strip()
+ return re.sub(r"\[.*?\]", "", text).strip()
def get_short_value(text, n_words_offset: int = 10) -> str:
@@ -299,44 +298,57 @@ def get_short_value(text, n_words_offset: int = 10) -> str:
words = text.split(" ")
if len(words) <= 2 * n_words_offset:
return text
- short_value = " ".join(words[:n_words_offset]) + " (...) " + " ".join(words[-n_words_offset:])
+ short_value = (
+ " ".join(words[:n_words_offset]) + " (...) " + " ".join(words[-n_words_offset:])
+ )
return short_value
class TextGrad(Optimizer):
- def __init__(self, parameters: List[ParameterNode],
- llm: AutoGenLLM = None,
- *args,
- propagator: Propagator = None,
- objective: Union[None, str] = None,
- max_tokens=4096,
- log=False,
- **kwargs, ):
+ def __init__(
+ self,
+ parameters: List[ParameterNode],
+ llm: AutoGenLLM = None,
+ *args,
+ propagator: Propagator = None,
+ objective: Union[None, str] = None,
+ max_tokens=4096,
+ log=False,
+ **kwargs,
+ ):
super().__init__(parameters, *args, **kwargs)
self.llm = llm or AutoGenLLM()
self.print_limit = 100
self.max_tokens = max_tokens
self.new_variable_tags = ["", ""]
- self.optimizer_system_prompt = OPTIMIZER_SYSTEM_PROMPT.format(new_variable_start_tag=self.new_variable_tags[0],
- new_variable_end_tag=self.new_variable_tags[1])
+ self.optimizer_system_prompt = OPTIMIZER_SYSTEM_PROMPT.format(
+ new_variable_start_tag=self.new_variable_tags[0],
+ new_variable_end_tag=self.new_variable_tags[1],
+ )
self.log = [] if log else None
def _construct_backward_prompt(self, backward_info):
conversation = CONVERSATION_TEMPLATE.format(**backward_info)
- backward_prompt = CONVERSATION_START_INSTRUCTION_BASE.format(conversation=conversation, **backward_info)
+ backward_prompt = CONVERSATION_START_INSTRUCTION_BASE.format(
+ conversation=conversation, **backward_info
+ )
backward_prompt += OBJECTIVE_INSTRUCTION_BASE.format(**backward_info)
backward_prompt += EVALUATE_VARIABLE_INSTRUCTION.format(**backward_info)
return backward_prompt
def _construct_chain_backward_prompt(self, backward_info) -> str:
conversation = CONVERSATION_TEMPLATE.format(**backward_info)
- backward_prompt = CONVERSATION_START_INSTRUCTION_CHAIN.format(conversation=conversation, **backward_info)
+ backward_prompt = CONVERSATION_START_INSTRUCTION_CHAIN.format(
+ conversation=conversation, **backward_info
+ )
backward_prompt += OBJECTIVE_INSTRUCTION_CHAIN.format(**backward_info)
backward_prompt += EVALUATE_VARIABLE_INSTRUCTION.format(**backward_info)
return backward_prompt
- def _grad(self, input_node: Node, parent_nodes, gradient_text, verbose=False) -> List[GradientInfo]:
+ def _grad(
+ self, input_node: Node, parent_nodes, gradient_text, verbose=False
+ ) -> List[GradientInfo]:
"""
https://github.com/zou-group/textgrad/blob/main/textgrad/autograd/llm_ops.py#L174
@@ -353,16 +365,19 @@ def _grad(self, input_node: Node, parent_nodes, gradient_text, verbose=False) ->
"response_gradient": gradient_text,
"prompt": var_node.data, # prompt = input to the operation
"variable_desc": rm_node_attrs(var_node.description),
- "variable_short": get_short_value(var_node.data)
+ "variable_short": get_short_value(var_node.data),
}
backward_prompt = self._construct_chain_backward_prompt(backward_info)
- gradient_value = self.call_llm(user_prompt=backward_prompt, system_prompt=BACKWARD_SYSTEM_PROMPT,
- verbose=verbose)
+ gradient_value = self.call_llm(
+ user_prompt=backward_prompt,
+ system_prompt=BACKWARD_SYSTEM_PROMPT,
+ verbose=verbose,
+ )
conversation = CONVERSATION_TEMPLATE.format(**backward_info)
gradients_context = {
"context": conversation,
"response_desc": rm_node_attrs(input_node.description),
- "variable_desc": rm_node_attrs(var_node.description)
+ "variable_desc": rm_node_attrs(var_node.description),
}
propagated_grads.append(GradientInfo(gradient_value, gradients_context))
@@ -373,9 +388,11 @@ def _reduce_gradient_mean(self, gradients: List[GradientInfo], verbose=False):
return gradients[0]
else:
gradient_reduce_prompt = construct_reduce_prompt(gradients)
- reduced_gradient = self.call_llm(user_prompt=gradient_reduce_prompt,
- system_prompt=REDUCE_MEAN_SYSTEM_PROMPT,
- verbose=verbose)
+ reduced_gradient = self.call_llm(
+ user_prompt=gradient_reduce_prompt,
+ system_prompt=REDUCE_MEAN_SYSTEM_PROMPT,
+ verbose=verbose,
+ )
return reduced_gradient
def _get_gradient_and_context_text(self, gradients: List[GradientInfo]):
@@ -385,7 +402,8 @@ def _get_gradient_and_context_text(self, gradients: List[GradientInfo]):
gradient_content.append(g.gradient)
else:
criticism_and_context = GRADIENT_TEMPLATE.format(
- feedback=g.gradient, **g.gradient_context)
+ feedback=g.gradient, **g.gradient_context
+ )
gradient_content.append(criticism_and_context)
return "\n".join(gradient_content)
@@ -403,24 +421,28 @@ def _update_prompt(self, node: Node, gradients: List[GradientInfo]):
# "gradient_memory": gradient_memory
}
- prompt = construct_tgd_prompt(do_constrained=True,
- do_in_context_examples=False,
- do_gradient_memory=False,
- **optimizer_information)
+ prompt = construct_tgd_prompt(
+ do_constrained=True,
+ do_in_context_examples=False,
+ do_gradient_memory=False,
+ **optimizer_information,
+ )
return prompt
def _step(self, verbose=False):
# aggregate the trace graphes into one.
trace_graph = copy(self.trace_graph)
- # make sure it's sorted
- graph = sorted(trace_graph.graph, key=lambda x: x[0]) # sort by level
# this is the same as gradient memory
- grads = defaultdict(list) # accumulated gradient (same as variable.get_gradient_text())
+ grads = defaultdict(
+ list
+ ) # accumulated gradient (same as variable.get_gradient_text())
# trace_graph.graph is a list of nodes sorted according to the topological order
- for i, (_, x) in enumerate(reversed(graph)): # back-propagation starts from the last node
+ for i, (_, x) in enumerate(
+ reversed(trace_graph.graph)
+ ): # back-propagation starts from the last node
if len(x.parents) == 0:
continue
# we take the gradient step-by-step
@@ -446,10 +468,17 @@ def _step(self, verbose=False):
for p in self.parameters:
gradients = grads[p]
prompt_update_parameter = self._update_prompt(p, gradients)
- response = self.call_llm(user_prompt=prompt_update_parameter, system_prompt=self.optimizer_system_prompt,
- verbose=verbose)
+ response = self.call_llm(
+ user_prompt=prompt_update_parameter,
+ system_prompt=self.optimizer_system_prompt,
+ verbose=verbose,
+ )
try:
- var_json = response.split(self.new_variable_tags[0])[1].split(self.new_variable_tags[1])[0].strip()
+ var_json = (
+ response.split(self.new_variable_tags[0])[1]
+ .split(self.new_variable_tags[1])[0]
+ .strip()
+ )
# processing to fix JSON
# var_json = remove_non_ascii(escape_json_nested_quotes(var_json).replace("\n", "\\n"))
# new_proposal = json.loads(var_json)
@@ -462,18 +491,23 @@ def _step(self, verbose=False):
print(f"Error in updating {p.py_name}: {e}, raw response: {response}")
if self.log is not None:
- self.log.append({"user_prompt": prompt_update_parameter, "response": response})
+ self.log.append(
+ {"user_prompt": prompt_update_parameter, "response": response}
+ )
return update_dict # propose new update
def call_llm(
- self, system_prompt: str, user_prompt: str, verbose: Union[bool, str] = False
+ self, system_prompt: str, user_prompt: str, verbose: Union[bool, str] = False
):
"""Call the LLM with a prompt and return the response."""
if verbose not in (False, "output"):
- print("Prompt\n", system_prompt + '\n\n' + user_prompt)
+ print("Prompt\n", system_prompt + "\n\n" + user_prompt)
- messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
+ messages = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt},
+ ]
try:
response = self.llm.create(
@@ -486,4 +520,4 @@ def call_llm(
if verbose:
print("LLM response:\n", response)
- return response
\ No newline at end of file
+ return response
diff --git a/opto/optimizers/utils.py b/opto/optimizers/utils.py
index c81295f4..2f076f54 100644
--- a/opto/optimizers/utils.py
+++ b/opto/optimizers/utils.py
@@ -1,13 +1,15 @@
def print_color(message, color=None, logger=None):
colors = {
- 'red': '\033[91m',
- 'green': '\033[92m',
- 'yellow': '\033[93m',
- 'blue': '\033[94m',
- 'magenta': '\033[95m',
- 'cyan': '\033[96m'
+ "red": "\033[91m",
+ "green": "\033[92m",
+ "yellow": "\033[93m",
+ "blue": "\033[94m",
+ "magenta": "\033[95m",
+ "cyan": "\033[96m",
}
- print(f"{colors.get(color, '')}{message}\033[0m") # Default to no color if invalid color is provided
+ print(
+ f"{colors.get(color, '')}{message}\033[0m"
+ ) # Default to no color if invalid color is provided
if logger is not None:
- logger.log(message)
\ No newline at end of file
+ logger.log(message)
diff --git a/opto/trace/__init__.py b/opto/trace/__init__.py
index b0bb9067..0e6d38fd 100644
--- a/opto/trace/__init__.py
+++ b/opto/trace/__init__.py
@@ -8,6 +8,7 @@
from opto.trace.nodes import Node, GRAPH
from opto.trace.nodes import node
+
class stop_tracing:
"""A contextmanager to disable tracing."""
@@ -19,9 +20,15 @@ def __exit__(self, type, value, traceback):
__all__ = [
- 'node', 'stop_tracing', 'GRAPH', 'Node',
- 'bundle', 'ExecutionError',
- 'Module', 'NodeContainer', 'model',
- 'apply_op',
- 'propagators'
-]
\ No newline at end of file
+ "node",
+ "stop_tracing",
+ "GRAPH",
+ "Node",
+ "bundle",
+ "ExecutionError",
+ "Module",
+ "NodeContainer",
+ "model",
+ "apply_op",
+ "propagators",
+]
diff --git a/opto/trace/broadcast.py b/opto/trace/broadcast.py
index c88ff6d4..f157aa1f 100644
--- a/opto/trace/broadcast.py
+++ b/opto/trace/broadcast.py
@@ -5,7 +5,7 @@
def recursive_conversion(true_func, false_func):
- """ Recursively apply true_func to the nodes and false_func to the rest of
+ """Recursively apply true_func to the nodes and false_func to the rest of
the objects in a container of nodes. Container of nodes are tuple, list,
dict, set, and NodeContainer.
@@ -14,6 +14,7 @@ def recursive_conversion(true_func, false_func):
false_func (callable): the function to be applied to the rest of the objects.
"""
+
def func(obj):
if isinstance(obj, Node): # base case
return true_func(obj)
@@ -32,6 +33,7 @@ def func(obj):
return output
else:
return false_func(obj)
+
return func
@@ -55,9 +57,11 @@ def apply_op(op, output, *args, **kwargs):
# output = copy.deepcopy(containers[0]) # this would be used as the template of the output
def admissible_type(x, base):
- return type(x) == type(base) or isinstance(x, Node)
+ return type(x) is type(base) or isinstance(x, Node)
- assert all(admissible_type(x, output) for x in inputs) # All inputs are either Nodes or the same type as output
+ assert all(
+ admissible_type(x, output) for x in inputs
+ ) # All inputs are either Nodes or the same type as output
if isinstance(output, list) or isinstance(output, tuple):
assert all(
@@ -65,7 +69,9 @@ def admissible_type(x, base):
), f"output {output} and inputs {inputs} are of different lengths."
for k in range(len(output)):
_args = [x if isinstance(x, Node) else x[k] for x in args]
- _kwargs = {kk: vv if isinstance(vv, Node) else vv[k] for kk, vv in kwargs.items()}
+ _kwargs = {
+ kk: vv if isinstance(vv, Node) else vv[k] for kk, vv in kwargs.items()
+ }
output[k] = apply_op(op, output[k], *_args, **_kwargs)
if isinstance(output, tuple):
output = tuple(output)
@@ -73,13 +79,18 @@ def admissible_type(x, base):
elif isinstance(output, dict):
for k, v in output.items():
_args = [x if isinstance(x, Node) else x[k] for x in args]
- _kwargs = {kk: vv if isinstance(vv, Node) else vv[k] for kk, vv in kwargs.items()}
+ _kwargs = {
+ kk: vv if isinstance(vv, Node) else vv[k] for kk, vv in kwargs.items()
+ }
output[k] = apply_op(op, output[k], *_args, **_kwargs)
elif isinstance(output, NodeContainer): # this is a NodeContainer object instance
for k, v in output.__dict__.items():
_args = [x if isinstance(x, Node) else getattr(x, k) for x in args]
- _kwargs = {kk: vv if isinstance(v, Node) else getattr(vv, k) for kk, vv in kwargs.items()}
+ _kwargs = {
+ kk: vv if isinstance(v, Node) else getattr(vv, k)
+ for kk, vv in kwargs.items()
+ }
new_v = apply_op(op, v, *_args, **_kwargs)
setattr(output, k, new_v)
else:
diff --git a/opto/trace/bundle.py b/opto/trace/bundle.py
index 8c1115fa..048aed45 100644
--- a/opto/trace/bundle.py
+++ b/opto/trace/bundle.py
@@ -13,7 +13,15 @@
from opto.trace.errors import ExecutionError, TraceMissingInputsError
from opto.trace.modules import Module
from opto.trace.nodes import GRAPH
-from opto.trace.nodes import MessageNode, USED_NODES, Node, ParameterNode, ExceptionNode, node, get_op_name
+from opto.trace.nodes import (
+ MessageNode,
+ USED_NODES,
+ Node,
+ ParameterNode,
+ ExceptionNode,
+ node,
+ get_op_name,
+)
from opto.trace.utils import contain
@@ -28,7 +36,7 @@ def bundle(
):
"""Wrap a function as a FunModule which returns node objects.
- The input signature to the wrapped function stays the same. bundle can be used with other decorators
+ The input signature to the wrapped function stays the same. bundle can be used with other decorators
so long as they are not named 'bundle'.
Args:
@@ -44,8 +52,9 @@ def bundle(
FunModule: The wrapped function that returns node objects.
"""
prev_f_locals = inspect.stack()[1].frame.f_locals
+
def decorator(fun):
- fun_module= FunModule(
+ fun_module = FunModule(
fun=fun,
description=description,
traceable_code=traceable_code,
@@ -57,6 +66,7 @@ def decorator(fun):
_ldict=prev_f_locals, # Get the locals of the calling function
)
return fun_module
+
return decorator
@@ -101,14 +111,15 @@ def __init__(
_ldict=None,
):
- assert _ldict is None or isinstance(_ldict, dict), "_ldict must be a dictionary. or None"
+ assert _ldict is None or isinstance(
+ _ldict, dict
+ ), "_ldict must be a dictionary. or None"
self._ldict = {} if _ldict is None else _ldict.copy()
-
assert callable(fun), "fun must be a callable."
# Get the source code of the function, excluding the decorator line
- source, line_number = self.get_source(fun)
+ source, line_number = self.get_source(fun)
# Construct the info dictionary
docstring = inspect.getdoc(fun)
@@ -149,12 +160,16 @@ def __init__(
# assert overwrite_python_recursion, "trainable requires overwrite_python_recursion to be True."
signature_sr = re.search(r"\s*(def.*\"\"\")", source, re.DOTALL)
- if signature_sr is None: # if there is no docstring just take the first line
+ if (
+ signature_sr is None
+ ): # if there is no docstring just take the first line
signature = re.search(r"\s*(def.*:)", source).group(1)
else:
signature = signature_sr.group(1)
self.parameter = ParameterNode(
- self.info["source"], name="__code", constraint="The code should start with:\n" + signature
+ self.info["source"],
+ name="__code",
+ constraint="The code should start with:\n" + signature,
)
@property
@@ -163,13 +178,15 @@ def trainable(self):
@property
def fun(self, *args, **kwargs):
- """ Return a callable function. Return the decorated function if the parameter is None. Otherwise, return the function defined by the parameter. When exception happens during the defining the function with the parameter, raise a trace.ExecutionError. """
+ """Return a callable function. Return the decorated function if the parameter is None. Otherwise, return the function defined by the parameter. When exception happens during the defining the function with the parameter, raise a trace.ExecutionError."""
# This function should be later called within trace_nodes context manager.
if self.parameter is None:
return self._fun
else:
- code = self.parameter._data # This is not traced, but we will add this as the parent later.
+ code = (
+ self.parameter._data
+ ) # This is not traced, but we will add this as the parent later.
# before we execute, we should try to import all the global name spaces from the original function
try:
_ldict = {}
@@ -194,11 +211,18 @@ def fun(self, *args, **kwargs):
else:
return fun
- base_message = f'({error_class}) {detail}.'
- commented_code = self.generate_comment(code, base_message, line_number, 1) + f"\n{base_message}"
- raw_traceback = 'SyntaxError in trainable code definition.\n' + commented_code if 'SyntaxError' == error_class else traceback.format_exc()
- self.info['error_comment'] = commented_code
- self.info['traceback'] = raw_traceback # This is saved for user debugging
+ base_message = f"({error_class}) {detail}."
+ commented_code = (
+ self.generate_comment(code, base_message, line_number, 1)
+ + f"\n{base_message}"
+ )
+ raw_traceback = (
+ "SyntaxError in trainable code definition.\n" + commented_code
+ if "SyntaxError" == error_class
+ else traceback.format_exc()
+ )
+ self.info["error_comment"] = commented_code
+ self.info["traceback"] = raw_traceback # This is saved for user debugging
e_node = ExceptionNode(
e,
@@ -215,7 +239,7 @@ def name(self):
return get_op_name(self.description)
def _wrap_inputs(self, fun, args, kwargs):
- """ Wrap the inputs to a function as nodes when they're not.
+ """Wrap the inputs to a function as nodes when they're not.
Args:
fun (callable): the function to be wrapped.
@@ -229,7 +253,7 @@ def _wrap_inputs(self, fun, args, kwargs):
_args (list): the original positional arguments (including the default values).
_kwargs (dict): the original keyword arguments (including the default values).
"""
- ## Wrap the inputs as nodes
+ # Wrap the inputs as nodes
# add default into kwargs
ba = inspect.signature(fun).bind(*args, **kwargs)
@@ -238,57 +262,86 @@ def _wrap_inputs(self, fun, args, kwargs):
a1 = ba.arguments
fullargspec = inspect.getfullargspec(fun)
# include default into the kwargs
- for k,v in a1.items():
+ for k, v in a1.items():
if k not in a0:
if k != fullargspec.varargs and k != fullargspec.varkw:
kwargs[k] = v
# convert args and kwargs to nodes, except for FunModule
_args, _kwargs = args, kwargs # back up
- args = [node(a, name=fullargspec.args[i] if i < len(fullargspec.args) and not isinstance(a, Node) else None) if not isinstance(a, FunModule) else a for i, a in enumerate(args)]
- kwargs = {k: node(v, name=k if not isinstance(v, Node) else None) if not isinstance(v, FunModule) else v for k, v in kwargs.items()}
+ args = [
+ (
+ node(
+ a,
+ name=(
+ fullargspec.args[i]
+ if i < len(fullargspec.args) and not isinstance(a, Node)
+ else None
+ ),
+ )
+ if not isinstance(a, FunModule)
+ else a
+ )
+ for i, a in enumerate(args)
+ ]
+ kwargs = {
+ k: (
+ node(v, name=k if not isinstance(v, Node) else None)
+ if not isinstance(v, FunModule)
+ else v
+ )
+ for k, v in kwargs.items()
+ }
- ## Construct the input dict of the MessageNode from function inputs
+ # Construct the input dict of the MessageNode from function inputs
inputs = {}
# args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, ann
_, varargs, varkw, _, _, _, _ = inspect.getfullargspec(fun)
-
# bind the node version of args and kwargs
ba = inspect.signature(fun).bind(*args, **kwargs)
spec = ba.arguments
def extract_param(n):
- return n.parameter if isinstance(n, FunModule) and n.parameter is not None else n
+ return (
+ n.parameter
+ if isinstance(n, FunModule) and n.parameter is not None
+ else n
+ )
# expand varargs and varkw
for k, v in spec.items():
if k == varargs: # unpack varargs
for i, n in enumerate(v):
- inputs[f"args_{i}"] = extract_param(n) # TODO different representation?
+ inputs[f"args_{i}"] = extract_param(
+ n
+ ) # TODO different representation?
elif k == varkw: # unpack varkw
for kk, n in v.items():
inputs[kk] = extract_param(n)
else:
inputs[k] = extract_param(v)
- assert all([isinstance(n, Node) for n in inputs.values()]), "All values in inputs must be nodes."
+ assert all(
+ [isinstance(n, Node) for n in inputs.values()]
+ ), "All values in inputs must be nodes."
return inputs, args, kwargs, _args, _kwargs
def _get_tracer(self):
- """ Get a tracer to overwrite the python recursion behavior of calling the decorated function. """
+ """Get a tracer to overwrite the python recursion behavior of calling the decorated function."""
# Define a tracer to deal with recursive function calls
_bundled_func = None
- def tracer(frame, event, arg = None):
- """ This tracer modifies the local/global dict of the frame, so that
+
+ def tracer(frame, event, arg=None):
+ """This tracer modifies the local/global dict of the frame, so that
when a recursive call of the wrapped function is made, it calls the
unwrapped function."""
nonlocal _bundled_func
if frame.f_code == self._fun.__code__: # entering the wrapped function
# Use the original function, rather than the bundled function
- if event == 'call': # Detect potential recursive calls
+ if event == "call": # Detect potential recursive calls
if frame.f_code.co_name in frame.f_locals:
# # the function is not defined globally at the top level
current_fun = frame.f_locals[frame.f_code.co_name]
@@ -301,15 +354,18 @@ def tracer(frame, event, arg = None):
_bundled_func = current_fun # save the original function
frame.f_globals[frame.f_code.co_name] = self._fun
- elif event == 'return':
+ elif event == "return":
if frame.f_code.co_name in frame.f_globals:
frame.f_globals[frame.f_code.co_name] = _bundled_func
return tracer
+
return tracer
def _construct_error_comment(self, e):
- """ Construct the error comment on the source code and traceback. """
- self.info['traceback'] = traceback.format_exc() # This is saved for user debugging
+ """Construct the error comment on the source code and traceback."""
+ self.info["traceback"] = (
+ traceback.format_exc()
+ ) # This is saved for user debugging
# Construct message to optimizer
error_class = e.__class__.__name__
detail = e.args[0]
@@ -318,33 +374,49 @@ def _construct_error_comment(self, e):
n_fun_calls = len(traceback.extract_tb(tb))
# Step through the traceback stack
comments = []
- base_message = f'({error_class}) {detail}.'
+ base_message = f"({error_class}) {detail}."
for i, (f, ln) in enumerate(traceback.walk_tb(tb)):
- if i>0: # ignore the first one, since that is the try statement above
- error_message = base_message if i == n_fun_calls-1 else 'Error raised in function call. See below.'
-
- if i==1 and self.parameter is not None: # this is the trainable function defined by exec, which needs special treatment. inspect.getsource doesn't work here.
- comment = self.generate_comment(self.parameter._data, error_message, ln, 1)
- comment_backup = self.generate_comment(self.parameter._data, base_message, ln, 1)
+ if i > 0: # ignore the first one, since that is the try statement above
+ error_message = (
+ base_message
+ if i == n_fun_calls - 1
+ else "Error raised in function call. See below."
+ )
+
+ if (
+ i == 1 and self.parameter is not None
+ ): # this is the trainable function defined by exec, which needs special treatment. inspect.getsource doesn't work here.
+ comment = self.generate_comment(
+ self.parameter._data, error_message, ln, 1
+ )
+ comment_backup = self.generate_comment(
+ self.parameter._data, base_message, ln, 1
+ )
else:
try:
f_source, f_source_ln = self.get_source(f, bug_mode=True)
- except OSError: # OSError: could not get source code
+ except OSError: # OSError: could not get source code
# we reach the compiled C level, so the previous level is actually the bottom
comments[-1] = comment_backup # replace the previous comment
break # exit the loop
- comment = self.generate_comment(f_source, error_message, ln, f_source_ln)
- comment_backup = self.generate_comment(f_source, base_message, ln, f_source_ln)
+ comment = self.generate_comment(
+ f_source, error_message, ln, f_source_ln
+ )
+ comment_backup = self.generate_comment(
+ f_source, base_message, ln, f_source_ln
+ )
comments.append(comment)
- commented_code = '\n\n'.join(comments)
- self.info['error_comment'] = commented_code + f"\n{base_message}"
+ commented_code = "\n\n".join(comments)
+ self.info["error_comment"] = commented_code + f"\n{base_message}"
output = e
return output
def sync_call_fun(self, fun, *_args, **_kwargs):
- """ Call the operator fun and return the output. Catch the exception if catch_execution_error is True. """
+ """Call the operator fun and return the output. Catch the exception if catch_execution_error is True."""
oldtracer = sys.gettrace()
- if self.overwrite_python_recursion and self.parameter is None: # Overwrite the python recursion behavior
+ if (
+ self.overwrite_python_recursion and self.parameter is None
+ ): # Overwrite the python recursion behavior
# Running a tracer would slow down the execution, so we only do this when necessary.
sys.settrace(self._get_tracer())
@@ -361,7 +433,9 @@ def sync_call_fun(self, fun, *_args, **_kwargs):
async def async_call_fun(self, fun, *_args, **_kwargs):
oldtracer = sys.gettrace()
- if self.overwrite_python_recursion and self.parameter is None: # Overwrite the python recursion behavior
+ if (
+ self.overwrite_python_recursion and self.parameter is None
+ ): # Overwrite the python recursion behavior
# Running a tracer would slow down the execution, so we only do this when necessary.
sys.settrace(self._get_tracer())
@@ -378,7 +452,7 @@ async def async_call_fun(self, fun, *_args, **_kwargs):
def preprocess_inputs(self, args, kwargs, _args, _kwargs):
# NOTE This function must be put inside the used_nodes context manager.
- """ Preprocess the inputs for the operator fun.
+ """Preprocess the inputs for the operator fun.
Args:
_args (list): the original positional arguments. This includes the default values.
@@ -391,7 +465,9 @@ def preprocess_inputs(self, args, kwargs, _args, _kwargs):
if self.traceable_code:
_args, _kwargs = detach_inputs(args), detach_inputs(kwargs)
else: # NOTE Extract data from the nodes and pass them to the function; This line must be put inside the used_nodes context manager.
- _args, _kwargs = to_data(args), to_data(kwargs) # read node.data; this ensures the inputs are treated as used nodes
+ _args, _kwargs = to_data(args), to_data(
+ kwargs
+ ) # read node.data; this ensures the inputs are treated as used nodes
# else the inputs are passed directly to the function
# so we don't change _args and _kwargs
return _args, _kwargs # this will be passed as the input to the function
@@ -411,11 +487,13 @@ def postprocess_output(self, output, fun, _args, _kwargs, used_nodes, inputs):
# Log inputs and output of the function call
self.info["output"] = output
- self.info['inputs']["args"] = _args
- self.info['inputs']["kwargs"] = _kwargs
+ self.info["inputs"]["args"] = _args
+ self.info["inputs"]["kwargs"] = _kwargs
# Nodes used to create the output but not in the inputs are external dependencies.
- external_dependencies = [node for node in used_nodes if not contain(inputs.values(), node)]
+ external_dependencies = [
+ node for node in used_nodes if not contain(inputs.values(), node)
+ ]
self.info["external_dependencies"] = external_dependencies
# Make sure all nodes in used_nodes are in the parents of the returned node.
@@ -425,29 +503,33 @@ def postprocess_output(self, output, fun, _args, _kwargs, used_nodes, inputs):
)
if not GRAPH.TRACE:
- inputs = {} # We don't need to keep track of the inputs if we are not tracing.
+ inputs = (
+ {}
+ ) # We don't need to keep track of the inputs if we are not tracing.
# Wrap the output as a MessageNode or an ExceptionNode
nodes = self.wrap(output, inputs, external_dependencies)
return nodes
def forward(self, *args, **kwargs):
- fun = self.fun # Define the function (only once)
- self.info['fun'] = fun
+ fun = self.fun # Define the function (only once)
+ self.info["fun"] = fun
if inspect.iscoroutinefunction(fun):
- return self.async_forward(fun, *args, **kwargs) # Return a coroutine that returns a MessageNode
+ return self.async_forward(
+ fun, *args, **kwargs
+ ) # Return a coroutine that returns a MessageNode
else:
return self.sync_forward(fun, *args, **kwargs) # Return a MessageNode
def sync_forward(self, fun, *args, **kwargs):
"""
- Call the operator fun and return a MessageNode. All nodes used in
- the operator fun are added to used_nodes during the execution. If
- the output is not a Node, we wrap it as a MessageNode, whose inputs
- are nodes in used_nodes. Sync version.
+ Call the operator fun and return a MessageNode. All nodes used in
+ the operator fun are added to used_nodes during the execution. If
+ the output is not a Node, we wrap it as a MessageNode, whose inputs
+ are nodes in used_nodes. Sync version.
"""
# Wrap the inputs as nodes
inputs, args, kwargs, _args, _kwargs = self._wrap_inputs(fun, args, kwargs)
- ## Execute fun
+ # Execute fun
with trace_nodes() as used_nodes:
# After exit, used_nodes contains the nodes whose data attribute is read in the operator fun.
_args, _kwargs = self.preprocess_inputs(args, kwargs, _args, _kwargs)
@@ -458,23 +540,30 @@ def sync_forward(self, fun, *args, **kwargs):
async def async_forward(self, fun, *args, **kwargs):
"""
- Call the operator fun and return a MessageNode. All nodes used in
- the operator fun are added to used_nodes during the execution. If
- the output is not a Node, we wrap it as a MessageNode, whose inputs
- are nodes in used_nodes. Async version.
+ Call the operator fun and return a MessageNode. All nodes used in
+ the operator fun are added to used_nodes during the execution. If
+ the output is not a Node, we wrap it as a MessageNode, whose inputs
+ are nodes in used_nodes. Async version.
"""
# Wrap the inputs as nodes
inputs, args, kwargs, _args, _kwargs = self._wrap_inputs(fun, args, kwargs)
- ## Execute fun
+ # Execute fun
with trace_nodes() as used_nodes:
# After exit, used_nodes contains the nodes whose data attribute is read in the operator fun.
_args, _kwargs = self.preprocess_inputs(args, kwargs, _args, _kwargs)
- output = await self.async_call_fun(fun, *_args, **_kwargs) # use await to call the async function
+ output = await self.async_call_fun(
+ fun, *_args, **_kwargs
+ ) # use await to call the async function
# Wrap the output as a MessageNode or an ExceptionNode
nodes = self.postprocess_output(output, fun, _args, _kwargs, used_nodes, inputs)
return nodes
- def wrap(self, output: Any, inputs: Union[List[Node], Dict[str, Node]], external_dependencies: List[Node]):
+ def wrap(
+ self,
+ output: Any,
+ inputs: Union[List[Node], Dict[str, Node]],
+ external_dependencies: List[Node],
+ ):
"""Wrap the output as a MessageNode of inputs as the parents."""
# Some nodes are used in the operator fun, we need to wrap the output as a MessageNode.
if self.parameter is not None:
@@ -497,11 +586,15 @@ def wrap(self, output: Any, inputs: Union[List[Node], Dict[str, Node]], external
)
raise ExecutionError(e_node)
else:
- return MessageNode(output, description=description, inputs=inputs, name=name, info=info)
+ return MessageNode(
+ output, description=description, inputs=inputs, name=name, info=info
+ )
@staticmethod
def is_valid_output(output):
- return isinstance(output, Node) or (isinstance(output, tuple) and all([isinstance(o, Node) for o in output]))
+ return isinstance(output, Node) or (
+ isinstance(output, tuple) and all([isinstance(o, Node) for o in output])
+ )
def __get__(self, obj, objtype):
# Support instance methods.
@@ -510,18 +603,24 @@ def __get__(self, obj, objtype):
def detach(self):
return copy.deepcopy(self)
- def generate_comment(self, code: str, comment: str, comment_line_number: int, base_line_number: int = 0):
+ def generate_comment(
+ self,
+ code: str,
+ comment: str,
+ comment_line_number: int,
+ base_line_number: int = 0,
+ ):
commented_code = []
- for i, l in enumerate(code.split('\n')):
+ for i, l in enumerate(code.split("\n")):
if i == comment_line_number - base_line_number:
commented_code.append(f"{l} <--- {comment}")
else:
commented_code.append(f"{l}")
- commented_code = '\n'.join(commented_code)
+ commented_code = "\n".join(commented_code)
return commented_code
def get_source(self, obj: Any, bug_mode=False):
- """ Get the source code of the function and its line number, excluding the @bundle decorator line.
+ """Get the source code of the function and its line number, excluding the @bundle decorator line.
bug_mode=True means
We are in the forward() function, but there is an error during execution.
The error can be caused by a lambda function which does not have `def` in the source code.
@@ -539,54 +638,64 @@ def get_source(self, obj: Any, bug_mode=False):
or inline usage
>>> bundle()(fun) # or ....bundle()(fun)
"""
- source = inspect.getsource(obj) # the source includes @bundle, or @trace.bundle, etc. we will remove those parts.
+ source = inspect.getsource(
+ obj
+ ) # the source includes @bundle, or @trace.bundle, etc. we will remove those parts.
line_number = int(inspect.getsourcelines(obj)[1]) # line number of obj
# Check if it's a decorator or an inline usage.
decorator_usage = False
- lines = source.split('\n')
- for i, l in enumerate(lines):
- l = l.strip().split('#')[0] # remove spacing and comment
- if l == '':
+ lines = source.split("\n")
+ for i, line in enumerate(lines):
+ line = line.strip().split("#")[0] # remove spacing and comment
+ if line == "":
continue
- if l[0] == '@': # decorator line. check whether it's using bundle
+ if line[0] == "@": # decorator line. check whether it's using bundle
# use cases
# @bundle(
# @bundle\ i.e., change line
# @......bundle(
# @......bundle\
- if ('@bundle(' in l) or ('@bundle\\' in l) or \
- (re.search(r'@.*\.bundle\(.*', l) is not None) or \
- (re.search(r'@.*\.bundle\\.*', l) is not None):
+ if (
+ ("@bundle(" in line)
+ or ("@bundle\\" in line)
+ or (re.search(r"@.*\.bundle\(.*", line) is not None)
+ or (re.search(r"@.*\.bundle\\.*", line) is not None)
+ ):
decorator_usage = True
break # i is the where the bundle decorator is used
-
if decorator_usage:
line_offset = i # to account for @bundle is not the top decorator
# Extract the lines after @bundle(....)
- inner_source = '\n'.join(lines[i:]) # i is where @bundle is used
- assert 'def ' in inner_source
+ inner_source = "\n".join(lines[i:]) # i is where @bundle is used
+ assert "def " in inner_source
# str after the first bundle
- after_bundle = 'bundle'.join(inner_source.split('bundle')[1:]) # NOTE there may be multiple usages of bundle in the comments
+ after_bundle = "bundle".join(
+ inner_source.split("bundle")[1:]
+ ) # NOTE there may be multiple usages of bundle in the comments
# Find where the scope of brackets
count = 0
for i, t in enumerate(after_bundle):
- if t == '(':
+ if t == "(":
count += 1
- elif t == ')':
+ elif t == ")":
count -= 1
if count == 0:
break
# Get the decorated source code
- after_bundle_call = after_bundle[i+1:] # after bundle(....)
- extracted_source = '\n'.join(after_bundle_call.split('\n')[1:]) # remove the first \n
+ after_bundle_call = after_bundle[i + 1:] # after bundle(....)
+ extracted_source = "\n".join(
+ after_bundle_call.split("\n")[1:]
+ ) # remove the first \n
extracted_source = extracted_source.strip()
# Get the line number of the decorated source code
- within_bundle_call = after_bundle[:i+1]
- n_line_changes = line_offset + 1 + within_bundle_call.count('\n') # the latter is the lines within the bundle call
+ within_bundle_call = after_bundle[: i + 1]
+ n_line_changes = (
+ line_offset + 1 + within_bundle_call.count("\n")
+ ) # the latter is the lines within the bundle call
line_number += n_line_changes
else:
# The inline usecase of
@@ -595,7 +704,9 @@ def get_source(self, obj: Any, bug_mode=False):
extracted_source = inspect.getsource(obj).strip()
if not bug_mode:
- assert 'def' in extracted_source, f'def is not in the source code: {extracted_source}'
+ assert (
+ "def" in extracted_source
+ ), f"def is not in the source code: {extracted_source}"
return extracted_source, line_number
@@ -604,17 +715,19 @@ def to_data(obj):
"""Extract the data from a node or a container of nodes."""
return recursive_conversion(lambda x: x.data, lambda x: x)(obj)
+
def wrap_node(obj):
"""Wrap a node on top of the original object"""
return recursive_conversion(lambda x: x, lambda x: node(x))(obj)
+
def detach_inputs(obj):
"""Detach a node or a container of nodes."""
return recursive_conversion(lambda x: x.detach(), lambda x: x)(obj)
def update_local(frame, name, value):
- """ Update the value of a local variable in a frame."""
+ """Update the value of a local variable in a frame."""
frame.f_locals[name] = value
ctypes.pythonapi.PyFrame_LocalsToFast(ctypes.py_object(frame), ctypes.c_int(0))
@@ -630,4 +743,4 @@ def test(x):
print(y)
print("Parents", y.parents)
print("Children", y.children)
- print("Level", y._level)
\ No newline at end of file
+ print("Level", y._level)
diff --git a/opto/trace/containers.py b/opto/trace/containers.py
index 63e86789..85b2f0e9 100644
--- a/opto/trace/containers.py
+++ b/opto/trace/containers.py
@@ -3,22 +3,26 @@
from opto.trace.nodes import ParameterNode
import functools
+
class NodeContainer:
- """ An identifier for a container of nodes."""
+ """An identifier for a container of nodes."""
+
...
def trainable_method(method):
from opto.trace.bundle import FunModule
+
if isinstance(method, FunModule):
return method.trainable
return False
+
class ParameterContainer(NodeContainer):
- """ A container of parameter nodes. """
+ """A container of parameter nodes."""
def parameters(self):
- """ Return a flattned list of all the parameters in the model's
+ """Return a flattned list of all the parameters in the model's
parameters_dict, useful for optimization."""
parameters = []
for k, v in self.parameters_dict().items():
@@ -32,7 +36,7 @@ def parameters(self):
return parameters
def parameters_dict(self):
- """ Return a dictionary of all the parameters in the model, including
+ """Return a dictionary of all the parameters in the model, including
both trainable and non-trainable parameters. The dict contains
ParameterNodes or ParameterContainers.
"""
@@ -49,12 +53,14 @@ def parameters_dict(self):
elif isinstance(attr, ParameterContainer):
parameters[name] = attr
- assert all(isinstance(v, (ParameterNode, ParameterContainer)) for v in parameters.values())
+ assert all(
+ isinstance(v, (ParameterNode, ParameterContainer))
+ for v in parameters.values()
+ )
return parameters # include both trainable and non-trainable parameters
-
class Seq(UserList, ParameterContainer):
"""
Seq is defined as having a length and an index.
@@ -62,14 +68,18 @@ class Seq(UserList, ParameterContainer):
"""
def __init__(self, *args):
- if len(args) == 1 and hasattr(args[0], "__len__") and hasattr(args[0], "__getitem__"):
+ if (
+ len(args) == 1
+ and hasattr(args[0], "__len__")
+ and hasattr(args[0], "__getitem__")
+ ):
seq = args[0]
else:
seq = args
super().__init__(initlist=seq)
def parameters_dict(self):
- """ Return a dictionary of all the parameters in the model, including
+ """Return a dictionary of all the parameters in the model, including
both trainable and non-trainable parameters. The dict contains
ParameterNodes or ParameterContainers.
"""
@@ -80,7 +90,10 @@ def parameters_dict(self):
elif isinstance(attr, ParameterContainer):
parameters[str(attr)] = attr # TODO: what is the name of the container?
- assert all(isinstance(v, (ParameterNode, ParameterContainer)) for v in parameters.values())
+ assert all(
+ isinstance(v, (ParameterNode, ParameterContainer))
+ for v in parameters.values()
+ )
return parameters
@@ -94,7 +107,7 @@ def __init__(self, mapping):
super().__init__(mapping)
def parameters_dict(self):
- """ Return a dictionary of all the parameters in the model, including
+ """Return a dictionary of all the parameters in the model, including
both trainable and non-trainable parameters. The dict contains
ParameterNodes or ParameterContainers.
"""
@@ -110,6 +123,11 @@ def parameters_dict(self):
elif isinstance(k, ParameterContainer):
raise Exception("The key of a Map cannot be a container.")
- assert all(isinstance(v, (ParameterNode, ParameterContainer)) for v in parameters.values())
+ assert all(
+ isinstance(v, (ParameterNode, ParameterContainer))
+ for v in parameters.values()
+ )
return parameters
-#
\ No newline at end of file
+
+
+#
diff --git a/opto/trace/errors.py b/opto/trace/errors.py
index ee760f2e..70943124 100644
--- a/opto/trace/errors.py
+++ b/opto/trace/errors.py
@@ -1,6 +1,6 @@
-
from opto.trace.nodes import ExceptionNode
+
class ExecutionError(Exception):
"""Base class for execution error in code tracing."""
@@ -9,7 +9,7 @@ def __init__(self, exception_node: ExceptionNode):
super().__init__(self.exception_node.data)
def __str__(self):
- return '\n\n' + self.exception_node.info['traceback'] # show full traceback
+ return "\n\n" + self.exception_node.info["traceback"] # show full traceback
class TraceMissingInputsError(Exception):
diff --git a/opto/trace/iterators.py b/opto/trace/iterators.py
index bd42d8ea..8bc28779 100644
--- a/opto/trace/iterators.py
+++ b/opto/trace/iterators.py
@@ -1,14 +1,14 @@
from opto.trace.nodes import node, Node, ExceptionNode
from typing import Any
+
from opto.trace.bundle import bundle
import opto.trace.operators as ops
from opto.trace.errors import ExecutionError
-
# List[Nodes], Node[List]
def iterate(x: Any):
- """ Return an iterator object for node of list, tuple, set, or dict. """
+ """Return an iterator object for node of list, tuple, set, or dict."""
if not isinstance(x, Node):
x = node(x)
if issubclass(x.type, list) or issubclass(x.type, tuple) or issubclass(x.type, str):
@@ -19,11 +19,17 @@ def iterate(x: Any):
elif issubclass(x.type, dict):
return SeqIterable(x.keys())
else:
- raw_traceback = "TypeError: Cannot unpack non-iterable {} object".format(type(x._data))
+ raw_traceback = "TypeError: Cannot unpack non-iterable {} object".format(
+ type(x._data)
+ )
ex = TypeError(raw_traceback)
- e = ExceptionNode(ex, inputs=[x], info={
- 'traceback': raw_traceback,
- })
+ e = ExceptionNode(
+ ex,
+ inputs=[x],
+ info={
+ "traceback": raw_traceback,
+ },
+ )
raise ExecutionError(e)
@@ -72,27 +78,3 @@ def __next__(self):
return result
else:
raise StopIteration
-
-class DictIterable:
- def __init__(self, wrapped_dict):
- assert isinstance(wrapped_dict, Node)
- self._index = 0
- self.wrapped_dict = wrapped_dict
- self.keys = ops.keys(wrapped_dict)
-
- def __iter__(self):
- self._index = 0
- return self
-
- def __next__(self):
- if self._index < len(self.keys):
-
- key = self.keys[self._index]
- result = (key, self.wrapped_dict[key])
- self._index += 1
-
- assert self.wrapped_dict in result[1].parents
-
- return result
- else:
- raise StopIteration
\ No newline at end of file
diff --git a/opto/trace/modules.py b/opto/trace/modules.py
index ae7a8913..a85d1efb 100644
--- a/opto/trace/modules.py
+++ b/opto/trace/modules.py
@@ -17,7 +17,7 @@ class ModelWrapper(cls, Module):
class Module(ParameterContainer):
- """ Module is a ParameterContainer which has a forward method. """
+ """Module is a ParameterContainer which has a forward method."""
def forward(self, *args, **kwargs):
raise NotImplementedError
@@ -26,7 +26,7 @@ def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def save(self, file_name):
- """ Save the parameters of the model to a file."""
+ """Save the parameters of the model to a file."""
# detect if the directory exists
directory = os.path.dirname(file_name)
if directory != "":
@@ -35,13 +35,13 @@ def save(self, file_name):
pickle.dump(copy.deepcopy(self.parameters_dict()), f)
def load(self, file_name):
- """ Load the parameters of the model from a file."""
+ """Load the parameters of the model from a file."""
with open(file_name, "rb") as f:
loaded_data = pickle.load(f)
self._set(loaded_data)
def _set(self, new_parameters):
- """ Set the parameters of the model from a dictionary.
+ """Set the parameters of the model from a dictionary.
new_parameters is a ParamterContainer or a parameter dict.
"""
assert isinstance(new_parameters, (dict, ParameterContainer))
@@ -52,8 +52,9 @@ def _set(self, new_parameters):
parameters_dict = self.parameters_dict()
- assert all(k in new_parameters_dict for k in
- parameters_dict.keys()), """ Not all model parameters are in the new parameters dictionary. """
+ assert all(
+ k in new_parameters_dict for k in parameters_dict.keys()
+ ), """ Not all model parameters are in the new parameters dictionary. """
for k, v in new_parameters_dict.items():
if k in parameters_dict: # if the parameter exists
@@ -61,4 +62,4 @@ def _set(self, new_parameters):
parameters_dict[k]._set(v)
else: # if the parameter does not exist
assert k not in self.__dict__
- setattr(self, k, v)
\ No newline at end of file
+ setattr(self, k, v)
diff --git a/opto/trace/nodes.py b/opto/trace/nodes.py
index 4e522936..6719aeb3 100644
--- a/opto/trace/nodes.py
+++ b/opto/trace/nodes.py
@@ -33,10 +33,16 @@ def node(data, name=None, trainable=False, description=None, constraint=None):
if trainable:
if isinstance(data, Node):
- name = name or data.name.split(':')[0]
+ name = name or data.name.split(":")[0]
data = data._data
- return ParameterNode(data, name=name, trainable=True, description=description, constraint=constraint)
+ return ParameterNode(
+ data,
+ name=name,
+ trainable=True,
+ description=description,
+ constraint=constraint,
+ )
else:
if isinstance(data, Node):
if name is not None:
@@ -53,7 +59,6 @@ class Graph:
"""Graph is a registry of all the nodes, forming a Directed Acyclic Graph (DAG).
Attributes:
- TRACE (bool): A class-level boolean attribute that determines whether the graph is traced when creating MessageNode. Default is True.
_nodes (defaultdict): An instance-level attribute, which is a defaultdict of lists, used as a lookup table to find nodes by name.
Notes:
@@ -109,7 +114,7 @@ def register(self, node):
name = NAME_SCOPES[-1] + "/" + name
self._nodes[name].append(node)
node._name = (
- name + ":" + str(len(self._nodes[name]) - 1)
+ name + ":" + str(len(self._nodes[name]) - 1)
) # NOTE assume elements in self._nodes never get removed.
# self._levels[node._level].append(node)
@@ -159,38 +164,16 @@ def __len__(self):
GRAPH = Graph() # This is a global registry of all the nodes.
-USED_NODES = list() # A stack of sets. This is a global registry to track which nodes are read.
+USED_NODES = (
+ list()
+) # A stack of sets. This is a global registry to track which nodes are read.
T = TypeVar("T")
-"""Graph is a registry of all the nodes, forming a Directed Acyclic Graph (DAG).
-
- Attributes:
- TRACE (bool): A class-level boolean attribute that determines whether the graph is traced when creating MessageNode. Default is True.
- _nodes (defaultdict): An instance-level attribute, which is a defaultdict of lists, used as a lookup table to find nodes by name.
-
- Notes:
- The Graph class manages and organizes nodes in a Directed Acyclic Graph (DAG).
- It provides methods to register nodes, clear the graph, retrieve nodes by name, and identify root nodes.
- The `register` method assumes that elements in `_nodes` are never removed,
- which is important for maintaining the integrity of node names.
-"""
-
class AbstractNode(Generic[T]):
"""AbstractNode represents an abstract data node in a directed graph.
- Attributes:
- data: The data stored in the node.
- parents: The list of parent nodes.
- children: The list of child nodes.
- name: The name of the node.
- py_name: The name of the node without the ":" character.
- id: The ID of the node.
- level: The level of the node in the graph.
- is_root: A boolean indicating whether the node is a root node.
- is_leaf: A boolean indicating whether the node is a leaf node.
-
Notes:
The `AbstractNode` class is meant to be subclassed and extended to create specific types of nodes.
The node can have multiple parents and children, forming a directed graph structure.
@@ -237,7 +220,9 @@ def __init__(self, value, *, name=None, trainable=False) -> None:
self._parents = []
self._children = []
self._level = 0 # roots are at level 0
- default_name = str(type(value).__name__) + ":0" if name is None else name + ":0" # name:version
+ default_name = (
+ str(type(value).__name__) + ":0" if name is None else name + ":0"
+ ) # name:version
if isinstance(value, Node): # just a reference
self._data = value._data
self._name = value._name if name is None else default_name
@@ -396,10 +381,14 @@ def _add_parent(self, parent):
with child nodes always having a level greater than or equal to their parent nodes.
"""
assert parent is not self, "Cannot add self as a parent."
- assert isinstance(parent, Node), f"{parent} is {type(parent)}, which is not a Node."
+ assert isinstance(
+ parent, Node
+ ), f"{parent} is {type(parent)}, which is not a Node."
parent._children.append(self)
self._parents.append(parent)
- self._update_level(max(self._level, parent._level + 1)) # Update the level, because the parent is added
+ self._update_level(
+ max(self._level, parent._level + 1)
+ ) # Update the level, because the parent is added
def _update_level(self, new_level):
"""Update the level attribute of the current node.
@@ -462,7 +451,7 @@ def __deepcopy__(self, memo):
for k, v in self.__dict__.items():
if k == "_parents" or k == "_children":
setattr(result, k, [])
- elif k == '_feedback':
+ elif k == "_feedback":
setattr(result, k, defaultdict(list))
else:
setattr(result, k, copy.deepcopy(v, memo))
@@ -523,13 +512,17 @@ def get_op_name(description):
If a match is found, the operator type is extracted and returned.
Otherwise, a `ValueError` is raised with a specific error message.
"""
- assert type(description) is str, f"Description must be a string, but it is {type(description)}: {description}."
+ assert (
+ type(description) is str
+ ), f"Description must be a string, but it is {type(description)}: {description}."
match = re.search(r"^\[([^\[\]]+)\]", description)
if match:
operator_type = match.group(1)
return operator_type
else:
- raise ValueError(f"The description '{description}' must contain the operator type in square brackets.")
+ raise ValueError(
+ f"The description '{description}' must contain the operator type in square brackets."
+ )
class NodeVizStyleGuide:
@@ -540,7 +533,7 @@ class NodeVizStyleGuide:
print_limit (int): Sets the maximum number of characters to print for node descriptions and content. Default is 100.
"""
- def __init__(self, style='default', print_limit=100):
+ def __init__(self, style="default", print_limit=100):
"""Initialize the NodeVizStyleGuide.
Args:
@@ -566,10 +559,10 @@ def get_attrs(self, x):
assign a color, and set the style.
"""
attrs = {
- 'label': self.get_label(x),
- 'shape': self.get_node_shape(x),
- 'fillcolor': self.get_color(x),
- 'style': self.get_style(x)
+ "label": self.get_label(x),
+ "shape": self.get_node_shape(x),
+ "fillcolor": self.get_color(x),
+ "style": self.get_style(x),
}
return attrs
@@ -591,7 +584,7 @@ def get_label(self, x):
# using colon in the name causes problems in graphviz
description = x.description
if len(x.description) > self.print_limit:
- description = x.description[:self.print_limit] + "..."
+ description = x.description[: self.print_limit] + "..."
text = x.py_name + "\n" + description + "\n"
content = str(x.data)
@@ -600,7 +593,7 @@ def get_label(self, x):
content = str(x.data["content"])
if len(content) > self.print_limit:
- content = content[:self.print_limit] + "..."
+ content = content[: self.print_limit] + "..."
return text + content
def get_node_shape(self, x):
@@ -617,8 +610,8 @@ def get_node_shape(self, x):
ParameterNode types are represented as 'box',
while other types are represented as 'ellipse'.
"""
- if type(x) == ParameterNode:
- return 'box'
+ if type(x) is ParameterNode:
+ return "box"
else:
return "ellipse"
@@ -636,10 +629,10 @@ def get_color(self, x):
ExceptionNode types are colored 'firebrick1',
and ParameterNode types are colored 'lightgray'.
"""
- if type(x) == ExceptionNode:
- return 'firebrick1'
- elif type(x) == ParameterNode:
- return '#DEEBF6'
+ if type(x) is ExceptionNode:
+ return "firebrick1"
+ elif type(x) is ParameterNode:
+ return "#DEEBF6"
return ""
@@ -656,7 +649,7 @@ def get_style(self, x):
The style of a node is set to 'filled,solid' if the node is trainable;
otherwise, it returns an empty string.
"""
- return 'filled,solid' if x.trainable else ""
+ return "filled,solid" if x.trainable else ""
class NodeVizStyleGuideColorful(NodeVizStyleGuide):
@@ -667,7 +660,7 @@ class NodeVizStyleGuideColorful(NodeVizStyleGuide):
print_limit (int): Sets the maximum number of characters to print for node descriptions and content. Default is 100.
"""
- def __init__(self, style='default', print_limit=100):
+ def __init__(self, style="default", print_limit=100):
"""Initialize the NodeVizStyleGuideColorful.
Args:
@@ -693,12 +686,12 @@ def get_attrs(self, x):
determine the node shape, assign a color, and set the style.
"""
attrs = {
- 'label': self.get_label(x),
- 'shape': self.get_node_shape(x),
- 'fillcolor': self.get_color(x),
- 'style': self.get_style(x),
- 'color': self.get_border_color(x),
- 'penwidth': "1.2"
+ "label": self.get_label(x),
+ "shape": self.get_node_shape(x),
+ "fillcolor": self.get_color(x),
+ "style": self.get_style(x),
+ "color": self.get_border_color(x),
+ "penwidth": "1.2",
}
return attrs
@@ -716,10 +709,10 @@ def get_border_color(self, x):
ExceptionNode types are colored 'firebrick1',
and ParameterNode types are colored 'black'.
"""
- if type(x) == ExceptionNode:
- return 'black'
- elif type(x) == ParameterNode:
- return '#FF7E79'
+ if type(x) is ExceptionNode:
+ return "black"
+ elif type(x) is ParameterNode:
+ return "#FF7E79"
return "#5C9BD5"
@@ -737,10 +730,10 @@ def get_color(self, x):
ExceptionNode types are colored 'firebrick1',
and ParameterNode types are colored 'lightgray'.
"""
- if type(x) == ExceptionNode:
- return 'firebrick1'
- elif type(x) == ParameterNode:
- return '#FFE5E5'
+ if type(x) is ExceptionNode:
+ return "firebrick1"
+ elif type(x) is ParameterNode:
+ return "#FFE5E5"
return "#DEEBF6"
@@ -753,7 +746,7 @@ def get_style(self, x):
Returns:
str: The style string 'filled,solid'.
"""
- return 'filled,solid'
+ return "filled,solid"
class Node(AbstractNode[T]):
@@ -788,14 +781,14 @@ class Node(AbstractNode[T]):
"""
def __init__(
- self,
- value: Any,
- *,
- name: str = None,
- trainable: bool = False,
- description: str = "[Node] This is a node in a computational graph.",
- constraint: Union[None, str] = None,
- info: Union[None, Dict] = None,
+ self,
+ value: Any,
+ *,
+ name: str = None,
+ trainable: bool = False,
+ description: str = "[Node] This is a node in a computational graph.",
+ constraint: Union[None, str] = None,
+ info: Union[None, Dict] = None,
) -> None:
"""Initialize an instance of the Node class.
@@ -813,22 +806,20 @@ def __init__(
matched = re.match(r"^\[([^\[\]]+)\]", description)
if not matched:
- description = '[Node] ' + description.strip()
+ description = "[Node] " + description.strip()
super().__init__(value, name=name)
self.trainable = trainable
- self._feedback = defaultdict(
- list
- ) # (analogous to gradient) this is the feedback from the user. Each key is a child and the value is a list of feedbacks from the child.
+ # (analogous to gradient) this is the feedback from the user. Each key is a child and the value is a list of feedbacks from the child.
# We keep the propagated feedback as dict and let the propagator performs
# the aggreation, rather than doing the aggregation incrementally. This is
# to support implementing aggregation that is not commutable.
- self._description = description # Information to describe of the node
- self._constraint = constraint # A constraint on the node
- self._backwarded = False # True if backward has been called
- self._info = info # Additional information about the node
- self._dependencies = {'parameter': set(),
- 'expandable': set()} # A dictionary of dependencies on parameters and expandable nodes; expandable nodes are those who depened on parameters not visible in the current graph level.
+ self._feedback = defaultdict(list)
+ self._description = description
+ self._constraint = constraint
+ self._backwarded = False
+ self._info = info
+ self._dependencies = {"parameter": set(), "expandable": set()}
def zero_feedback(self): # set feedback to zero
"""Zero out the feedback of the node.
@@ -868,7 +859,7 @@ def parameter_dependencies(self):
with a corresponding value before calling the parameter_dependencies function to avoid potential
KeyError exceptions.
"""
- return self._dependencies['parameter']
+ return self._dependencies["parameter"]
@property
def expandable_dependencies(self):
@@ -880,7 +871,7 @@ def expandable_dependencies(self):
with a corresponding value before calling the expandable_dependencies function to avoid potential
KeyError exceptions.
"""
- return self._dependencies['expandable']
+ return self._dependencies["expandable"]
def _add_feedback(self, child, feedback):
"""Add feedback from a child.
@@ -918,14 +909,14 @@ def _itemize(self): # for priority queue
return (-self.level, id(self), self)
def backward(
- self,
- feedback: Any = "",
- propagator=None,
- retain_graph=False,
- visualize=False,
- simple_visualization=True,
- reverse_plot=False,
- print_limit=100,
+ self,
+ feedback: Any = "",
+ propagator=None,
+ retain_graph=False,
+ visualize=False,
+ simple_visualization=True,
+ reverse_plot=False,
+ print_limit=100,
):
"""Performs a backward pass in a computational graph.
@@ -935,11 +926,11 @@ def backward(
Args:
feedback: The feedback given to the current node.
propagator: A function that takes in a node and a feedback, and returns a dict of {parent: parent_feedback}.
- If not provided, a default `GraphPropagator` object is used.
+ If not provided, a default `GraphPropagator` object is used.
retain_graph: If True, the graph will be retained after backward pass.
visualize: If True, the graph will be visualized using graphviz.
simple_visualization: If True, identity operators will be skipped in the visualization.
- reverse_plot: if True, plot the graph in reverse order (from child to parent).
+ reverse_plot: If True, plot the graph in reverse order (from child to parent).
print_limit: The maximum number of characters to print for node descriptions and content.
Returns:
@@ -956,7 +947,9 @@ def backward(
Visualization is handled using graphviz if enabled, with options to simplify the graph by skipping identity operators.
"""
if propagator is None:
- from opto.trace.propagators.graph_propagator import GraphPropagator # this avoids circular import
+ from opto.trace.propagators.graph_propagator import (
+ GraphPropagator,
+ ) # this avoids circular import
propagator = GraphPropagator()
@@ -973,7 +966,9 @@ def backward(
# Check for root node with no parents
if self._backwarded:
raise AttributeError(f"{self} has been backwarded.")
- self._add_feedback(Node("FEEDBACK_ORACLE"), propagator.init_feedback(self, feedback))
+ self._add_feedback(
+ Node("FEEDBACK_ORACLE"), propagator.init_feedback(self, feedback)
+ )
if len(self.parents) == 0: # This is a root. Nothing to propagate
if visualize:
@@ -982,10 +977,14 @@ def backward(
return digraph
# TODO check memory leak
- queue = [self._itemize()] # priority queue; add id() since __eq__ is overloaded to compare values.
+ queue = [
+ self._itemize()
+ ] # priority queue; add id() since __eq__ is overloaded to compare values.
while True:
try:
- _, _, node = heapq.heappop(queue) # All the children of this node have been visited
+ _, _, node = heapq.heappop(
+ queue
+ ) # All the children of this node have been visited
# Each node is a MessageNode, which has at least one parent.
assert len(node.parents) > 0 and isinstance(node, MessageNode)
if node._backwarded:
@@ -1003,18 +1002,32 @@ def backward(
parent._add_feedback(node, propagated_feedback[parent])
# Put parent in the queue if it has not been visited and it's not a root
- if len(parent.parents) > 0 and parent._itemize() not in queue: # and parent not in queue:
- heapq.heappush(queue, parent._itemize()) # put parent in the priority queue
+ if (
+ len(parent.parents) > 0 and parent._itemize() not in queue
+ ): # and parent not in queue:
+ heapq.heappush(
+ queue, parent._itemize()
+ ) # put parent in the priority queue
if visualize:
# Plot the edge from parent to node
# Bypass chain of identity operators (for better visualization)
- while (get_op_name(parent.description) in IDENTITY_OPERATORS) and simple_visualization:
- assert len(parent.parents) == 1 # identity operators should have only one parent
- visited.add(parent.py_name) # skip this node in visualization
+ while (
+ get_op_name(parent.description) in IDENTITY_OPERATORS
+ ) and simple_visualization:
+ assert (
+ len(parent.parents) == 1
+ ) # identity operators should have only one parent
+ visited.add(
+ parent.py_name
+ ) # skip this node in visualization
parent = parent.parents[0]
- edge = (node.py_name, parent.py_name) if reverse_plot else (parent.py_name, node.py_name)
+ edge = (
+ (node.py_name, parent.py_name)
+ if reverse_plot
+ else (parent.py_name, node.py_name)
+ )
# Just plot the edge once, since the same node can be
# visited multiple times (e.g., when that node has
# multiple children).
@@ -1252,7 +1265,7 @@ def __trunc__(self):
return ops.trunc(self)
- ## Normal arithmetic operators
+ # Normal arithmetic operators
def __add__(self, other):
"""Return the sum of the node and another value.
@@ -1821,6 +1834,7 @@ def eq(self, other):
Otherwise, it will return a MessageNode.
"""
import opto.trace.operators as ops
+
return ops.eq(self, node(other))
def neq(self, other):
@@ -1837,6 +1851,7 @@ def neq(self, other):
Otherwise, it will return a MessageNode.
"""
import opto.trace.operators as ops
+
return ops.neq(self, node(other))
def __hash__(self):
@@ -1862,7 +1877,9 @@ def __bool__(self):
# string operators
def format(self, *args, **kwargs):
if type(self._data) is not str:
- raise AttributeError(f"{type(self._data)} object has no attribute 'format'.")
+ raise AttributeError(
+ f"{type(self._data)} object has no attribute 'format'."
+ )
import opto.trace.operators as ops
@@ -1870,7 +1887,9 @@ def format(self, *args, **kwargs):
def capitalize(self):
if type(self._data) is not str:
- raise AttributeError(f"{type(self._data)} object has no attribute 'capitalize'.")
+ raise AttributeError(
+ f"{type(self._data)} object has no attribute 'capitalize'."
+ )
import opto.trace.operators as ops
return ops.capitalize(self)
@@ -1891,7 +1910,9 @@ def upper(self):
def swapcase(self):
if type(self._data) is not str:
- raise AttributeError(f"{type(self._data)} object has no attribute 'swapcase'.")
+ raise AttributeError(
+ f"{type(self._data)} object has no attribute 'swapcase'."
+ )
import opto.trace.operators as ops
return ops.swapcase(self)
@@ -1919,7 +1940,9 @@ def strip(self, chars=None):
def replace(self, old, new, count=-1):
if type(self._data) is not str:
- raise AttributeError(f"{type(self._data)} object has no attribute 'replace'.")
+ raise AttributeError(
+ f"{type(self._data)} object has no attribute 'replace'."
+ )
import opto.trace.operators as ops
return ops.replace(self, node(old), node(new), count)
@@ -1931,7 +1954,7 @@ def join(self, seq):
try:
iter(seq)
except TypeError:
- raise TypeError(f"Can only join an iterable.")
+ raise TypeError("Can only join an iterable.")
import opto.trace.operators as ops
@@ -1971,30 +1994,39 @@ def append(self, *args, **kwargs):
class ParameterNode(Node[T]):
# This is a shorthand of a trainable Node.
def __init__(
- self,
- value,
- *,
- name=None,
- trainable=True,
- description="[ParameterNode] This is a ParameterNode in a computational graph.",
- constraint=None,
- info=None,
+ self,
+ value,
+ *,
+ name=None,
+ trainable=True,
+ description="[ParameterNode] This is a ParameterNode in a computational graph.",
+ constraint=None,
+ info=None,
) -> None:
if description is None or description == "":
- description = "[ParameterNode] This is a ParameterNode in a computational graph."
+ description = (
+ "[ParameterNode] This is a ParameterNode in a computational graph."
+ )
matched = re.match(r"^\[([^\[\]]+)\]", description)
if not matched:
- description = '[ParameterNode] ' + description.strip()
+ description = "[ParameterNode] " + description.strip()
super().__init__(
- value, name=name, trainable=trainable, description=description, constraint=constraint, info=info
+ value,
+ name=name,
+ trainable=trainable,
+ description=description,
+ constraint=constraint,
+ info=info,
)
- self._dependencies['parameter'].add(self)
+ self._dependencies["parameter"].add(self)
def __str__(self) -> str:
# str(node) allows us to look up in the feedback dictionary easily
- return f"ParameterNode: ({self.name}, dtype={type(self._data)}, data={self._data})"
+ return (
+ f"ParameterNode: ({self.name}, dtype={type(self._data)}, data={self._data})"
+ )
class MessageNode(Node[T]):
@@ -2015,28 +2047,27 @@ class MessageNode(Node[T]):
Attributes:
value: The output value of the operator
- inputs (Union[List[Node], Dict[str, Node]]): Input nodes to the operator
- description (str): Description string starting with [operator_name]
- constraint: Optional constraints on the output
- name (str, optional): Name of the node
- info (optional): Additional operator information
"""
# TODO document what needs to go into info
def __init__(
- self,
- value,
- *,
- inputs: Union[List[Node], Dict[str, Node]], # extra
- description: str,
- constraint=None,
- name=None,
- info=None,
+ self,
+ value,
+ *,
+ inputs: Union[List[Node], Dict[str, Node]], # extra
+ description: str,
+ constraint=None,
+ name=None,
+ info=None,
) -> None:
- super().__init__(value, name=name, description=description, constraint=constraint, info=info)
+ super().__init__(
+ value, name=name, description=description, constraint=constraint, info=info
+ )
- assert isinstance(inputs, list) or isinstance(inputs, dict), "Inputs to MessageNode must be a list or a dict."
+ assert isinstance(inputs, list) or isinstance(
+ inputs, dict
+ ), "Inputs to MessageNode must be a list or a dict."
# If inputs is not a dict, we create a dict with the names of the nodes as keys
if isinstance(inputs, list):
inputs = {v.name: v for v in inputs}
@@ -2044,91 +2075,125 @@ def __init__(
# If not tracing, MessageNode would just behave like a Node.
if not GRAPH.TRACE:
- assert len(self._inputs) == 0, "MessageNode should have no inputs when not tracing."
+ assert (
+ len(self._inputs) == 0
+ ), "MessageNode should have no inputs when not tracing."
# Add parents if we are tracing
for k, v in self._inputs.items():
assert isinstance(v, Node), f"Input {k} is not a Node."
self._add_parent(v)
- self._add_dependencies(v) # Initializes the dependencies on parameter and expandable nodes
+ self._add_dependencies(
+ v
+ ) # Initializes the dependencies on parameter and expandable nodes
if len(self.hidden_dependencies) > 0:
- self._dependencies['expandable'].add(self)
+ self._dependencies["expandable"].add(self)
@property
def inputs(self):
+ """(Union[List[Node], Dict[str, Node]]): Input nodes to the operator"""
return copy.copy(self._inputs)
def __str__(self) -> str:
# str(node) allows us to look up in the feedback dictionary easily
- return f"MessageNode: ({self.name}, dtype={type(self._data)}, data={self._data})"
+ return (
+ f"MessageNode: ({self.name}, dtype={type(self._data)}, data={self._data})"
+ )
def _add_feedback(self, child, feedback):
"""Add feedback from a child."""
super()._add_feedback(child, feedback)
- assert len(self._feedback[child]) == 1, "MessageNode should have only one feedback from each child."
+ assert (
+ len(self._feedback[child]) == 1
+ ), "MessageNode should have only one feedback from each child."
@property
def hidden_dependencies(self): # this needs to be recursive
- """ Returns the set of hidden dependencies that are not visible in the current graph level."""
+ """Returns the set of hidden dependencies that are not visible in the current graph level."""
diff = set()
inputs, output = [None], None
if isinstance(self.info, dict):
- if 'inputs' in self.info:
- inputs = list(self.info['inputs']['args']) + list(self.info['inputs']['kwargs'].values())
- if 'output' in self.info:
- output = self.info['output']
-
- if isinstance(self.info, dict) and \
- isinstance(output, Node) and all(isinstance(i, Node) for i in inputs): # traceable code
+ if "inputs" in self.info:
+ inputs = list(self.info["inputs"]["args"]) + list(
+ self.info["inputs"]["kwargs"].values()
+ )
+ if "output" in self.info:
+ output = self.info["output"]
+
+ if (
+ isinstance(self.info, dict)
+ and isinstance(output, Node)
+ and all(isinstance(i, Node) for i in inputs)
+ ): # traceable code
# The inner function is traceable.
diff = diff | (
- output.parameter_dependencies - self.parameter_dependencies) # add extra parameters explicitly used in the inner function
- extra_expandable = output.expandable_dependencies - self.expandable_dependencies
+ output.parameter_dependencies - self.parameter_dependencies
+ ) # add extra parameters explicitly used in the inner function
+ extra_expandable = (
+ output.expandable_dependencies - self.expandable_dependencies
+ )
for n in extra_expandable: # add extra hidden dependencies
diff = diff | n.hidden_dependencies
return diff
def _add_dependencies(self, parent):
assert parent is not self, "Cannot add self as a parent."
- assert isinstance(parent, Node), f"{parent} is {type(parent)}, which is not a Node."
- self._dependencies['parameter'] = self._dependencies['parameter'] | parent._dependencies['parameter']
- self._dependencies['expandable'] = self._dependencies['expandable'] | parent._dependencies['expandable']
+ assert isinstance(
+ parent, Node
+ ), f"{parent} is {type(parent)}, which is not a Node."
+ self._dependencies["parameter"] = (
+ self._dependencies["parameter"] | parent._dependencies["parameter"]
+ )
+ self._dependencies["expandable"] = (
+ self._dependencies["expandable"] | parent._dependencies["expandable"]
+ )
class ExceptionNode(MessageNode[T]):
"""Node containing the exception message."""
def __init__(
- self,
- value: Exception,
- *,
- inputs: Union[List[Node], Dict[str, Node]],
- description: str = "[ExceptionNode] This is node containing the error of execution.",
- constraint=None,
- name=None,
- info=None,
+ self,
+ value: Exception,
+ *,
+ inputs: Union[List[Node], Dict[str, Node]],
+ description: str = "[ExceptionNode] This is node containing the error of execution.",
+ constraint=None,
+ name=None,
+ info=None,
) -> None:
e = value
error_type = re.search(r"", str(type(e))).group(1)
- from opto import trace
+
value = f"({error_type}) {str(e)}"
- super().__init__(value, inputs=inputs, description=description, constraint=constraint, name=name, info=info)
+ super().__init__(
+ value,
+ inputs=inputs,
+ description=description,
+ constraint=constraint,
+ name=name,
+ info=info,
+ )
- def create_feedback(self, style='simple'):
- assert style in ('simple', 'full')
+ def create_feedback(self, style="simple"):
+ assert style in ("simple", "full")
feedback = self._data
- if style in ('line', 'full'):
- if type(self.info) == dict and self.info.get('error_comment') is not None:
- feedback = self.info['error_comment']
+ if style in ("line", "full"):
+ if type(self.info) is dict and self.info.get("error_comment") is not None:
+ feedback = self.info["error_comment"]
return feedback
if __name__ == "__main__":
x = node("Node X")
y = node("Node Y")
- z = MessageNode("Node Z", inputs={"x": x, "y": y}, description="[Add] This is an add operator of x and y.")
+ z = MessageNode(
+ "Node Z",
+ inputs={"x": x, "y": y},
+ description="[Add] This is an add operator of x and y.",
+ )
print(x.name, y.name)
print([p.name for p in z.parents])
@@ -2136,5 +2201,7 @@ def create_feedback(self, style='simple'):
x: Node[str] = node("Node X")
x: ParameterNode[str] = ParameterNode("Node X", trainable=True)
x: MessageNode[str] = MessageNode(
- "Node X", inputs={"x": x, "y": y}, description="[Add] This is an add operator of x and y."
+ "Node X",
+ inputs={"x": x, "y": y},
+ description="[Add] This is an add operator of x and y.",
)
diff --git a/opto/trace/operators.py b/opto/trace/operators.py
index 92e72839..fa7aee00 100644
--- a/opto/trace/operators.py
+++ b/opto/trace/operators.py
@@ -10,7 +10,7 @@
@bundle()
def clone(x: Any):
- """ This is a clone operator of x. """
+ """This is a clone operator of x."""
return copy.deepcopy(x)
@@ -24,37 +24,37 @@ def identity(x: Any):
@bundle()
def pos(x: Any):
- """ This is a pos operator of x. """
+ """This is a pos operator of x."""
return +x
@bundle()
def neg(x: Any):
- """ This is a neg operator of x. """
+ """This is a neg operator of x."""
return -x
@bundle()
def abs(x: Any):
- """ This is an abs operator of x. """
+ """This is an abs operator of x."""
return abs(x)
@bundle()
def invert(x: Any):
- """ This is an invert operator of x. """
+ """This is an invert operator of x."""
return ~x
@bundle()
def round(x: Any, n: Any):
- """ This is a round operator of x. """
+ """This is a round operator of x."""
return round(x, n)
@bundle()
def floor(x: Any):
- """ This is a floor operator of x. """
+ """This is a floor operator of x."""
import math
return math.floor(x)
@@ -62,7 +62,7 @@ def floor(x: Any):
@bundle()
def ceil(x: Any):
- """ This is a ceil operator of x. """
+ """This is a ceil operator of x."""
import math
return math.ceil(x)
@@ -70,7 +70,7 @@ def ceil(x: Any):
@bundle()
def trunc(x: Any):
- """ This is a trunc operator of x. """
+ """This is a trunc operator of x."""
import math
return math.trunc(x)
@@ -81,79 +81,79 @@ def trunc(x: Any):
@bundle()
def add(x: Any, y: Any):
- """ This is an add operator of x and y. """
+ """This is an add operator of x and y."""
return x + y
@bundle()
def subtract(x: Any, y: Any):
- """ This is a subtract operator of x and y. """
+ """This is a subtract operator of x and y."""
return x - y
@bundle()
def multiply(x: Any, y: Any):
- """ This is a multiply operator of x and y. """
+ """This is a multiply operator of x and y."""
return x * y
@bundle()
def floor_divide(x: Any, y: Any):
- """ This is a floor_divide operator of x and y. """
+ """This is a floor_divide operator of x and y."""
return x // y
@bundle()
def divide(x: Any, y: Any):
- """ This is a divide operator of x and y. """
+ """This is a divide operator of x and y."""
return x / y
@bundle()
def mod(x: Any, y: Any):
- """ This is a mod operator of x and y. """
+ """This is a mod operator of x and y."""
return x % y
@bundle()
def node_divmod(x: Any, y: Any):
- """ This is a divmod operator of x and y. """
+ """This is a divmod operator of x and y."""
return divmod(x, y)
@bundle()
def power(x: Any, y: Any):
- """ This is a power operator of x and y. """
+ """This is a power operator of x and y."""
return x**y
@bundle()
def lshift(x: Any, y: Any):
- """ This is a lshift operator of x and y. """
+ """This is a lshift operator of x and y."""
return x << y
@bundle()
def rshift(x: Any, y: Any):
- """ This is a rshift operator of x and y. """
+ """This is a rshift operator of x and y."""
return x >> y
@bundle()
def and_(x: Any, y: Any):
- """ This is an and operator of x and y. """
+ """This is an and operator of x and y."""
return x & y
@bundle()
def or_(x: Any, y: Any):
- """ This is an or operator of x and y. """
+ """This is an or operator of x and y."""
return x | y
@bundle()
def xor(x: Any, y: Any):
- """ This is a xor operator of x and y. """
+ """This is a xor operator of x and y."""
return x ^ y
@@ -162,41 +162,43 @@ def xor(x: Any, y: Any):
@bundle()
def lt(x: Any, y: Any):
- """ This is a lt operator of x and y. """
+ """This is a lt operator of x and y."""
return x < y
@bundle()
def le(x: Any, y: Any):
- """ This is a le operator of x and y. """
+ """This is a le operator of x and y."""
return x <= y
@bundle()
def eq(x: Any, y: Any):
- """ This is an eq operator of x and y. """
+ """This is an eq operator of x and y."""
return x == y
+
@bundle()
def neq(x: Any, y: Any):
- """ This is a not eq operator of x and y. """
+ """This is a not eq operator of x and y."""
return x != y
+
@bundle()
def ne(x: Any, y: Any):
- """ This is a ne operator of x and y. """
+ """This is a ne operator of x and y."""
return x != y
@bundle()
def ge(x: Any, y: Any):
- """ This is a ge operator of x and y. """
+ """This is a ge operator of x and y."""
return x >= y
@bundle()
def gt(x: Any, y: Any):
- """ This is a gt operator of x and y. """
+ """This is a gt operator of x and y."""
return x > y
@@ -205,140 +207,142 @@ def gt(x: Any, y: Any):
@bundle()
def cond(condition: Any, x: Any, y: Any):
- """ This selects x if condition is True, otherwise y. """
+ """This selects x if condition is True, otherwise y."""
x, y, condition = x, y, condition # This makes sure all data are read
return x if condition else y
@bundle()
def not_(x: Any):
- """ This is a not operator of x. """
+ """This is a not operator of x."""
return not x
@bundle()
def is_(x: Any, y: Any):
- """ Whether x is equal to y. """
+ """Whether x is equal to y."""
return x is y
@bundle()
def is_not(x: Any, y: Any):
- """ Whether x is not equal to y. """
+ """Whether x is not equal to y."""
return x is not y
@bundle()
def in_(x: Any, y: Any):
- """ Whether x is in y. """
+ """Whether x is in y."""
return x in y
@bundle()
def not_in(x: Any, y: Any):
- """ Whether x is not in y. """
+ """Whether x is not in y."""
return x not in y
# Indexing and slicing
@bundle()
def getitem(x: Any, index: Any):
- """ This is a getitem operator of x based on index. """
+ """This is a getitem operator of x based on index."""
return x[index]
@bundle()
def pop(x: Any, index: Any):
- """ This is a pop operator of x based on index. """
+ """This is a pop operator of x based on index."""
return x.pop(index)
@bundle()
def len_(x: Any):
- """ This is a len operator of x. """
+ """This is a len operator of x."""
return len(x)
# String operators
@bundle()
def ord_(x: Any):
- """ The unicode number of a character. """
+ """The unicode number of a character."""
return ord(x)
@bundle()
def chr_(x: Any):
- """ The character of a unicode number. """
+ """The character of a unicode number."""
return chr(x)
@bundle()
def concat(x: Any, y: Any):
- """ This is a concatenation operator of x and y. """
+ """This is a concatenation operator of x and y."""
return x + y
@bundle()
def lower(x: Any):
- """ This makes all characters in x lower case. """
+ """This makes all characters in x lower case."""
return x.lower()
@bundle()
def upper(x: Any):
- """ This makes all characters in x upper case. """
+ """This makes all characters in x upper case."""
return x.upper()
@bundle()
def title(x: Any):
- """ This makes the first character to upper case and the rest to lower case. """
+ """This makes the first character to upper case and the rest to lower case."""
return x.title()
@bundle()
def swapcase(x: Any):
- """ Swaps the case of all characters: uppercase character to lowercase and vice-versa. """
+ """Swaps the case of all characters: uppercase character to lowercase and vice-versa."""
return x.swapcase()
@bundle()
def capitalize(x: Any):
- """ Converts the first character of a string to uppercase. """
+ """Converts the first character of a string to uppercase."""
return x.capitalize()
@bundle()
def split(x: Any, y: Any, maxsplit: Any = -1):
- """ Splits the string by finding a substring y in string x, return the first part and second part of string x without y. """
+ """Splits the string by finding a substring y in string x, return the first part and second part of string x without y."""
return x.split(y, maxsplit)
@bundle()
def strip(x: Any, chars=None):
- """ Removes the leading and trailing characters of x. """
+ """Removes the leading and trailing characters of x."""
return x.strip(chars)
@bundle()
def replace(x: Any, old: Any, new: Any, count: Any = -1):
- """ Replaces all occurrences of substring y in string x with z. """
+ """Replaces all occurrences of substring y in string x with z."""
return x.replace(old, new, count)
@bundle()
def format(x: Any, *args, **kwargs):
- """ Fills in a string template with content, str.format(). """
+ """Fills in a string template with content, str.format()."""
return x.format(*args, **kwargs)
+
@bundle()
def join(x: Any, *y: Any):
- """ Joins a sequence y with different strs with x: "\n".join(["a", "b", "c"]) -> "a\nb\nc". """
+ """Joins a sequence y with different strs with x: "\n".join(["a", "b", "c"]) -> "a\nb\nc"."""
return x.join(y)
+
@bundle()
def node_getattr(obj: Node, attr: str):
- """ This operator gets attr of obj. """
+ """This operator gets attr of obj."""
return getattr(obj, attr)
@@ -347,7 +351,7 @@ def node_getattr(obj: Node, attr: str):
allow_external_dependencies=True,
)
def call(fun: Node, *args, **kwargs):
- """ This operator calls the function `fun` with args (args_0, args_1, etc.) and kwargs. If there are no args or kwargs, i.e. call(fun=function_name), the function takes no input. """
+ """This operator calls the function `fun` with args (args_0, args_1, etc.) and kwargs. If there are no args or kwargs, i.e. call(fun=function_name), the function takes no input."""
# Run the function as it is
fun = fun._data
# Call the node with the input arguments
@@ -358,125 +362,146 @@ def call(fun: Node, *args, **kwargs):
@bundle()
def to_list(x: Any):
- """ This converts x to a list. """
+ """This converts x to a list."""
return list(x)
+
@bundle()
def make_list(*args):
- """ This creates a list from the arguments. """
+ """This creates a list from the arguments."""
return list(args)
+
@bundle()
def to_dict(x: Any):
- """ This converts x to a dictionary. """
+ """This converts x to a dictionary."""
return dict(x)
+
@bundle()
def make_dict(**kwargs):
- """ This creates a dictionary from the keyword arguments. """
+ """This creates a dictionary from the keyword arguments."""
return kwargs
+
@bundle()
def to_set(x: Any):
- """ This converts x to a set. """
+ """This converts x to a set."""
return set(x)
+
@bundle()
def make_set(*args):
- """ This creates a set from the arguments. """
+ """This creates a set from the arguments."""
return set(args)
+
@bundle()
def to_tuple(x: Any):
- """ This converts x to a tuple. """
+ """This converts x to a tuple."""
return tuple(x)
+
@bundle()
def make_tuple(*args):
- """ This creates a tuple from the arguments. """
+ """This creates a tuple from the arguments."""
return tuple(args)
+
# dict operators
+
@bundle()
def keys(x: Dict):
- """ Return the keys of a dictionary x as a list. """
+ """Return the keys of a dictionary x as a list."""
if not isinstance(x, dict):
raise AttributeError(f"{type(x)} object has no attribute 'values'.")
return [k for k in x.keys()]
+
@bundle()
def values(x: Dict):
- """ Return the values of a dictionary x as a list. """
+ """Return the values of a dictionary x as a list."""
if not isinstance(x, dict):
raise AttributeError(f"{type(x)} object has no attribute 'values'.")
return [k for k in x.values()]
+
# dict in-place operators
+
@bundle()
def dict_update(x: Dict, y: Dict):
- """ Update the dictionary x with the dictionary y. """
+ """Update the dictionary x with the dictionary y."""
x = copy.copy(x)
x.update(y)
return x
+
@bundle()
def dict_pop(x: Dict, key: Any):
- """ Pop the key from the dictionary x. """
+ """Pop the key from the dictionary x."""
x = copy.copy(x)
x.pop(key)
return x
+
@bundle()
def dict_popitem(x: Dict):
- """ Pop the last item from the dictionary x. """
+ """Pop the last item from the dictionary x."""
x = copy.copy(x)
x.popitem()
return x
+
# list in-place operators
+
@bundle()
def list_append(x: Any, y: Any):
- """ Append y to x. """
+ """Append y to x."""
x = copy.copy(x)
x.append(y)
return x
+
@bundle()
def list_clear(x: Any):
- """ Clear x. """
+ """Clear x."""
x = copy.copy(x)
x.clear()
return x
+
@bundle()
def list_extend(x: Any, y: Any):
- """ Extend x with y. """
+ """Extend x with y."""
x = copy.copy(x)
x.extend(y)
return x
+
@bundle()
def list_insert(x: Any, index: Any, y: Any):
- """ Insert y at index in x. """
+ """Insert y at index in x."""
x = copy.copy(x)
x.insert(index, y)
return x
+
@bundle()
def list_pop(x: Any, index: Any):
- """ Pop the index from x. """
+ """Pop the index from x."""
x = copy.copy(x)
x.pop(index)
return x
+
@bundle()
def list_remove(x: Any, y: Any):
- """ Remove y from x. """
+ """Remove y from x."""
x = copy.copy(x)
x.remove(y)
return x
@@ -484,7 +509,7 @@ def list_remove(x: Any, y: Any):
@bundle()
def list_reverse(x: Any):
- """ Reverse x. """
+ """Reverse x."""
x = copy.copy(x)
x.reverse()
return x
@@ -492,7 +517,7 @@ def list_reverse(x: Any):
@bundle()
def list_sort(x: Any, key: Any = None, reverse: Any = False):
- """ Sort x. """
+ """Sort x."""
x = copy.copy(x)
x.sort(key=key, reverse=reverse)
return x
@@ -501,67 +526,76 @@ def list_sort(x: Any, key: Any = None, reverse: Any = False):
# set in-place operators
@bundle()
def set_add(x: Any, y: Any):
- """ Add y to x. """
+ """Add y to x."""
x = copy.copy(x)
x.add(y)
return x
+
@bundle()
def set_clear(x: Any):
- """ Clear x. """
+ """Clear x."""
x = copy.copy(x)
x.clear()
return x
+
@bundle()
def set_discard(x: Any, y: Any):
- """ Discard y from x. """
+ """Discard y from x."""
x = copy.copy(x)
x.discard(y)
return x
+
@bundle()
def set_intersection_update(x: Any, y: Any):
- """ Update x with the intersection of x and y. """
+ """Update x with the intersection of x and y."""
x = copy.copy(x)
x.intersection_update(y)
return x
+
@bundle()
def set_pop(x: Any):
- """ Pop an element from x. """
+ """Pop an element from x."""
x = copy.copy(x)
x.pop()
return x
+
@bundle()
def set_remove(x: Any, y: Any):
- """ Remove y from x. """
+ """Remove y from x."""
x = copy.copy(x)
x.remove(y)
return x
+
@bundle()
def set_symmetric_difference_update(x: Any, y: Any):
- """ Update x with the symmetric difference of x and y. """
+ """Update x with the symmetric difference of x and y."""
x = copy.copy(x)
x.symmetric_difference_update(y)
return x
+
@bundle()
def set_update(x: Any, y: Any):
- """ Update x with y. """
+ """Update x with y."""
x = copy.copy(x)
x.update(y)
return x
+
@bundle()
def call_llm(system_prompt, *user_prompts, **kwargs):
- """ Query the language model of system_prompt with user_prompts."""
+ """Query the language model of system_prompt with user_prompts."""
messages = [{"role": "system", "content": system_prompt}]
for user_prompt in user_prompts:
messages.append({"role": "user", "content": user_prompt})
from opto.utils.llm import AutoGenLLM
+
llm = AutoGenLLM()
response = llm(messages=messages, **kwargs)
- return response.choices[0].message.content
\ No newline at end of file
+ return response.choices[0].message.content
diff --git a/opto/trace/propagators/graph_propagator.py b/opto/trace/propagators/graph_propagator.py
index a8a025a3..6e5d9dbb 100644
--- a/opto/trace/propagators/graph_propagator.py
+++ b/opto/trace/propagators/graph_propagator.py
@@ -1,15 +1,25 @@
from dataclasses import dataclass
from typing import Any, List, Dict, Tuple
-from opto.trace.nodes import Node, MessageNode, ParameterNode, get_op_name, IDENTITY_OPERATORS, NodeVizStyleGuideColorful
+from opto.trace.nodes import (
+ Node,
+ MessageNode,
+ ParameterNode,
+ get_op_name,
+ IDENTITY_OPERATORS,
+ NodeVizStyleGuideColorful,
+)
from opto.trace.propagators.propagators import Propagator, AbstractFeedback
import heapq
from opto.trace.utils import sum_feedback
+
@dataclass
class TraceGraph(AbstractFeedback):
"""Feedback container used by GraphPropagator."""
- graph: List[Tuple[int,Node]] # a priority queue of nodes in the subgraph, ordered from roots to leaves
+ graph: List[
+ Tuple[int, Node]
+ ] # a priority queue of nodes in the subgraph, ordered from roots to leaves
user_feedback: Any
def empty(self):
@@ -23,9 +33,15 @@ def __add__(self, other):
self.user_feedback is None and other.user_feedback is None
), "One of the user feedback should not be None."
if self.user_feedback is None or other.user_feedback is None:
- user_feedback = self.user_feedback if other.user_feedback is None else other.user_feedback
+ user_feedback = (
+ self.user_feedback
+ if other.user_feedback is None
+ else other.user_feedback
+ )
else: # both are not None
- assert self.user_feedback == other.user_feedback, "user feedback should be the same for all children"
+ assert (
+ self.user_feedback == other.user_feedback
+ ), "user feedback should be the same for all children"
user_feedback = self.user_feedback
other_names = [id(n[1]) for n in other.graph]
@@ -35,21 +51,23 @@ def __add__(self, other):
graph = [x for x in heapq.merge(complement, other.graph, key=lambda x: x[0])]
return TraceGraph(graph=graph, user_feedback=user_feedback)
-
@classmethod
def expand(cls, node: MessageNode):
- """ Return the subgraph within a MessageNode. """
+ """Return the subgraph within a MessageNode."""
assert isinstance(node, MessageNode)
- if isinstance(node.info['output'], MessageNode):
+ if isinstance(node.info["output"], MessageNode):
# these are the nodes where we will collect the feedback
- roots = list(node.info['output'].parameter_dependencies) + \
- list(node.info['output'].expandable_dependencies) + \
- node.info['inputs']['args'] + [v for v in node.info['inputs']['kwargs'].values()]
+ roots = (
+ list(node.info["output"].parameter_dependencies)
+ + list(node.info["output"].expandable_dependencies)
+ + node.info["inputs"]["args"]
+ + [v for v in node.info["inputs"]["kwargs"].values()]
+ )
# remove old feedback, since we need to call backard again; we will restore it later
old_feedback = {p: p._feedback for p in roots}
for p in roots:
p.zero_feedback()
- node.info['output'].backward('', retain_graph=True)
+ node.info["output"].backward("", retain_graph=True)
subgraph = sum_feedback(roots)
# restore the old feedback
for p, feedback in old_feedback.items():
@@ -72,10 +90,12 @@ def visualize(self, simple_visualization=True, reverse_plot=False, print_limit=1
nvsg = NodeVizStyleGuideColorful(print_limit=print_limit)
- queue = sorted(self.graph, key=lambda x: x[0]) # sort by level
+ queue = sorted(self.graph, key=lambda x: x[0]) # sort by level
digraph = Digraph()
- if len(queue) == 1 and len(queue[0][1].parents) == 0: # This is a root. Nothing to propagate
+ if (
+ len(queue) == 1 and len(queue[0][1].parents) == 0
+ ): # This is a root. Nothing to propagate
digraph.node(queue[0][1].py_name, **nvsg.get_attrs(queue[0][1]))
return digraph
@@ -89,12 +109,17 @@ def visualize(self, simple_visualization=True, reverse_plot=False, print_limit=1
for parent in node.parents:
if self._itemize(parent) in queue:
# if there's a parent, add an edge, otherwise no need
- edge = (node.py_name, parent.py_name) if reverse_plot else (parent.py_name, node.py_name)
+ edge = (
+ (node.py_name, parent.py_name)
+ if reverse_plot
+ else (parent.py_name, node.py_name)
+ )
digraph.edge(*edge)
digraph.node(parent.py_name, **nvsg.get_attrs(parent))
return digraph
+
class GraphPropagator(Propagator):
"""A propagator that collects all the nodes seen in the path."""
@@ -103,8 +128,9 @@ def init_feedback(self, node, feedback: Any):
def _propagate(self, child: MessageNode):
graph = [(p.level, p) for p in child.parents] # add the parents
- graph = sorted(graph, key=lambda x: x[0]) # sort by level, heapq in TraceGraph requires this
- feedback = self.aggregate(child.feedback) + TraceGraph(graph=graph, user_feedback=None)
+ feedback = self.aggregate(child.feedback) + TraceGraph(
+ graph=graph, user_feedback=None
+ )
assert isinstance(feedback, TraceGraph)
# For including the external dependencies on parameters not visible
diff --git a/opto/trace/propagators/propagators.py b/opto/trace/propagators/propagators.py
index dba5274d..101f538f 100644
--- a/opto/trace/propagators/propagators.py
+++ b/opto/trace/propagators/propagators.py
@@ -36,6 +36,7 @@ def __radd__(self, other):
else:
return self.__add__(other)
+
class Propagator(AbstractPropagator):
def __init__(self):
self.override = dict() # key: operator name: data: override propagate function
@@ -70,6 +71,7 @@ def _propagate(self, child: MessageNode) -> Dict[Node, Any]:
# if len(feedback) > 1, it means there are two or more child nodes from this node,
# we might need to perform a "merge" feedback action
+
# # TODO test
class SumPropagator(Propagator):
def init_feedback(self, feedback: Any):
@@ -84,7 +86,9 @@ def _propagate(self, child: MessageNode):
# Simply sum the feedback
feedback_list = [v[0] for k, v in child.feedback.items()]
assert len(feedback_list) > 0
- assert all([type(feedback_list[0]) == type(f) for f in feedback_list]), "error in propagate"
+ assert all(
+ [type(feedback_list[0]) is type(f) for f in feedback_list]
+ ), "error in propagate"
if isinstance(feedback_list[0], str):
feedback = "".join(feedback_list)
else:
diff --git a/opto/trace/utils.py b/opto/trace/utils.py
index 999e6a30..10be762e 100644
--- a/opto/trace/utils.py
+++ b/opto/trace/utils.py
@@ -6,11 +6,13 @@
# Get a list of all names in the builtins module
builtins_list = dir(builtins)
# Filter for function names; this includes exceptions, so you might want to refine this
-global_functions_list = [name for name in builtins_list if callable(getattr(builtins, name))]
+global_functions_list = [
+ name for name in builtins_list if callable(getattr(builtins, name))
+]
def sum_feedback(nodes):
- """ Aggregate the feedback of a list of nodes. """
+ """Aggregate the feedback of a list of nodes."""
return sum([sum(gg) for p in nodes for gg in p.feedback.values()])
@@ -23,15 +25,16 @@ def parse_eqs_to_dict(text):
"""
Parse the text of equations into a dictionary
+ Example:
x0 = 1
x1=2
x2=`2`
- x3= def fun():\n print('hello')\n
+ x3= def fun():\\n print('hello')\\n
abc_test1=test
would be parsed into
- {'x0': '1', 'x1': '2', 'x2': '2', 'x3': "def fun():\nprint('hello')", 'abc_test1': 'test'}
+ {'x0': '1', 'x1': '2', 'x2': '2', 'x3': "def fun():\\nprint('hello')", 'abc_test1': 'test'}
"""
lines = text.split("\n")
result_dict = {}
@@ -64,33 +67,36 @@ def render_opt_step(step_idx, optimizer, no_trace_graph=False, no_improvement=Fa
from IPython.display import display, HTML
idx = step_idx
- llm_response = json.loads(optimizer.log[idx]['response'])
- r1 = llm_response['reasoning']
-
- if llm_response.get('suggestion'):
- a1 = ''.join(
- [f"{var_name}:\n\n{var_body}\n\n" for var_name, var_body in llm_response['suggestion'].items()]
+ llm_response = json.loads(optimizer.log[idx]["response"])
+ r1 = llm_response["reasoning"]
+
+ if llm_response.get("suggestion"):
+ a1 = "".join(
+ [
+ f"{var_name}:\n\n{var_body}\n\n"
+ for var_name, var_body in llm_response["suggestion"].items()
+ ]
)
- elif llm_response.get('answer') is not None:
- a1 = llm_response['answer']
+ elif llm_response.get("answer") is not None:
+ a1 = llm_response["answer"]
else:
a1 = " NULL/INVALID RESPONSE"
- pi = optimizer.summary_log[idx]['problem_instance'] # full
+ pi = optimizer.summary_log[idx]["problem_instance"] # full
f1 = pi.feedback
- masked = ['#Feedback', '#Others', '#Instruction']
- pi = optimizer.problem_instance(optimizer.summary_log[idx]['summary'], mask=masked)
+ masked = ["#Feedback", "#Others", "#Instruction"]
+ pi = optimizer.problem_instance(optimizer.summary_log[idx]["summary"], mask=masked)
# a hack to remove "#Feedback:" because it has a colon
pi = str(pi)
pi = pi.replace("#Feedback:", "#Feedback")
for m in masked:
- pi = pi.replace(m + '\n', '')
+ pi = pi.replace(m + "\n", "")
# a quick processing to reduce multiple empty lines to one
- pi = re.sub(r'\n\s*\n', '\n\n', pi)
+ pi = re.sub(r"\n\s*\n", "\n\n", pi)
g1 = pi
html_template = f"""
@@ -153,18 +159,19 @@ def escape_json_nested_quotes(json_str):
{"name": "string value", "value": "string value"}
Does not escape quotes around keys or structural quotes.
- Warning: here are what this function does not do:
- 1. Cannot handle "\\\n" or "\\\t" type of strings
- 2. Do not check if "\n" or "\t" or other control characters are properly escaped
- Please use json_str.replace("\n", "\\n") to escape control characters outside of this function
+ Warning:
+ Here are what this function does not do:
+ 1. Cannot handle "\\n" or "\\t" type of strings
+ 2. Does not check if "\\n", "\\t", or other control characters are properly escaped.
+ Please use json_str.replace("\\n", "\\n") to escape control characters outside of this function.
- Example usage can be found in optimizers/textgrad.py
+ Example usage can be found in optimizers/textgrad.py.
Args:
- json_str (str): A string representation of JSON with exactly two keys: name and value
+ json_str (str): A string representation of JSON with exactly two keys: name and value.
Returns:
- str: JSON string with properly escaped quotes in values
+ str: JSON string with properly escaped quotes in values.
"""
result = []
i = 0
@@ -174,21 +181,25 @@ def escape_json_nested_quotes(json_str):
if char == '"':
# Check if this quote is around "name" or "value"
- next_four = json_str[i + 1:i + 5]
- next_five = json_str[i + 1:i + 6]
- is_key = next_four == 'name' or next_five == 'value'
+ next_four = json_str[i + 1 : i + 5]
+ next_five = json_str[i + 1 : i + 6]
+ is_key = next_four == "name" or next_five == "value"
# Check if this is a structural quote (after : or before })
- prev_char = json_str[i - 1] if i > 0 else ''
- next_char = json_str[i + 1] if i < len(json_str) - 1 else ''
- is_value_boundary = prev_char == ':' or (
- prev_char == ' ' and json_str[i - 2] == ':') or next_char == '}' or next_char == ','
+ prev_char = json_str[i - 1] if i > 0 else ""
+ next_char = json_str[i + 1] if i < len(json_str) - 1 else ""
+ is_value_boundary = (
+ prev_char == ":"
+ or (prev_char == " " and json_str[i - 2] == ":")
+ or next_char == "}"
+ or next_char == ","
+ )
if is_key or is_value_boundary:
result.append(char)
- if prev_char == ':' or (prev_char == ' ' and json_str[i - 2] == ':'):
+ if prev_char == ":" or (prev_char == " " and json_str[i - 2] == ":"):
in_value = True
- if next_char == '}' or next_char == ',':
+ if next_char == "}" or next_char == ",":
in_value = False
else:
# if we double-escpaed like \\", we remove one
@@ -197,7 +208,7 @@ def escape_json_nested_quotes(json_str):
result.append(char)
# If we're in a value and this is not a boundary quote, escape it
elif in_value and prev_char != "\\":
- result.append(r'\"')
+ result.append(r"\"")
else:
result.append(char)
else:
@@ -206,7 +217,15 @@ def escape_json_nested_quotes(json_str):
# JSON can't accept any \ with invalid characters, in here we took a short cut and only keep \ for
# we didn't add \u to this list
- if json_str[i - 1] == "\\" and char not in ["\\", "\/", 'n', 'b', 'f', 'r', 't']:
+ if json_str[i - 1] == "\\" and char not in [
+ "\\",
+ "\/",
+ "n",
+ "b",
+ "f",
+ "r",
+ "t",
+ ]:
result.pop(-1)
result.append(char)
@@ -214,7 +233,7 @@ def escape_json_nested_quotes(json_str):
# print(in_value, ''.join(result))
i += 1
- return ''.join(result)
+ return "".join(result)
def remove_non_ascii(json_txt):
@@ -223,7 +242,7 @@ def remove_non_ascii(json_txt):
"""
cleaned = ""
for c in escape_json_nested_quotes(json_txt):
- if c not in ['\n', '\t', '\b', '\r', '\f'] and not c.isprintable():
+ if c not in ["\n", "\t", "\b", "\r", "\f"] and not c.isprintable():
continue
cleaned += c
return cleaned
@@ -233,36 +252,38 @@ def test_json_quote_escaper():
test_cases = [
(
'{"name": "Multiple "quotes" in "one" string", "value": "Multiple "quotes" in "the second" string"}',
- r'{"name": "Multiple \"quotes\" in \"one\" string", "value": "Multiple \"quotes\" in \"the second\" string"}'
+ r'{"name": "Multiple \"quotes\" in \"one\" string", "value": "Multiple \"quotes\" in \"the second\" string"}',
),
(
'{"name": "Simple "quote"", "value": "Another "quote""}',
- r'{"name": "Simple \"quote\"", "value": "Another \"quote\""}'
+ r'{"name": "Simple \"quote\"", "value": "Another \"quote\""}',
),
(
'{"name": "No quotes here", "value": "But "quotes" here"}',
- r'{"name": "No quotes here", "value": "But \"quotes\" here"}'
+ r'{"name": "No quotes here", "value": "But \"quotes\" here"}',
),
(
'{"name": "Quote at "end"", "value": "Another at "end""}',
- r'{"name": "Quote at \"end\"", "value": "Another at \"end\""}'
+ r'{"name": "Quote at \"end\"", "value": "Another at \"end\""}',
),
(
r'{"name": "Quote at "end"", "value": "Partial at \"end""}',
- r'{"name": "Quote at \"end\"", "value": "Partial at \"end\""}'
+ r'{"name": "Quote at \"end\"", "value": "Partial at \"end\""}',
),
(
r'{"name": "Quote at \\"end\\"", "value": "Partial at \"end""}',
- r'{"name": "Quote at \"end\"", "value": "Partial at \"end\""}'
+ r'{"name": "Quote at \"end\"", "value": "Partial at \"end\""}',
),
(
r'{"name": "Quote at \\"end\\"", "value": "\( \alpha_t \) \\n"}',
- r'{"name": "Quote at \"end\"", "value": "( alpha_t ) \\n"}'
- )
+ r'{"name": "Quote at \"end\"", "value": "( alpha_t ) \\n"}',
+ ),
]
for i, (input_str, expected) in enumerate(test_cases, 1):
result = escape_json_nested_quotes(input_str)
- assert result == expected, f'\nTest case {i} failed:\nInput: {input_str}\nExpected: {expected}\nGot: {result}'
+ assert (
+ result == expected
+ ), f"\nTest case {i} failed:\nInput: {input_str}\nExpected: {expected}\nGot: {result}"
print("All tests passed!")
diff --git a/opto/utils/llm.py b/opto/utils/llm.py
index bd2ac8bf..a67ccc52 100644
--- a/opto/utils/llm.py
+++ b/opto/utils/llm.py
@@ -4,12 +4,14 @@
import json
import autogen # We import autogen here to avoid the need of installing autogen
+
class AbstractModel:
"""
A minimal abstraction of a model api that refreshes the model every
reset_freq seconds (this is useful for long-running models that may require
refreshing certificates or memory management).
"""
+
def __init__(self, factory: Callable, reset_freq: Union[int, None] = None) -> None:
"""
Args:
@@ -29,25 +31,34 @@ def model(self):
# This is the main API
def __call__(self, *args, **kwargs) -> Any:
- """ The call function handles refreshing the model if needed. """
- if self.reset_freq is not None and time.time() - self._init_time > self.reset_freq:
+ """The call function handles refreshing the model if needed."""
+ if (
+ self.reset_freq is not None
+ and time.time() - self._init_time > self.reset_freq
+ ):
self._model = self.factory()
self._init_time = time.time()
return self.model(*args, **kwargs)
def __getstate__(self):
state = self.__dict__.copy()
- state['_model'] = None
+ state["_model"] = None
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._model = self.factory()
-class AutoGenLLM(AbstractModel):
- """ This is the main class Trace uses to interact with the model. It is a wrapper around autogen's OpenAIWrapper. For using models not supported by autogen, subclass AutoGenLLM and override the `_factory` and `create` method. Users can pass instances of this class to optimizers' llm argument. """
- def __init__(self, config_list: List = None, filter_dict: Dict = None, reset_freq: Union[int, None] = None) -> None:
+class AutoGenLLM(AbstractModel):
+ """This is the main class Trace uses to interact with the model. It is a wrapper around autogen's OpenAIWrapper. For using models not supported by autogen, subclass AutoGenLLM and override the `_factory` and `create` method. Users can pass instances of this class to optimizers' llm argument."""
+
+ def __init__(
+ self,
+ config_list: List = None,
+ filter_dict: Dict = None,
+ reset_freq: Union[int, None] = None,
+ ) -> None:
if config_list is None:
try:
config_list = autogen.config_list_from_json("OAI_CONFIG_LIST")
@@ -57,9 +68,9 @@ def __init__(self, config_list: List = None, filter_dict: Dict = None, reset_fre
os.environ.update({"OAI_CONFIG_LIST": json.dumps(config_list)})
config_list = autogen.config_list_from_json("OAI_CONFIG_LIST")
if filter_dict is not None:
- config_list = autogen.filter_config(config_list, filter_dict)
+ config_list = autogen.filter_config_list(config_list, filter_dict)
- factory = lambda *args, **kwargs : self._factory(config_list)
+ factory = lambda *args, **kwargs: self._factory(config_list)
super().__init__(factory, reset_freq)
@classmethod
@@ -68,7 +79,7 @@ def _factory(cls, config_list):
@property
def model(self):
- return lambda *args, **kwargs : self.create(*args, **kwargs)
+ return lambda *args, **kwargs: self.create(*args, **kwargs)
# This is main API. We use the API of autogen's OpenAIWrapper
def create(self, **config: Any) -> autogen.ModelClient.ModelClientResponseProtocol:
@@ -94,22 +105,23 @@ def create(self, **config: Any) -> autogen.ModelClient.ModelClientResponseProtoc
Note: this is a legacy argument. It is only used when the cache argument is not provided.
- filter_func (Callable | None): A function that takes in the context and the response
and returns a boolean to indicate whether the response is valid. E.g.,
-
- ```python
- def yes_or_no_filter(context, response):
- return context.get("yes_or_no_choice", False) is False or any(
- text in ["Yes.", "No."] for text in client.extract_text_or_completion_object(response)
- )
- ```
-
- allow_format_str_template (bool | None): Whether to allow format string template in the config. Default to false.
- api_version (str | None): The api version. Default to None. E.g., "2024-02-01".
+
+ Example:
+ >>> # filter_func example:
+ >>> def yes_or_no_filter(context, response):
+ >>> return context.get("yes_or_no_choice", False) is False or any(
+ >>> text in ["Yes.", "No."] for text in client.extract_text_or_completion_object(response)
+ >>> )
+
Raises:
- RuntimeError: If all declared custom model clients are not registered
- APIError: If any model client create call raises an APIError
"""
return self._model.create(**config)
+
def auto_construct_oai_config_list_from_env() -> List:
"""
Collect various API keys saved in the environment and return a format like:
@@ -120,7 +132,14 @@ def auto_construct_oai_config_list_from_env() -> List:
"""
config_list = []
if os.environ.get("OPENAI_API_KEY") is not None:
- config_list.append({"model": "gpt-4o", "api_key": os.environ.get("OPENAI_API_KEY")})
+ config_list.append(
+ {"model": "gpt-4o", "api_key": os.environ.get("OPENAI_API_KEY")}
+ )
if os.environ.get("ANTHROPIC_API_KEY") is not None:
- config_list.append({"model": "claude-3-5-sonnet-latest", "api_key": os.environ.get("ANTHROPIC_API_KEY")})
+ config_list.append(
+ {
+ "model": "claude-3-5-sonnet-latest",
+ "api_key": os.environ.get("ANTHROPIC_API_KEY"),
+ }
+ )
return config_list