Skip to content

Commit 112e279

Browse files
attempt at fixing tests (done with LLM)
1 parent fd33875 commit 112e279

File tree

4 files changed

+176
-111
lines changed

4 files changed

+176
-111
lines changed

synapse/handlers/stats.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ def __init__(self, hs: "HomeServer"):
7272
# Guard to ensure we only process deltas one at a time
7373
self._is_processing = False
7474

75+
# Initialize room count metrics to 0
76+
known_rooms_gauge.set(0, {SERVER_NAME_LABEL: self.server_name})
77+
locally_joined_rooms_gauge.set(0, {SERVER_NAME_LABEL: self.server_name})
78+
7579
if self.stats_enabled and hs.config.worker.run_background_tasks:
7680
self.notifier.add_replication_callback(self.notify_new_event)
7781

synapse/util/batching_queue.py

Lines changed: 69 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#
2121

2222
import logging
23+
import weakref
2324
from typing import (
2425
TYPE_CHECKING,
2526
Awaitable,
@@ -46,13 +47,66 @@
4647
V = TypeVar("V")
4748
R = TypeVar("R")
4849

49-
# number_queued = meter.create_observable_gauge(
50-
# "synapse_util_batching_queue_number_queued",
51-
# description="The number of items waiting in the queue across all keys",
52-
# )
50+
# Global registry to track all BatchingQueue instances.
51+
# We use a WeakSet so that queues can be garbage collected when no longer referenced.
52+
_batching_queue_registry: "weakref.WeakSet[BatchingQueue]" = weakref.WeakSet()
53+
54+
55+
def _collect_number_queued(options: object) -> list[Observation]:
56+
"""Callback to collect number_queued metrics from all BatchingQueue instances."""
57+
observations = []
58+
for queue in _batching_queue_registry:
59+
observations.append(
60+
Observation(
61+
sum(len(q) for q in queue._next_values.values()),
62+
{"name": queue._name, SERVER_NAME_LABEL: queue.server_name},
63+
)
64+
)
65+
return observations
66+
67+
68+
def _collect_number_of_keys(options: object) -> list[Observation]:
69+
"""Callback to collect number_of_keys metrics from all BatchingQueue instances."""
70+
observations = []
71+
for queue in _batching_queue_registry:
72+
observations.append(
73+
Observation(
74+
len(queue._next_values),
75+
{"name": queue._name, SERVER_NAME_LABEL: queue.server_name},
76+
)
77+
)
78+
return observations
79+
80+
81+
def _collect_number_in_flight(options: object) -> list[Observation]:
82+
"""Callback to collect number_in_flight metrics from all BatchingQueue instances."""
83+
observations = []
84+
for queue in _batching_queue_registry:
85+
observations.append(
86+
Observation(
87+
queue._number_in_flight,
88+
{"name": queue._name, SERVER_NAME_LABEL: queue.server_name},
89+
)
90+
)
91+
return observations
92+
93+
94+
# Global observable gauges that collect from all BatchingQueue instances
95+
number_queued = meter.create_observable_gauge(
96+
"synapse_util_batching_queue_number_queued",
97+
callbacks=[_collect_number_queued],
98+
description="The number of items waiting in the queue across all keys",
99+
)
100+
101+
number_of_keys = meter.create_observable_gauge(
102+
"synapse_util_batching_queue_number_of_keys",
103+
callbacks=[_collect_number_of_keys],
104+
description="The number of distinct keys that have items queued",
105+
)
53106

54107
number_in_flight = meter.create_observable_gauge(
55108
"synapse_util_batching_queue_number_pending",
109+
callbacks=[_collect_number_in_flight],
56110
description="The number of items across all keys either being processed or waiting in a queue",
57111
)
58112

@@ -107,45 +161,19 @@ def __init__(
107161
# The function to call with batches of values.
108162
self._process_batch_callback = process_batch_callback
109163

110-
self.number_queued = meter.create_observable_gauge(
111-
"synapse_util_batching_queue_number_queued",
112-
callbacks=[
113-
lambda options: [
114-
Observation(
115-
sum(len(q) for q in self._next_values.values()),
116-
{"name": self._name, SERVER_NAME_LABEL: self.server_name},
117-
)
118-
]
119-
],
120-
description="The number of items waiting in the queue across all keys",
121-
)
164+
# Counter for number of items in flight (being processed or waiting).
165+
self._number_in_flight: int = 0
122166

123-
self.number_of_keys = meter.create_observable_gauge(
124-
"synapse_util_batching_queue_number_of_keys",
125-
description="The number of distinct keys that have items queued",
126-
callbacks=[
127-
lambda options: [
128-
Observation(
129-
len(self._next_values),
130-
{"name": self._name, SERVER_NAME_LABEL: self.server_name},
131-
)
132-
]
133-
],
134-
)
135-
136-
self._number_in_flight_metric = meter.create_up_down_counter(
137-
"synapse_util_batching_queue_number_pending",
138-
description="The number of items across all keys either being processed or waiting in a queue",
139-
)
167+
# Register this instance with the global registry so metrics can be collected.
168+
_batching_queue_registry.add(self)
140169

141170
def shutdown(self) -> None:
142171
"""
143172
Prepares the object for garbage collection by removing any handed out
144173
references.
145174
"""
146-
# there doesn't seem to be an otel equivalent for those
147-
# number_queued.remove(self._name, self.server_name)
148-
# number_of_keys.remove(self._name, self.server_name)
175+
# The global registry uses WeakSet, so instances are automatically
176+
# removed when garbage collected. No explicit cleanup needed.
149177

150178
async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
151179
"""Adds the value to the queue with the given key, returning the result
@@ -167,13 +195,11 @@ async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
167195
if key not in self._processing_keys:
168196
self.hs.run_as_background_process(self._name, self._process_queue, key)
169197

170-
self._number_in_flight_metric.add(
171-
1, {"name": self._name, SERVER_NAME_LABEL: self.server_name}
172-
)
173-
res = await make_deferred_yieldable(d)
174-
self._number_in_flight_metric.add(
175-
-1, {"name": self._name, SERVER_NAME_LABEL: self.server_name}
176-
)
198+
self._number_in_flight += 1
199+
try:
200+
res = await make_deferred_yieldable(d)
201+
finally:
202+
self._number_in_flight -= 1
177203
return res
178204

179205
async def _process_queue(self, key: Hashable) -> None:

tests/handlers/test_stats.py

Lines changed: 89 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,50 @@ def _set_metrics_to_zero(self) -> None:
5757
This method resets the metrics to zero before each test to ensure
5858
that each test starts with a clean slate.
5959
"""
60-
from opentelemetry import metrics
61-
from opentelemetry.sdk.metrics import MeterProvider
62-
from opentelemetry.sdk.metrics.export import InMemoryMetricReader
63-
64-
# Create a fresh reader and provider
65-
self.reader = InMemoryMetricReader()
66-
provider = MeterProvider(metric_readers=[self.reader])
67-
# Set the global provider
68-
# Any new metric instruments created after this will use the clean state.
69-
metrics.set_meter_provider(provider)
60+
from synapse.metrics import (
61+
SERVER_NAME_LABEL,
62+
known_rooms_gauge,
63+
locally_joined_rooms_gauge,
64+
)
65+
66+
# Reset the gauge values to 0 for this server
67+
known_rooms_gauge.set(0, {SERVER_NAME_LABEL: self.hs.hostname})
68+
locally_joined_rooms_gauge.set(0, {SERVER_NAME_LABEL: self.hs.hostname})
69+
70+
def _get_gauge_values(
71+
self, metrics: list[tuple[str, dict[str, str]]]
72+
) -> list[Optional[float]]:
73+
"""
74+
Get multiple gauge values from the Prometheus registry in a single call.
75+
76+
The standard REGISTRY.get_sample_value() doesn't work for OpenTelemetry
77+
metrics because the OTel exporter doesn't register its metric names.
78+
Additionally, the OTel collector only returns data on the first collect()
79+
call, so we must collect all data once and then look up all values.
80+
81+
Args:
82+
metrics: List of (metric_name, labels) tuples to look up.
83+
84+
Returns:
85+
List of values in the same order as the input metrics.
86+
"""
87+
# Collect all data from all collectors into a lookup dict
88+
all_samples: dict[tuple[str, tuple[tuple[str, str], ...]], float] = {}
89+
for collector in REGISTRY._collector_to_names.keys():
90+
try:
91+
for metric_family in collector.collect():
92+
for sample in metric_family.samples:
93+
key = (metric_family.name, tuple(sorted(sample.labels.items())))
94+
all_samples[key] = sample.value
95+
except Exception:
96+
continue
97+
98+
# Look up each requested metric
99+
results: list[Optional[float]] = []
100+
for metric_name, labels in metrics:
101+
key = (metric_name, tuple(sorted(labels.items())))
102+
results.append(all_samples.get(key))
103+
return results
70104

71105
def _add_background_updates(self) -> None:
72106
"""
@@ -184,19 +218,18 @@ def test_create_room(self) -> None:
184218
When we create a room, it should have statistics already ready.
185219
"""
186220
self._perform_background_initial_update()
187-
self.assertEqual(
188-
REGISTRY.get_sample_value(
189-
"synapse_known_rooms_total", labels={"server_name": self.hs.hostname}
190-
),
191-
0.0,
192-
)
193-
self.assertEqual(
194-
REGISTRY.get_sample_value(
195-
"synapse_locally_joined_rooms_total",
196-
labels={"server_name": self.hs.hostname},
197-
),
198-
0.0,
199-
)
221+
known_rooms, locally_joined = self._get_gauge_values(
222+
[
223+
("synapse_known_rooms_total", {"server_name": self.hs.hostname}),
224+
(
225+
"synapse_locally_joined_rooms_total",
226+
{"server_name": self.hs.hostname},
227+
),
228+
]
229+
)
230+
self.assertEqual(known_rooms, 0.0)
231+
self.assertEqual(locally_joined, 0.0)
232+
200233
u1 = self.register_user("u1", "pass")
201234
u1token = self.login("u1", "pass")
202235
r1 = self.helper.create_room_as(u1, tok=u1token)
@@ -223,19 +256,17 @@ def test_create_room(self) -> None:
223256
self.assertEqual(r2stats["banned_members"], 0)
224257

225258
# There are 2 rooms created. Check the room metrics were udpated.
226-
self.assertEqual(
227-
REGISTRY.get_sample_value(
228-
"synapse_known_rooms_total", labels={"server_name": self.hs.hostname}
229-
),
230-
2,
231-
)
232-
self.assertEqual(
233-
REGISTRY.get_sample_value(
234-
"synapse_locally_joined_rooms_total",
235-
labels={"server_name": self.hs.hostname},
236-
),
237-
2,
238-
)
259+
known_rooms, locally_joined = self._get_gauge_values(
260+
[
261+
("synapse_known_rooms_total", {"server_name": self.hs.hostname}),
262+
(
263+
"synapse_locally_joined_rooms_total",
264+
{"server_name": self.hs.hostname},
265+
),
266+
]
267+
)
268+
self.assertEqual(known_rooms, 2)
269+
self.assertEqual(locally_joined, 2)
239270

240271
def test_updating_profile_information_does_not_increase_joined_members_count(
241272
self,
@@ -647,19 +678,17 @@ def test_room_metrics(self) -> None:
647678
"""
648679

649680
self._perform_background_initial_update()
650-
self.assertEqual(
651-
REGISTRY.get_sample_value(
652-
"synapse_known_rooms_total", labels={"server_name": self.hs.hostname}
653-
),
654-
0.0,
655-
)
656-
self.assertEqual(
657-
REGISTRY.get_sample_value(
658-
"synapse_locally_joined_rooms_total",
659-
labels={"server_name": self.hs.hostname},
660-
),
661-
0.0,
662-
)
681+
known_rooms, locally_joined = self._get_gauge_values(
682+
[
683+
("synapse_known_rooms_total", {"server_name": self.hs.hostname}),
684+
(
685+
"synapse_locally_joined_rooms_total",
686+
{"server_name": self.hs.hostname},
687+
),
688+
]
689+
)
690+
self.assertEqual(known_rooms, 0.0)
691+
self.assertEqual(locally_joined, 0.0)
663692

664693
u1 = self.register_user("u1", "pass")
665694
u1token = self.login("u1", "pass")
@@ -670,19 +699,17 @@ def test_room_metrics(self) -> None:
670699
self.helper.leave(r2, u1, tok=u1token)
671700

672701
# Check the locally joined rooms metric after creating rooms
673-
self.assertEqual(
674-
REGISTRY.get_sample_value(
675-
"synapse_locally_joined_rooms_total",
676-
labels={"server_name": self.hs.hostname},
677-
),
678-
1,
679-
)
680-
self.assertEqual(
681-
REGISTRY.get_sample_value(
682-
"synapse_known_rooms_total", labels={"server_name": self.hs.hostname}
683-
),
684-
2,
685-
)
702+
known_rooms, locally_joined = self._get_gauge_values(
703+
[
704+
("synapse_known_rooms_total", {"server_name": self.hs.hostname}),
705+
(
706+
"synapse_locally_joined_rooms_total",
707+
{"server_name": self.hs.hostname},
708+
),
709+
]
710+
)
711+
self.assertEqual(locally_joined, 1)
712+
self.assertEqual(known_rooms, 2)
686713

687714
# Check the stats for both rooms
688715
r1stats = self._get_current_stats("room", r1)

tests/util/test_batching_queue.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from synapse.util.batching_queue import (
2929
BatchingQueue,
3030
number_in_flight,
31+
number_of_keys,
32+
number_queued,
3133
)
3234

3335
from tests.unittest import HomeserverTestCase
@@ -61,27 +63,33 @@ async def _process_queue(self, values: list[str]) -> str:
6163

6264
def _get_sample_with_name(self, metric: ObservableGauge, name: str) -> float:
6365
"""For a prometheus metric get the value of the sample that has a
64-
matching "name" label.
66+
matching "name" label and matching metric name.
6567
"""
66-
print(vars(metric))
68+
# The metric.name attribute gives us the OTel instrument name
69+
metric_name = metric.name
70+
6771
for metric_family in REGISTRY.collect():
72+
# Check if this metric family corresponds to our metric
73+
# (the family name should match or contain the metric name)
74+
if metric_family.name != metric_name:
75+
continue
6876
for sample in metric_family.samples:
69-
if sample.labels.get("name") == name: # and sample.name == metric.name:
77+
if sample.labels.get("name") == name:
7078
return sample.value
7179

72-
self.fail("Found no matching sample")
80+
self.fail(f"Found no matching sample for metric={metric_name}, name={name}")
7381

7482
def _assert_metrics(self, queued: int, keys: int, in_flight: int) -> None:
7583
"""Assert that the metrics are correct"""
7684

77-
sample = self._get_sample_with_name(self.queue.number_queued, self.queue._name)
85+
sample = self._get_sample_with_name(number_queued, self.queue._name)
7886
self.assertEqual(
7987
sample,
8088
queued,
8189
"number_queued",
8290
)
8391

84-
sample = self._get_sample_with_name(self.queue.number_of_keys, self.queue._name)
92+
sample = self._get_sample_with_name(number_of_keys, self.queue._name)
8593
self.assertEqual(sample, keys, "number_of_keys")
8694

8795
sample = self._get_sample_with_name(number_in_flight, self.queue._name)

0 commit comments

Comments
 (0)