|
6 | 6 | from contextlib import contextmanager
|
7 | 7 | import difflib
|
8 | 8 | import os
|
| 9 | +import re |
9 | 10 | import shutil
|
10 | 11 | import time
|
11 | 12 |
|
|
14 | 15 | import dotenv
|
15 | 16 | import typechat
|
16 | 17 |
|
| 18 | +from pydantic_ai import Agent |
17 | 19 |
|
18 | 20 | cap = min # More readable name for capping a value at some limit.
|
19 | 21 |
|
@@ -88,6 +90,7 @@ def create_translator[T](
|
88 | 90 |
|
89 | 91 | # Vibe-coded by o4-mini-high
|
90 | 92 | def list_diff(label_a, a, label_b, b, max_items):
|
| 93 | + """Print colorized diff between two sorted list of numbers.""" |
91 | 94 | sm = difflib.SequenceMatcher(None, a, b)
|
92 | 95 | a_out, b_out = [], []
|
93 | 96 | for tag, i1, i2, j1, j2 in sm.get_opcodes():
|
@@ -142,3 +145,80 @@ def fmt(row, seg_widths):
|
142 | 145 | seg_widths = widths[start:end]
|
143 | 146 | print(la, fmt(a_cols[start:end], seg_widths))
|
144 | 147 | print(lb, fmt(b_cols[start:end], seg_widths))
|
| 148 | + |
| 149 | + |
| 150 | +def setup_logfire(): |
| 151 | + """Configure logfire for pydantic_ai and httpx.""" |
| 152 | + |
| 153 | + import logfire |
| 154 | + |
| 155 | + def scrubbing_callback(m: logfire.ScrubMatch): |
| 156 | + """Instructions: Uncomment any block where you deem it safe to not scrub.""" |
| 157 | + # if m.path == ('attributes', 'http.request.header.authorization'): |
| 158 | + # return m.value |
| 159 | + |
| 160 | + # if m.path == ('attributes', 'http.request.header.api-key'): |
| 161 | + # return m.value |
| 162 | + |
| 163 | + if ( |
| 164 | + m.path == ("attributes", "http.request.body.text", "messages", 0, "content") |
| 165 | + and m.pattern_match.group(0) == "secret" |
| 166 | + ): |
| 167 | + return m.value |
| 168 | + |
| 169 | + # if m.path == ('attributes', 'http.response.header.azureml-model-session'): |
| 170 | + # return m.value |
| 171 | + |
| 172 | + logfire.configure(scrubbing=logfire.ScrubbingOptions(callback=scrubbing_callback)) |
| 173 | + logfire.instrument_pydantic_ai() |
| 174 | + logfire.instrument_httpx(capture_all=True) |
| 175 | + |
| 176 | + |
| 177 | +def make_agent[T](cls: type[T]) -> Agent[None, T]: |
| 178 | + """Create Pydantic AI agent using hardcoded preferences.""" |
| 179 | + from pydantic_ai import NativeOutput, ToolOutput |
| 180 | + from pydantic_ai.models.openai import OpenAIModel |
| 181 | + from pydantic_ai.providers.azure import AzureProvider |
| 182 | + from .auth import get_shared_token_provider |
| 183 | + |
| 184 | + # Prefer straight OpenAI over Azure OpenAI. |
| 185 | + if os.getenv("OPENAI_API_KEY"): |
| 186 | + Wrapper = NativeOutput |
| 187 | + print(f"## Using OpenAI with {Wrapper.__name__} ##") |
| 188 | + model = OpenAIModel("gpt-4o") # Retrieves OPENAI_API_KEY again. |
| 189 | + |
| 190 | + elif azure_openai_api_key := os.getenv("AZURE_OPENAI_API_KEY"): |
| 191 | + # This section is rather specific to our team's setup at Microsoft. |
| 192 | + if azure_openai_api_key == "identity": |
| 193 | + token_provider = get_shared_token_provider() |
| 194 | + azure_openai_api_key = token_provider.get_token() |
| 195 | + |
| 196 | + azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") |
| 197 | + if not azure_endpoint: |
| 198 | + raise RuntimeError("AZURE_OPENAI_ENDPOINT not found") |
| 199 | + |
| 200 | + print(f"## {azure_endpoint} ##") |
| 201 | + m = re.search(r"api-version=([\d-]+(?:preview)?)", azure_endpoint) |
| 202 | + if not m: |
| 203 | + raise RuntimeError( |
| 204 | + f"AZURE_OPENAI_ENDPOINT has no valid api-version field: {azure_endpoint}" |
| 205 | + ) |
| 206 | + api_version = m.group(1) |
| 207 | + Wrapper = ToolOutput |
| 208 | + |
| 209 | + print(f"## Using Azure {api_version} with {Wrapper.__name__} ##") |
| 210 | + model = OpenAIModel( |
| 211 | + "gpt-4o", |
| 212 | + provider=AzureProvider( |
| 213 | + azure_endpoint=azure_endpoint, |
| 214 | + api_version=api_version, |
| 215 | + api_key=azure_openai_api_key, |
| 216 | + ), |
| 217 | + ) |
| 218 | + |
| 219 | + else: |
| 220 | + raise RuntimeError( |
| 221 | + "Neither OPENAI_API_KEY nor AZURE_OPENAI_API_KEY was provided." |
| 222 | + ) |
| 223 | + |
| 224 | + return Agent(model, output_type=Wrapper(cls, strict=True), retries=3) |
0 commit comments