Skip to content

Commit f96df6c

Browse files
[python/knowpro] Add two new utilities related to Pydantic (#1429)
* 7690421 utils.py: Add make_agent(), a utility to create a Pydantic AI agent * 7cc302b Add setup_logfire(), a utility for setting up Pydantic's logfire
1 parent 8f9b044 commit f96df6c

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed

python/ta/typeagent/aitools/utils.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from contextlib import contextmanager
77
import difflib
88
import os
9+
import re
910
import shutil
1011
import time
1112

@@ -14,6 +15,7 @@
1415
import dotenv
1516
import typechat
1617

18+
from pydantic_ai import Agent
1719

1820
cap = min # More readable name for capping a value at some limit.
1921

@@ -88,6 +90,7 @@ def create_translator[T](
8890

8991
# Vibe-coded by o4-mini-high
9092
def list_diff(label_a, a, label_b, b, max_items):
93+
"""Print colorized diff between two sorted list of numbers."""
9194
sm = difflib.SequenceMatcher(None, a, b)
9295
a_out, b_out = [], []
9396
for tag, i1, i2, j1, j2 in sm.get_opcodes():
@@ -142,3 +145,80 @@ def fmt(row, seg_widths):
142145
seg_widths = widths[start:end]
143146
print(la, fmt(a_cols[start:end], seg_widths))
144147
print(lb, fmt(b_cols[start:end], seg_widths))
148+
149+
150+
def setup_logfire():
151+
"""Configure logfire for pydantic_ai and httpx."""
152+
153+
import logfire
154+
155+
def scrubbing_callback(m: logfire.ScrubMatch):
156+
"""Instructions: Uncomment any block where you deem it safe to not scrub."""
157+
# if m.path == ('attributes', 'http.request.header.authorization'):
158+
# return m.value
159+
160+
# if m.path == ('attributes', 'http.request.header.api-key'):
161+
# return m.value
162+
163+
if (
164+
m.path == ("attributes", "http.request.body.text", "messages", 0, "content")
165+
and m.pattern_match.group(0) == "secret"
166+
):
167+
return m.value
168+
169+
# if m.path == ('attributes', 'http.response.header.azureml-model-session'):
170+
# return m.value
171+
172+
logfire.configure(scrubbing=logfire.ScrubbingOptions(callback=scrubbing_callback))
173+
logfire.instrument_pydantic_ai()
174+
logfire.instrument_httpx(capture_all=True)
175+
176+
177+
def make_agent[T](cls: type[T]) -> Agent[None, T]:
178+
"""Create Pydantic AI agent using hardcoded preferences."""
179+
from pydantic_ai import NativeOutput, ToolOutput
180+
from pydantic_ai.models.openai import OpenAIModel
181+
from pydantic_ai.providers.azure import AzureProvider
182+
from .auth import get_shared_token_provider
183+
184+
# Prefer straight OpenAI over Azure OpenAI.
185+
if os.getenv("OPENAI_API_KEY"):
186+
Wrapper = NativeOutput
187+
print(f"## Using OpenAI with {Wrapper.__name__} ##")
188+
model = OpenAIModel("gpt-4o") # Retrieves OPENAI_API_KEY again.
189+
190+
elif azure_openai_api_key := os.getenv("AZURE_OPENAI_API_KEY"):
191+
# This section is rather specific to our team's setup at Microsoft.
192+
if azure_openai_api_key == "identity":
193+
token_provider = get_shared_token_provider()
194+
azure_openai_api_key = token_provider.get_token()
195+
196+
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
197+
if not azure_endpoint:
198+
raise RuntimeError("AZURE_OPENAI_ENDPOINT not found")
199+
200+
print(f"## {azure_endpoint} ##")
201+
m = re.search(r"api-version=([\d-]+(?:preview)?)", azure_endpoint)
202+
if not m:
203+
raise RuntimeError(
204+
f"AZURE_OPENAI_ENDPOINT has no valid api-version field: {azure_endpoint}"
205+
)
206+
api_version = m.group(1)
207+
Wrapper = ToolOutput
208+
209+
print(f"## Using Azure {api_version} with {Wrapper.__name__} ##")
210+
model = OpenAIModel(
211+
"gpt-4o",
212+
provider=AzureProvider(
213+
azure_endpoint=azure_endpoint,
214+
api_version=api_version,
215+
api_key=azure_openai_api_key,
216+
),
217+
)
218+
219+
else:
220+
raise RuntimeError(
221+
"Neither OPENAI_API_KEY nor AZURE_OPENAI_API_KEY was provided."
222+
)
223+
224+
return Agent(model, output_type=Wrapper(cls, strict=True), retries=3)

0 commit comments

Comments
 (0)