diff --git a/examples/mistral/agents/async_multi_turn_conversation.py b/examples/mistral/agents/async_multi_turn_conversation.py new file mode 100644 index 0000000..d24443c --- /dev/null +++ b/examples/mistral/agents/async_multi_turn_conversation.py @@ -0,0 +1,69 @@ +import os +from mistralai import Mistral + +from mistralai.extra.run.context import RunContext +import logging +import time +import asyncio + + +MODEL = "mistral-medium-latest" + +USER_MESSAGE = """ +Please make the Secret Santa for me +To properly do it you need to: +- Get the friend you were assigned to (using the get_secret_santa_assignment function) +- Read into his gift wishlist what they would like to receive (using the get_gift_wishlist function) +- Buy the gift (using the buy_gift function) +- Find the best website to buy the gift using a web search +- Send it to them (using the send_gift function) +""" + + +async def main(): + api_key = os.environ["MISTRAL_API_KEY"] + mistral_agent_id = os.environ["MISTRAL_AGENT_ID"] + client = Mistral( + api_key=api_key, debug_logger=logging.getLogger("mistralai") + ) + + async with RunContext( + agent_id=mistral_agent_id + ) as run_context: + run_context.register_func(get_secret_santa_assignment) + run_context.register_func(get_gift_wishlist) + run_context.register_func(buy_gift) + run_context.register_func(send_gift) + + await client.beta.conversations.run_async( + run_ctx=run_context, + inputs=USER_MESSAGE, + ) + + +def get_secret_santa_assignment(): + """Get the friend you were assigned to""" + time.sleep(2) + return "John Doe" + + +def get_gift_wishlist(friend_name: str): + """Get the gift wishlist of the friend you were assigned to""" + time.sleep(1.5) + return ["Book", "Chocolate", "T-Shirt"] + + +def buy_gift(gift_name: str): + """Buy the gift you want to send to your friend""" + time.sleep(1.1) + return f"Bought {gift_name}" + + +def send_gift(friend_name: str, gift_name: str, website: str): + """Send the gift to your friend""" + time.sleep(2.2) + return f"Sent {gift_name} to {friend_name} bought on {website}" + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index e3a652f..5167611 100644 --- a/poetry.lock +++ b/poetry.lock @@ -178,10 +178,9 @@ pycparser = "*" name = "charset-normalizer" version = "3.4.0" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." -optional = true +optional = false python-versions = ">=3.7.0" groups = ["main"] -markers = "extra == \"gcp\"" files = [ {file = "charset_normalizer-3.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4f9fc98dad6c2eaa32fc3af1417d95b5e3d08aff968df0cd320066def971f9a6"}, {file = "charset_normalizer-3.4.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0de7b687289d3c1b3e8660d0741874abe7888100efe14bd0f9fd7141bcbda92b"}, @@ -502,6 +501,24 @@ pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] reauth = ["pyu2f (>=0.1.5)"] requests = ["requests (>=2.20.0,<3.0.0.dev0)"] +[[package]] +name = "googleapis-common-protos" +version = "1.70.0" +description = "Common protobufs used in Google APIs" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "googleapis_common_protos-1.70.0-py3-none-any.whl", hash = "sha256:b8bfcca8c25a2bb253e0e0b0adaf8c00773e5e6af6fd92397576680b807e0fd8"}, + {file = "googleapis_common_protos-1.70.0.tar.gz", hash = "sha256:0e1b44e0ea153e6594f9f394fef15193a68aaaea2d843f83e2742717ca753257"}, +] + +[package.dependencies] +protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0" + +[package.extras] +grpc = ["grpcio (>=1.44.0,<2.0.0)"] + [[package]] name = "griffe" version = "1.7.3" @@ -608,6 +625,30 @@ markers = {dev = "python_version >= \"3.10\""} [package.extras] all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] +[[package]] +name = "importlib-metadata" +version = "8.7.0" +description = "Read metadata from Python packages" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd"}, + {file = "importlib_metadata-8.7.0.tar.gz", hash = "sha256:d13b81ad223b890aa16c5471f2ac3056cf76c5f10f82d6f9292f0b415f389000"}, +] + +[package.dependencies] +zipp = ">=3.20" + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +enabler = ["pytest-enabler (>=2.2)"] +perf = ["ipython"] +test = ["flufl.flake8", "importlib_resources (>=1.3) ; python_version < \"3.9\"", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-perf (>=0.9.2)"] +type = ["pytest-mypy"] + [[package]] name = "iniconfig" version = "2.0.0" @@ -766,6 +807,106 @@ files = [ {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, ] +[[package]] +name = "opentelemetry-api" +version = "1.38.0" +description = "OpenTelemetry Python API" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "opentelemetry_api-1.38.0-py3-none-any.whl", hash = "sha256:2891b0197f47124454ab9f0cf58f3be33faca394457ac3e09daba13ff50aa582"}, + {file = "opentelemetry_api-1.38.0.tar.gz", hash = "sha256:f4c193b5e8acb0912b06ac5b16321908dd0843d75049c091487322284a3eea12"}, +] + +[package.dependencies] +importlib-metadata = ">=6.0,<8.8.0" +typing-extensions = ">=4.5.0" + +[[package]] +name = "opentelemetry-exporter-otlp-proto-common" +version = "1.38.0" +description = "OpenTelemetry Protobuf encoding" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "opentelemetry_exporter_otlp_proto_common-1.38.0-py3-none-any.whl", hash = "sha256:03cb76ab213300fe4f4c62b7d8f17d97fcfd21b89f0b5ce38ea156327ddda74a"}, + {file = "opentelemetry_exporter_otlp_proto_common-1.38.0.tar.gz", hash = "sha256:e333278afab4695aa8114eeb7bf4e44e65c6607d54968271a249c180b2cb605c"}, +] + +[package.dependencies] +opentelemetry-proto = "1.38.0" + +[[package]] +name = "opentelemetry-exporter-otlp-proto-http" +version = "1.38.0" +description = "OpenTelemetry Collector Protobuf over HTTP Exporter" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "opentelemetry_exporter_otlp_proto_http-1.38.0-py3-none-any.whl", hash = "sha256:84b937305edfc563f08ec69b9cb2298be8188371217e867c1854d77198d0825b"}, + {file = "opentelemetry_exporter_otlp_proto_http-1.38.0.tar.gz", hash = "sha256:f16bd44baf15cbe07633c5112ffc68229d0edbeac7b37610be0b2def4e21e90b"}, +] + +[package.dependencies] +googleapis-common-protos = ">=1.52,<2.0" +opentelemetry-api = ">=1.15,<2.0" +opentelemetry-exporter-otlp-proto-common = "1.38.0" +opentelemetry-proto = "1.38.0" +opentelemetry-sdk = ">=1.38.0,<1.39.0" +requests = ">=2.7,<3.0" +typing-extensions = ">=4.5.0" + +[[package]] +name = "opentelemetry-proto" +version = "1.38.0" +description = "OpenTelemetry Python Proto" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "opentelemetry_proto-1.38.0-py3-none-any.whl", hash = "sha256:b6ebe54d3217c42e45462e2a1ae28c3e2bf2ec5a5645236a490f55f45f1a0a18"}, + {file = "opentelemetry_proto-1.38.0.tar.gz", hash = "sha256:88b161e89d9d372ce723da289b7da74c3a8354a8e5359992be813942969ed468"}, +] + +[package.dependencies] +protobuf = ">=5.0,<7.0" + +[[package]] +name = "opentelemetry-sdk" +version = "1.38.0" +description = "OpenTelemetry Python SDK" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "opentelemetry_sdk-1.38.0-py3-none-any.whl", hash = "sha256:1c66af6564ecc1553d72d811a01df063ff097cdc82ce188da9951f93b8d10f6b"}, + {file = "opentelemetry_sdk-1.38.0.tar.gz", hash = "sha256:93df5d4d871ed09cb4272305be4d996236eedb232253e3ab864c8620f051cebe"}, +] + +[package.dependencies] +opentelemetry-api = "1.38.0" +opentelemetry-semantic-conventions = "0.59b0" +typing-extensions = ">=4.5.0" + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.59b0" +description = "OpenTelemetry Semantic Conventions" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "opentelemetry_semantic_conventions-0.59b0-py3-none-any.whl", hash = "sha256:35d3b8833ef97d614136e253c1da9342b4c3c083bbaf29ce31d572a1c3825eed"}, + {file = "opentelemetry_semantic_conventions-0.59b0.tar.gz", hash = "sha256:7a6db3f30d70202d5bf9fa4b69bc866ca6a30437287de6c510fb594878aed6b0"}, +] + +[package.dependencies] +opentelemetry-api = "1.38.0" +typing-extensions = ">=4.5.0" + [[package]] name = "packaging" version = "24.2" @@ -811,6 +952,26 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "protobuf" +version = "6.33.0" +description = "" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "protobuf-6.33.0-cp310-abi3-win32.whl", hash = "sha256:d6101ded078042a8f17959eccd9236fb7a9ca20d3b0098bbcb91533a5680d035"}, + {file = "protobuf-6.33.0-cp310-abi3-win_amd64.whl", hash = "sha256:9a031d10f703f03768f2743a1c403af050b6ae1f3480e9c140f39c45f81b13ee"}, + {file = "protobuf-6.33.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:905b07a65f1a4b72412314082c7dbfae91a9e8b68a0cc1577515f8df58ecf455"}, + {file = "protobuf-6.33.0-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:e0697ece353e6239b90ee43a9231318302ad8353c70e6e45499fa52396debf90"}, + {file = "protobuf-6.33.0-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:e0a1715e4f27355afd9570f3ea369735afc853a6c3951a6afe1f80d8569ad298"}, + {file = "protobuf-6.33.0-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:35be49fd3f4fefa4e6e2aacc35e8b837d6703c37a2168a55ac21e9b1bc7559ef"}, + {file = "protobuf-6.33.0-cp39-cp39-win32.whl", hash = "sha256:cd33a8e38ea3e39df66e1bbc462b076d6e5ba3a4ebbde58219d777223a7873d3"}, + {file = "protobuf-6.33.0-cp39-cp39-win_amd64.whl", hash = "sha256:c963e86c3655af3a917962c9619e1a6b9670540351d7af9439d06064e3317cc9"}, + {file = "protobuf-6.33.0-py3-none-any.whl", hash = "sha256:25c9e1963c6734448ea2d308cfa610e692b801304ba0908d7bfa564ac5132995"}, + {file = "protobuf-6.33.0.tar.gz", hash = "sha256:140303d5c8d2037730c548f8c7b93b20bb1dc301be280c378b82b8894589c954"}, +] + [[package]] name = "pyasn1" version = "0.6.1" @@ -1219,10 +1380,9 @@ files = [ name = "requests" version = "2.32.3" description = "Python HTTP for Humans." -optional = true +optional = false python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"gcp\"" files = [ {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, @@ -1472,10 +1632,9 @@ typing-extensions = ">=4.12.0" name = "urllib3" version = "2.2.3" description = "HTTP library with thread-safe connection pooling, file post, and more." -optional = true +optional = false python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"gcp\"" files = [ {file = "urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac"}, {file = "urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9"}, @@ -1508,6 +1667,26 @@ typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} [package.extras] standard = ["colorama (>=0.4) ; sys_platform == \"win32\"", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1) ; sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\"", "watchfiles (>=0.13)", "websockets (>=10.4)"] +[[package]] +name = "zipp" +version = "3.23.0" +description = "Backport of pathlib-compatible object wrapper for zip files" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e"}, + {file = "zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166"}, +] + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more_itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] +type = ["pytest-mypy"] + [extras] agents = ["authlib", "griffe", "mcp"] gcp = ["google-auth", "requests"] @@ -1515,4 +1694,4 @@ gcp = ["google-auth", "requests"] [metadata] lock-version = "2.1" python-versions = ">=3.9" -content-hash = "84dda1a6ae0a8491ec9f64e6500480e7ef2e177812a624e388127f354c8e844c" +content-hash = "9d707321f2730f9d1e581d43778dd605a83fdc3d3c375f597b1a2dabb2584ba0" diff --git a/pyproject.toml b/pyproject.toml index 087b570..4bea662 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "mistralai" -version = "1.9.11" +version = "1.9.12" description = "Python Client SDK for the Mistral AI API." authors = [{ name = "Mistral" },] readme = "README-PYPI.md" @@ -13,6 +13,10 @@ dependencies = [ "typing-inspection >=0.4.0", "pyyaml (>=6.0.2,<7.0.0)", "invoke (>=2.2.0,<3.0.0)", + "opentelemetry-sdk (>=1.33.1,<2.0.0)", + "opentelemetry-api (>=1.33.1,<2.0.0)", + "opentelemetry-exporter-otlp-proto-http (>=1.37.0,<2.0.0)", + "opentelemetry-semantic-conventions (>=0.59b0,<0.60)", ] [tool.poetry] diff --git a/src/mistralai/_hooks/registration.py b/src/mistralai/_hooks/registration.py index fc3ae79..58bebab 100644 --- a/src/mistralai/_hooks/registration.py +++ b/src/mistralai/_hooks/registration.py @@ -1,5 +1,6 @@ from .custom_user_agent import CustomUserAgentHook from .deprecation_warning import DeprecationWarningHook +from .tracing import TracingHook from .types import Hooks # This file is only ever generated once on the first generation and then is free to be modified. @@ -13,5 +14,9 @@ def init_hooks(hooks: Hooks): with an instance of a hook that implements that specific Hook interface Hooks are registered per SDK instance, and are valid for the lifetime of the SDK instance """ + tracing_hook = TracingHook() hooks.register_before_request_hook(CustomUserAgentHook()) hooks.register_after_success_hook(DeprecationWarningHook()) + hooks.register_after_success_hook(tracing_hook) + hooks.register_before_request_hook(tracing_hook) + hooks.register_after_error_hook(tracing_hook) diff --git a/src/mistralai/_hooks/tracing.py b/src/mistralai/_hooks/tracing.py new file mode 100644 index 0000000..f2ac9c8 --- /dev/null +++ b/src/mistralai/_hooks/tracing.py @@ -0,0 +1,50 @@ +import logging +from typing import Optional, Tuple, Union + +import httpx +from opentelemetry.trace import Span + +from ..extra.observability.otel import ( + get_or_create_otel_tracer, + get_response_and_error, + get_traced_request_and_span, + get_traced_response, +) +from .types import ( + AfterErrorContext, + AfterErrorHook, + AfterSuccessContext, + AfterSuccessHook, + BeforeRequestContext, + BeforeRequestHook, +) + +logger = logging.getLogger(__name__) + + +class TracingHook(BeforeRequestHook, AfterSuccessHook, AfterErrorHook): + def __init__(self) -> None: + self.tracing_enabled, self.tracer = get_or_create_otel_tracer() + self.request_span: Optional[Span] = None + + def before_request( + self, hook_ctx: BeforeRequestContext, request: httpx.Request + ) -> Union[httpx.Request, Exception]: + request, self.request_span = get_traced_request_and_span(tracing_enabled=self.tracing_enabled, tracer=self.tracer, span=self.request_span, operation_id=hook_ctx.operation_id, request=request) + return request + + def after_success( + self, hook_ctx: AfterSuccessContext, response: httpx.Response + ) -> Union[httpx.Response, Exception]: + response = get_traced_response(tracing_enabled=self.tracing_enabled, tracer=self.tracer, span=self.request_span, operation_id=hook_ctx.operation_id, response=response) + return response + + def after_error( + self, + hook_ctx: AfterErrorContext, + response: Optional[httpx.Response], + error: Optional[Exception], + ) -> Union[Tuple[Optional[httpx.Response], Optional[Exception]], Exception]: + if response: + response, error = get_response_and_error(tracing_enabled=self.tracing_enabled, tracer=self.tracer, span=self.request_span, operation_id=hook_ctx.operation_id, response=response, error=error) + return response, error diff --git a/src/mistralai/conversations.py b/src/mistralai/conversations.py index 27edded..64551a9 100644 --- a/src/mistralai/conversations.py +++ b/src/mistralai/conversations.py @@ -26,8 +26,10 @@ reconstitue_entries, ) from mistralai.extra.run.utils import run_requirements +from mistralai.extra.observability.otel import GenAISpanEnum, get_or_create_otel_tracer logger = logging.getLogger(__name__) +tracing_enabled, tracer = get_or_create_otel_tracer() if typing.TYPE_CHECKING: from mistralai.extra.run.context import RunContext @@ -67,50 +69,52 @@ async def run_async( from mistralai.extra.run.context import _validate_run from mistralai.extra.run.tools import get_function_calls - req, run_result, input_entries = await _validate_run( - beta_client=Beta(self.sdk_configuration), - run_ctx=run_ctx, - inputs=inputs, - instructions=instructions, - tools=tools, - completion_args=completion_args, - ) + with tracer.start_as_current_span(GenAISpanEnum.VALIDATE_RUN.value): + req, run_result, input_entries = await _validate_run( + beta_client=Beta(self.sdk_configuration), + run_ctx=run_ctx, + inputs=inputs, + instructions=instructions, + tools=tools, + completion_args=completion_args, + ) - while True: - if run_ctx.conversation_id is None: - res = await self.start_async( - inputs=input_entries, - http_headers=http_headers, - name=name, - description=description, - retries=retries, - server_url=server_url, - timeout_ms=timeout_ms, - **req, # type: ignore - ) - run_result.conversation_id = res.conversation_id - run_ctx.conversation_id = res.conversation_id - logger.info( - f"Started Run with conversation with id {res.conversation_id}" - ) - else: - res = await self.append_async( - conversation_id=run_ctx.conversation_id, - inputs=input_entries, - retries=retries, - server_url=server_url, - timeout_ms=timeout_ms, - ) - run_ctx.request_count += 1 - run_result.output_entries.extend(res.outputs) - fcalls = get_function_calls(res.outputs) - if not fcalls: - logger.debug("No more function calls to execute") - break - else: - fresults = await run_ctx.execute_function_calls(fcalls) - run_result.output_entries.extend(fresults) - input_entries = typing.cast(list[InputEntries], fresults) + with tracer.start_as_current_span(GenAISpanEnum.CONVERSATION.value): + while True: + if run_ctx.conversation_id is None: + res = await self.start_async( + inputs=input_entries, + http_headers=http_headers, + name=name, + description=description, + retries=retries, + server_url=server_url, + timeout_ms=timeout_ms, + **req, # type: ignore + ) + run_result.conversation_id = res.conversation_id + run_ctx.conversation_id = res.conversation_id + logger.info( + f"Started Run with conversation with id {res.conversation_id}" + ) + else: + res = await self.append_async( + conversation_id=run_ctx.conversation_id, + inputs=input_entries, + retries=retries, + server_url=server_url, + timeout_ms=timeout_ms, + ) + run_ctx.request_count += 1 + run_result.output_entries.extend(res.outputs) + fcalls = get_function_calls(res.outputs) + if not fcalls: + logger.debug("No more function calls to execute") + break + else: + fresults = await run_ctx.execute_function_calls(fcalls) + run_result.output_entries.extend(fresults) + input_entries = typing.cast(list[InputEntries], fresults) return run_result @run_requirements diff --git a/src/mistralai/extra/observability/__init__.py b/src/mistralai/extra/observability/__init__.py new file mode 100644 index 0000000..4ff5873 --- /dev/null +++ b/src/mistralai/extra/observability/__init__.py @@ -0,0 +1,15 @@ +from contextlib import contextmanager + +from opentelemetry import trace as otel_trace + +from .otel import MISTRAL_SDK_OTEL_TRACER_NAME + + +@contextmanager +def trace(name: str, **kwargs): + tracer = otel_trace.get_tracer(MISTRAL_SDK_OTEL_TRACER_NAME) + with tracer.start_as_current_span(name, **kwargs) as span: + yield span + + +__all__ = ["trace"] diff --git a/src/mistralai/extra/observability/otel.py b/src/mistralai/extra/observability/otel.py new file mode 100644 index 0000000..46c667d --- /dev/null +++ b/src/mistralai/extra/observability/otel.py @@ -0,0 +1,393 @@ +import copy +import json +import logging +import os +import traceback +from datetime import datetime, timezone +from enum import Enum +from typing import Optional, Tuple + +import httpx +import opentelemetry.semconv._incubating.attributes.gen_ai_attributes as gen_ai_attributes +import opentelemetry.semconv._incubating.attributes.http_attributes as http_attributes +import opentelemetry.semconv.attributes.server_attributes as server_attributes +from opentelemetry import propagate, trace +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.resources import SERVICE_NAME, Resource +from opentelemetry.sdk.trace import SpanProcessor, TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExportResult +from opentelemetry.trace import Span, Status, StatusCode, Tracer, set_span_in_context + +logger = logging.getLogger(__name__) + + +OTEL_SERVICE_NAME: str = "mistralai_sdk" +OTEL_EXPORTER_OTLP_ENDPOINT: str = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "") +OTEL_EXPORTER_OTLP_TIMEOUT: int = int(os.getenv("OTEL_EXPORTER_OTLP_TIMEOUT", "2")) +OTEL_EXPORTER_OTLP_MAX_EXPORT_BATCH_SIZE: int = int(os.getenv("OTEL_EXPORTER_OTLP_MAX_EXPORT_BATCH_SIZE", "512")) +OTEL_EXPORTER_OTLP_SCHEDULE_DELAY_MILLIS: int = int(os.getenv("OTEL_EXPORTER_OTLP_SCHEDULE_DELAY_MILLIS", "1000")) +OTEL_EXPORTER_OTLP_MAX_QUEUE_SIZE: int = int(os.getenv("OTEL_EXPORTER_OTLP_MAX_QUEUE_SIZE", "2048")) +OTEL_EXPORTER_OTLP_EXPORT_TIMEOUT_MILLIS: int = int(os.getenv("OTEL_EXPORTER_OTLP_EXPORT_TIMEOUT_MILLIS", "5000")) + +MISTRAL_SDK_OTEL_TRACER_NAME: str = OTEL_SERVICE_NAME + "_tracer" + +MISTRAL_SDK_DEBUG_TRACING: bool = os.getenv("MISTRAL_SDK_DEBUG_TRACING", "false").lower() == "true" +DEBUG_HINT: str = "To see detailed exporter logs, set MISTRAL_SDK_DEBUG_TRACING=true." + + +class MistralAIAttributes: + MISTRAL_AI_TOTAL_TOKENS = "mistral_ai.request.total_tokens" + MISTRAL_AI_TOOL_CALL_ARGUMENTS = "mistral_ai.tool.call.arguments" + MISTRAL_AI_MESSAGE_ID = "mistral_ai.message.id" + MISTRAL_AI_OPERATION_NAME= "mistral_ai.operation.name" + MISTRAL_AI_OCR_USAGE_PAGES_PROCESSED = "mistral_ai.ocr.usage.pages_processed" + MISTRAL_AI_OCR_USAGE_DOC_SIZE_BYTES = "mistral_ai.ocr.usage.doc_size_bytes" + MISTRAL_AI_OPERATION_ID = "mistral_ai.operation.id" + MISTRAL_AI_ERROR_TYPE = "mistral_ai.error.type" + MISTRAL_AI_ERROR_MESSAGE = "mistral_ai.error.message" + MISTRAL_AI_ERROR_CODE = "mistral_ai.error.code" + MISTRAL_AI_FUNCTION_CALL_ARGUMENTS = "mistral_ai.function.call.arguments" + +class MistralAINameValues(Enum): + OCR = "ocr" + +class TracingErrors(Exception, Enum): + FAILED_TO_EXPORT_OTEL_SPANS = "Failed to export OpenTelemetry (OTEL) spans." + FAILED_TO_INITIALIZE_OPENTELEMETRY_TRACING = "Failed to initialize OpenTelemetry tracing." + FAILED_TO_CREATE_SPAN_FOR_REQUEST = "Failed to create span for request." + FAILED_TO_ENRICH_SPAN_WITH_RESPONSE = "Failed to enrich span with response." + FAILED_TO_HANDLE_ERROR_IN_SPAN = "Failed to handle error in span." + FAILED_TO_END_SPAN = "Failed to end span." + + def __str__(self): + return str(self.value) + +class GenAISpanEnum(str, Enum): + CONVERSATION = "conversation" + CONV_REQUEST = "POST /v1/conversations" + EXECUTE_TOOL = "execute_tool" + VALIDATE_RUN = "validate_run" + + @staticmethod + def function_call(func_name: str): + return f"function_call[{func_name}]" + + +def parse_time_to_nanos(ts: str) -> int: + dt = datetime.fromisoformat(ts.replace("Z", "+00:00")).astimezone(timezone.utc) + return int(dt.timestamp() * 1e9) + +def set_available_attributes(span: Span, attributes: dict) -> None: + for attribute, value in attributes.items(): + if value: + span.set_attribute(attribute, value) + + +def enrich_span_from_request(span: Span, request: httpx.Request) -> Span: + if not request.url.port: + # From httpx doc: + # Note that the URL class performs port normalization as per the WHATWG spec. + # Default ports for "http", "https", "ws", "wss", and "ftp" schemes are always treated as None. + # Handling default ports since most of the time we are using https + if request.url.scheme == "https": + port = 443 + elif request.url.scheme == "http": + port = 80 + else: + port = -1 + else: + port = request.url.port + + span.set_attributes({ + http_attributes.HTTP_REQUEST_METHOD: request.method, + http_attributes.HTTP_URL: str(request.url), + server_attributes.SERVER_ADDRESS: request.headers.get("host", ""), + server_attributes.SERVER_PORT: port + }) + if request._content: + request_body = json.loads(request._content) + + attributes = { + gen_ai_attributes.GEN_AI_REQUEST_CHOICE_COUNT: request_body.get("n", None), + gen_ai_attributes.GEN_AI_REQUEST_ENCODING_FORMATS: request_body.get("encoding_formats", None), + gen_ai_attributes.GEN_AI_REQUEST_FREQUENCY_PENALTY: request_body.get("frequency_penalty", None), + gen_ai_attributes.GEN_AI_REQUEST_MAX_TOKENS: request_body.get("max_tokens", None), + gen_ai_attributes.GEN_AI_REQUEST_MODEL: request_body.get("model", None), + gen_ai_attributes.GEN_AI_REQUEST_PRESENCE_PENALTY: request_body.get("presence_penalty", None), + gen_ai_attributes.GEN_AI_REQUEST_SEED: request_body.get("random_seed", None), + gen_ai_attributes.GEN_AI_REQUEST_STOP_SEQUENCES: request_body.get("stop", None), + gen_ai_attributes.GEN_AI_REQUEST_TEMPERATURE: request_body.get("temperature", None), + gen_ai_attributes.GEN_AI_REQUEST_TOP_P: request_body.get("top_p", None), + gen_ai_attributes.GEN_AI_REQUEST_TOP_K: request_body.get("top_k", None), + # Input messages are likely to be large, containing user/PII data and other sensitive information. + # Also structured attributes are not yet supported on spans in Python. + # For those reasons, we will not record the input messages for now. + gen_ai_attributes.GEN_AI_INPUT_MESSAGES: None, + } + # Set attributes only if they are not None. + # From OpenTelemetry documentation: None is not a valid attribute value per spec / is not a permitted value type for an attribute. + set_available_attributes(span, attributes) + return span + + +def enrich_span_from_response(tracer: trace.Tracer, span: Span, operation_id: str, response: httpx.Response) -> None: + span.set_status(Status(StatusCode.OK)) + response_data = json.loads(response.content) + + # Base attributes + attributes: dict[str, str | int] = { + http_attributes.HTTP_RESPONSE_STATUS_CODE: response.status_code, + MistralAIAttributes.MISTRAL_AI_OPERATION_ID: operation_id, + gen_ai_attributes.GEN_AI_PROVIDER_NAME: gen_ai_attributes.GenAiProviderNameValues.MISTRAL_AI.value + } + + # Add usage attributes if available + usage = response_data.get("usage", {}) + if usage: + attributes.update({ + gen_ai_attributes.GEN_AI_USAGE_PROMPT_TOKENS: usage.get("prompt_tokens", 0), + gen_ai_attributes.GEN_AI_USAGE_OUTPUT_TOKENS: usage.get("completion_tokens", 0), + MistralAIAttributes.MISTRAL_AI_TOTAL_TOKENS: usage.get("total_tokens", 0) + }) + + span.set_attributes(attributes) + if operation_id == "agents_api_v1_agents_create": + # Semantics from https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-agent-spans/#create-agent-span + agent_attributes = { + gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.CREATE_AGENT.value, + gen_ai_attributes.GEN_AI_AGENT_DESCRIPTION: response_data.get("description", ""), + gen_ai_attributes.GEN_AI_AGENT_ID: response_data.get("id", ""), + gen_ai_attributes.GEN_AI_AGENT_NAME: response_data.get("name", ""), + gen_ai_attributes.GEN_AI_REQUEST_MODEL: response_data.get("model", ""), + gen_ai_attributes.GEN_AI_SYSTEM_INSTRUCTIONS: response_data.get("instructions", "") + } + span.set_attributes(agent_attributes) + if operation_id in ["agents_api_v1_conversations_start", "agents_api_v1_conversations_append"]: + outputs = response_data.get("outputs", []) + conversation_attributes = { + gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.INVOKE_AGENT.value, + gen_ai_attributes.GEN_AI_CONVERSATION_ID: response_data.get("conversation_id", "") + } + span.set_attributes(conversation_attributes) + parent_context = set_span_in_context(span) + + for output in outputs: + # TODO: Only enrich the spans if it's a single turn conversation. + # Multi turn conversations are handled in the extra.run.tools.create_function_result function + if output["type"] == "function.call": + pass + if output["type"] == "tool.execution": + start_ns = parse_time_to_nanos(output["created_at"]) + end_ns = parse_time_to_nanos(output["completed_at"]) + child_span = tracer.start_span("Tool Execution", start_time=start_ns, context=parent_context) + tool_attributes = { + gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.EXECUTE_TOOL.value, + gen_ai_attributes.GEN_AI_TOOL_CALL_ID: output.get("id", ""), + MistralAIAttributes.MISTRAL_AI_TOOL_CALL_ARGUMENTS: output.get("arguments", ""), + gen_ai_attributes.GEN_AI_TOOL_NAME: output.get("name", "") + } + child_span.set_attributes(tool_attributes) + child_span.end(end_time=end_ns) + if output["type"] == "message.output": + start_ns = parse_time_to_nanos(output["created_at"]) + end_ns = parse_time_to_nanos(output["completed_at"]) + child_span = tracer.start_span("Message Output", start_time=start_ns, context=parent_context) + message_attributes = { + gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.CHAT.value, + gen_ai_attributes.GEN_AI_PROVIDER_NAME: gen_ai_attributes.GenAiProviderNameValues.MISTRAL_AI.value, + MistralAIAttributes.MISTRAL_AI_MESSAGE_ID: output.get("id", ""), + gen_ai_attributes.GEN_AI_AGENT_ID: output.get("agent_id", ""), + gen_ai_attributes.GEN_AI_REQUEST_MODEL: output.get("model", "") + } + child_span.set_attributes(message_attributes) + child_span.end(end_time=end_ns) + if operation_id == "ocr_v1_ocr_post": + usage_info = response_data.get("usage_info", "") + ocr_attributes = { + MistralAIAttributes.MISTRAL_AI_OPERATION_NAME: MistralAINameValues.OCR.value, + MistralAIAttributes.MISTRAL_AI_OCR_USAGE_PAGES_PROCESSED: usage_info.get("pages_processed", "") if usage_info else "", + MistralAIAttributes.MISTRAL_AI_OCR_USAGE_DOC_SIZE_BYTES: usage_info.get("doc_size_bytes", "") if usage_info else "", + gen_ai_attributes.GEN_AI_REQUEST_MODEL: response_data.get("model", "") + } + span.set_attributes(ocr_attributes) + + +class GenAISpanProcessor(SpanProcessor): + def on_start(self, span, parent_context = None): + span.set_attributes({"agent.trace.public": ""}) + + +class QuietOTLPSpanExporter(OTLPSpanExporter): + def export(self, spans): + try: + return super().export(spans) + except Exception: + logger.warning(f"{TracingErrors.FAILED_TO_EXPORT_OTEL_SPANS} {(traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT)}") + return SpanExportResult.FAILURE + + +def get_or_create_otel_tracer() -> Tuple[bool, Tracer]: + """ + 3 possible cases: + + -> [SDK in a Workflow / App] If there is already a tracer provider set -> use that one + + -> [SDK standalone] If no tracer provider is set but the OTEL_EXPORTER_OTLP_ENDPOINT is set -> create a new tracer provider that exports to the OTEL_EXPORTER_OTLP_ENDPOINT + + -> Else tracing is disabled + """ + tracing_enabled = True + tracer_provider = trace.get_tracer_provider() + + if isinstance(tracer_provider, trace.ProxyTracerProvider): + if OTEL_EXPORTER_OTLP_ENDPOINT: + # SDK standalone: No tracer provider but OTEL_EXPORTER_OTLP_ENDPOINT is set -> create a new tracer provider that exports to the OTEL_EXPORTER_OTLP_ENDPOINT + try: + exporter = QuietOTLPSpanExporter( + endpoint=OTEL_EXPORTER_OTLP_ENDPOINT, + timeout=OTEL_EXPORTER_OTLP_TIMEOUT + ) + resource = Resource.create(attributes={SERVICE_NAME: OTEL_SERVICE_NAME}) + tracer_provider = TracerProvider(resource=resource) + + span_processor = BatchSpanProcessor( + exporter, + export_timeout_millis=OTEL_EXPORTER_OTLP_EXPORT_TIMEOUT_MILLIS, + max_export_batch_size=OTEL_EXPORTER_OTLP_MAX_EXPORT_BATCH_SIZE, + schedule_delay_millis=OTEL_EXPORTER_OTLP_SCHEDULE_DELAY_MILLIS, + max_queue_size=OTEL_EXPORTER_OTLP_MAX_QUEUE_SIZE + ) + + tracer_provider.add_span_processor(span_processor) + tracer_provider.add_span_processor(GenAISpanProcessor()) + trace.set_tracer_provider(tracer_provider) + + except Exception: + logger.warning(f"{TracingErrors.FAILED_TO_INITIALIZE_OPENTELEMETRY_TRACING} {(traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT)}") + tracing_enabled = False + else: + # No tracer provider nor OTEL_EXPORTER_OTLP_ENDPOINT set -> tracing is disabled + tracing_enabled = False + + tracer = tracer_provider.get_tracer(MISTRAL_SDK_OTEL_TRACER_NAME) + + return tracing_enabled, tracer + +def get_traced_request_and_span(tracing_enabled: bool, tracer: Tracer, span: Optional[Span], operation_id: str, request: httpx.Request) -> Tuple[httpx.Request, Optional[Span]]: + if not tracing_enabled: + return request, span + + try: + span = tracer.start_span(name=operation_id) + # Inject the span context into the request headers to be used by the backend service to continue the trace + propagate.inject(request.headers) + span = enrich_span_from_request(span, request) + except Exception: + logger.warning( + "%s %s", + TracingErrors.FAILED_TO_CREATE_SPAN_FOR_REQUEST, + traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT, + ) + if span: + end_span(span=span) + span = None + + return request, span + + +def get_traced_response(tracing_enabled: bool, tracer: Tracer, span: Optional[Span], operation_id: str, response: httpx.Response) -> httpx.Response: + if not tracing_enabled or not span: + return response + try: + is_stream_response = not response.is_closed and not response.is_stream_consumed + if is_stream_response: + return TracedResponse.from_response(resp=response, span=span) + enrich_span_from_response( + tracer, span, operation_id, response + ) + except Exception: + logger.warning( + "%s %s", + TracingErrors.FAILED_TO_ENRICH_SPAN_WITH_RESPONSE, + traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT, + ) + if span: + end_span(span=span) + return response + +def get_response_and_error(tracing_enabled: bool, tracer: Tracer, span: Optional[Span], operation_id: str, response: httpx.Response, error: Optional[Exception]) -> Tuple[httpx.Response, Optional[Exception]]: + if not tracing_enabled or not span: + return response, error + try: + if error: + span.record_exception(error) + span.set_status(Status(StatusCode.ERROR, str(error))) + if hasattr(response, "_content") and response._content: + response_body = json.loads(response._content) + if response_body.get("object", "") == "error": + if error_msg := response_body.get("message", ""): + attributes = { + http_attributes.HTTP_RESPONSE_STATUS_CODE: response.status_code, + MistralAIAttributes.MISTRAL_AI_ERROR_TYPE: response_body.get("type", ""), + MistralAIAttributes.MISTRAL_AI_ERROR_MESSAGE: error_msg, + MistralAIAttributes.MISTRAL_AI_ERROR_CODE: response_body.get("code", ""), + } + for attribute, value in attributes.items(): + if value: + span.set_attribute(attribute, value) + span.end() + span = None + except Exception: + logger.warning( + "%s %s", + TracingErrors.FAILED_TO_HANDLE_ERROR_IN_SPAN, + traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT, + ) + + if span: + span.end() + span = None + return response, error + + +def end_span(span: Span) -> None: + try: + span.end() + except Exception: + logger.warning( + "%s %s", + TracingErrors.FAILED_TO_END_SPAN, + traceback.format_exc() if MISTRAL_SDK_DEBUG_TRACING else DEBUG_HINT, + ) + +class TracedResponse(httpx.Response): + """ + TracedResponse is a subclass of httpx.Response that ends the span when the response is closed. + + This hack allows ending the span only once the stream is fully consumed. + """ + def __init__(self, *args, span: Optional[Span], **kwargs) -> None: + super().__init__(*args, **kwargs) + self.span = span + + def close(self) -> None: + if self.span: + end_span(span=self.span) + super().close() + + async def aclose(self) -> None: + if self.span: + end_span(span=self.span) + await super().aclose() + + @classmethod + def from_response(cls, resp: httpx.Response, span: Optional[Span]) -> "TracedResponse": + traced_resp = cls.__new__(cls) + traced_resp.__dict__ = copy.copy(resp.__dict__) + traced_resp.span = span + + # Warning: this syntax bypasses the __init__ method. + # If you add init logic in the TracedResponse.__init__ method, you will need to add the following line for it to execute: + # traced_resp.__init__(your_arguments) + + return traced_resp diff --git a/src/mistralai/extra/run/tools.py b/src/mistralai/extra/run/tools.py index 81fec66..e3f8093 100644 --- a/src/mistralai/extra/run/tools.py +++ b/src/mistralai/extra/run/tools.py @@ -8,6 +8,7 @@ import json from typing import cast, Callable, Sequence, Any, ForwardRef, get_type_hints, Union +from opentelemetry import trace from griffe import ( Docstring, DocstringSectionKind, @@ -15,9 +16,11 @@ DocstringParameter, DocstringSection, ) +import opentelemetry.semconv._incubating.attributes.gen_ai_attributes as gen_ai_attributes from mistralai.extra.exceptions import RunException from mistralai.extra.mcp.base import MCPClientProtocol +from mistralai.extra.observability.otel import GenAISpanEnum, MistralAIAttributes, set_available_attributes from mistralai.extra.run.result import RunOutputEntries from mistralai.models import ( FunctionResultEntry, @@ -191,22 +194,31 @@ async def create_function_result( if isinstance(function_call.arguments, str) else function_call.arguments ) - try: - if isinstance(run_tool, RunFunction): - res = run_tool.callable(**arguments) - elif isinstance(run_tool, RunCoroutine): - res = await run_tool.awaitable(**arguments) - elif isinstance(run_tool, RunMCPTool): - res = await run_tool.mcp_client.execute_tool(function_call.name, arguments) - except Exception as e: - if continue_on_fn_error is True: - return FunctionResultEntry( - tool_call_id=function_call.tool_call_id, - result=f"Error while executing {function_call.name}: {str(e)}", - ) - raise RunException( - f"Failed to execute tool {function_call.name} with arguments '{function_call.arguments}'" - ) from e + tracer = trace.get_tracer(__name__) + with tracer.start_as_current_span(GenAISpanEnum.function_call(function_call.name)) as span: + try: + if isinstance(run_tool, RunFunction): + res = run_tool.callable(**arguments) + elif isinstance(run_tool, RunCoroutine): + res = await run_tool.awaitable(**arguments) + elif isinstance(run_tool, RunMCPTool): + res = await run_tool.mcp_client.execute_tool(function_call.name, arguments) + function_call_attributes = { + gen_ai_attributes.GEN_AI_OPERATION_NAME: gen_ai_attributes.GenAiOperationNameValues.EXECUTE_TOOL.value, + gen_ai_attributes.GEN_AI_TOOL_CALL_ID: function_call.id, + MistralAIAttributes.MISTRAL_AI_TOOL_CALL_ARGUMENTS: str(function_call.arguments), + gen_ai_attributes.GEN_AI_TOOL_NAME: function_call.name + } + set_available_attributes(span, function_call_attributes) + except Exception as e: + if continue_on_fn_error is True: + return FunctionResultEntry( + tool_call_id=function_call.tool_call_id, + result=f"Error while executing {function_call.name}: {str(e)}", + ) + raise RunException( + f"Failed to execute tool {function_call.name} with arguments '{function_call.arguments}'" + ) from e return FunctionResultEntry( tool_call_id=function_call.tool_call_id,