Skip to content

Commit 4a2ea27

Browse files
authored
fix(mypy): resolve OpenTelemetry typing issues in telemetry.py (#3943)
Fixes mypy type errors in OpenTelemetry integration: - Add type aliases for AttributeValue and Attributes - Add helper to filter None values from attributes (OpenTelemetry doesn't accept None) - Cast metric and tracer objects to proper types - Update imports after refactoring No functional changes.
1 parent 85887d7 commit 4a2ea27

File tree

2 files changed

+36
-22
lines changed

2 files changed

+36
-22
lines changed

src/llama_stack/core/telemetry/telemetry.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66

77
import os
88
import threading
9+
from collections.abc import Mapping, Sequence
910
from datetime import datetime
1011
from enum import Enum
1112
from typing import (
1213
Annotated,
1314
Any,
1415
Literal,
16+
cast,
1517
)
1618

1719
from opentelemetry import metrics, trace
@@ -30,6 +32,10 @@
3032

3133
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]
3234

35+
# Type alias for OpenTelemetry attribute values (excludes None)
36+
AttributeValue = str | bool | int | float | Sequence[str] | Sequence[bool] | Sequence[int] | Sequence[float]
37+
Attributes = Mapping[str, AttributeValue]
38+
3339

3440
@json_schema_type
3541
class SpanStatus(Enum):
@@ -428,6 +434,13 @@ class QueryMetricsResponse(BaseModel):
428434
logger = get_logger(name=__name__, category="telemetry")
429435

430436

437+
def _clean_attributes(attrs: dict[str, Any] | None) -> Attributes | None:
438+
"""Remove None values from attributes dict to match OpenTelemetry's expected type."""
439+
if attrs is None:
440+
return None
441+
return {k: v for k, v in attrs.items() if v is not None}
442+
443+
431444
def is_tracing_enabled(tracer):
432445
with tracer.start_as_current_span("check_tracing") as span:
433446
return span.is_recording()
@@ -456,7 +469,7 @@ def __init__(self) -> None:
456469
# https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter
457470
span_exporter = OTLPSpanExporter()
458471
span_processor = BatchSpanProcessor(span_exporter)
459-
trace.get_tracer_provider().add_span_processor(span_processor)
472+
cast(TracerProvider, trace.get_tracer_provider()).add_span_processor(span_processor)
460473

461474
metric_reader = PeriodicExportingMetricReader(OTLPMetricExporter())
462475
metric_provider = MeterProvider(metric_readers=[metric_reader])
@@ -474,7 +487,7 @@ async def initialize(self) -> None:
474487

475488
async def shutdown(self) -> None:
476489
if self.is_otel_endpoint_set:
477-
trace.get_tracer_provider().force_flush()
490+
cast(TracerProvider, trace.get_tracer_provider()).force_flush()
478491

479492
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
480493
if isinstance(event, UnstructuredLogEvent):
@@ -515,7 +528,7 @@ def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
515528
unit=unit,
516529
description=f"Counter for {name}",
517530
)
518-
return _GLOBAL_STORAGE["counters"][name]
531+
return cast(metrics.Counter, _GLOBAL_STORAGE["counters"][name])
519532

520533
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
521534
assert self.meter is not None
@@ -525,7 +538,7 @@ def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
525538
unit=unit,
526539
description=f"Gauge for {name}",
527540
)
528-
return _GLOBAL_STORAGE["gauges"][name]
541+
return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name])
529542

530543
def _log_metric(self, event: MetricEvent) -> None:
531544
# Add metric as an event to the current span
@@ -560,10 +573,10 @@ def _log_metric(self, event: MetricEvent) -> None:
560573
return
561574
if isinstance(event.value, int):
562575
counter = self._get_or_create_counter(event.metric, event.unit)
563-
counter.add(event.value, attributes=event.attributes)
576+
counter.add(event.value, attributes=_clean_attributes(event.attributes))
564577
elif isinstance(event.value, float):
565578
up_down_counter = self._get_or_create_up_down_counter(event.metric, event.unit)
566-
up_down_counter.add(event.value, attributes=event.attributes)
579+
up_down_counter.add(event.value, attributes=_clean_attributes(event.attributes))
567580

568581
def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDownCounter:
569582
assert self.meter is not None
@@ -573,7 +586,7 @@ def _get_or_create_up_down_counter(self, name: str, unit: str) -> metrics.UpDown
573586
unit=unit,
574587
description=f"UpDownCounter for {name}",
575588
)
576-
return _GLOBAL_STORAGE["up_down_counters"][name]
589+
return cast(metrics.UpDownCounter, _GLOBAL_STORAGE["up_down_counters"][name])
577590

578591
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
579592
with self._lock:
@@ -601,7 +614,8 @@ def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
601614
if event.payload.parent_span_id:
602615
parent_span_id = int(event.payload.parent_span_id, 16)
603616
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
604-
context = trace.set_span_in_context(parent_span)
617+
if parent_span:
618+
context = trace.set_span_in_context(parent_span)
605619
elif traceparent:
606620
carrier = {
607621
"traceparent": traceparent,
@@ -612,15 +626,17 @@ def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
612626
span = tracer.start_span(
613627
name=event.payload.name,
614628
context=context,
615-
attributes=event.attributes or {},
629+
attributes=_clean_attributes(event.attributes),
616630
)
617631
_GLOBAL_STORAGE["active_spans"][span_id] = span
618632

619633
elif isinstance(event.payload, SpanEndPayload):
620-
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
634+
span = _GLOBAL_STORAGE["active_spans"].get(span_id) # type: ignore[assignment]
621635
if span:
622636
if event.attributes:
623-
span.set_attributes(event.attributes)
637+
cleaned_attrs = _clean_attributes(event.attributes)
638+
if cleaned_attrs:
639+
span.set_attributes(cleaned_attrs)
624640

625641
status = (
626642
trace.Status(status_code=trace.StatusCode.OK)

src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66
from collections.abc import Mapping, Sequence
7-
from typing import Any, Literal
7+
from typing import Any, Literal, cast
88

99
from sqlalchemy import (
1010
JSON,
@@ -55,17 +55,17 @@ def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
5555
raise ValueError(f"Operator mapping must have a single operator, got: {value}")
5656
op, operand = next(iter(value.items()))
5757
if op == "==" or op == "=":
58-
return column == operand
58+
return cast(ColumnElement[Any], column == operand)
5959
if op == ">":
60-
return column > operand
60+
return cast(ColumnElement[Any], column > operand)
6161
if op == "<":
62-
return column < operand
62+
return cast(ColumnElement[Any], column < operand)
6363
if op == ">=":
64-
return column >= operand
64+
return cast(ColumnElement[Any], column >= operand)
6565
if op == "<=":
66-
return column <= operand
66+
return cast(ColumnElement[Any], column <= operand)
6767
raise ValueError(f"Unsupported operator '{op}' in where mapping")
68-
return column == value
68+
return cast(ColumnElement[Any], column == value)
6969

7070

7171
class SqlAlchemySqlStoreImpl(SqlStore):
@@ -210,10 +210,8 @@ async def fetch_all(
210210
query = query.limit(fetch_limit)
211211

212212
result = await session.execute(query)
213-
if result.rowcount == 0:
214-
rows = []
215-
else:
216-
rows = [dict(row._mapping) for row in result]
213+
# Iterate directly - if no rows, list comprehension yields empty list
214+
rows = [dict(row._mapping) for row in result]
217215

218216
# Always return pagination result
219217
has_more = False

0 commit comments

Comments
 (0)