diff --git a/sdks/python/apache_beam/metrics/cells.pxd b/sdks/python/apache_beam/metrics/cells.pxd index ebadeec97984..2f173d7394b1 100644 --- a/sdks/python/apache_beam/metrics/cells.pxd +++ b/sdks/python/apache_beam/metrics/cells.pxd @@ -22,6 +22,7 @@ from cpython.datetime cimport datetime cdef class MetricCell(object): cdef object _lock + cdef object _container_lock cpdef bint update(self, value) except -1 cdef datetime _start_time diff --git a/sdks/python/apache_beam/metrics/cells.py b/sdks/python/apache_beam/metrics/cells.py index b4703c5b5b96..6f477e5bead8 100644 --- a/sdks/python/apache_beam/metrics/cells.py +++ b/sdks/python/apache_beam/metrics/cells.py @@ -61,8 +61,9 @@ class MetricCell(object): and may be subject to parallel/concurrent updates. Cells should only be used directly within a runner. """ - def __init__(self): - self._lock = threading.Lock() + def __init__(self, container_lock=None): + self._lock = threading.Lock() # Lock for this specific cell's internal data + self._container_lock = container_lock # Lock from the MetricsContainer self._start_time = None def update(self, value): @@ -106,8 +107,8 @@ class CounterCell(MetricCell): This class is thread safe. """ - def __init__(self, *args): - super().__init__(*args) + def __init__(self, container_lock=None): + super().__init__(container_lock=container_lock) self.value = 0 def reset(self): @@ -137,7 +138,11 @@ def update(self, value): # directly by circumventing the GIL. self.value += ivalue else: - with self._lock: + # If a container lock is provided, use it. Otherwise, use cell's own lock. + # This ensures that if the cell is managed by a MetricsContainer, + # the container's lock is used for thread safety across cells. + lock_to_use = self._container_lock if self._container_lock else self._lock + with lock_to_use: self.value += value def get_cumulative(self): @@ -171,8 +176,8 @@ class DistributionCell(MetricCell): This class is thread safe. """ - def __init__(self, *args): - super().__init__(*args) + def __init__(self, container_lock=None): + super().__init__(container_lock=container_lock) self.data = DistributionData.identity_element() def reset(self): @@ -190,7 +195,9 @@ def update(self, value): # We will hold the GIL throughout the entire _update. self._update(value) else: - with self._lock: + # If a container lock is provided, use it. Otherwise, use cell's own lock. + lock_to_use = self._container_lock if self._container_lock else self._lock + with lock_to_use: self._update(value) def _update(self, value): @@ -226,8 +233,8 @@ class AbstractMetricCell(MetricCell): This class is thread safe. """ - def __init__(self, data_class): - super().__init__() + def __init__(self, data_class, container_lock=None): + super().__init__(container_lock=container_lock) self.data_class = data_class self.data = self.data_class.identity_element() @@ -240,11 +247,13 @@ def combine(self, other: 'AbstractMetricCell') -> 'AbstractMetricCell': return result def set(self, value): - with self._lock: + lock_to_use = self._container_lock if self._container_lock else self._lock + with lock_to_use: self._update_locked(value) def update(self, value): - with self._lock: + lock_to_use = self._container_lock if self._container_lock else self._lock + with lock_to_use: self._update_locked(value) def _update_locked(self, value): @@ -269,8 +278,8 @@ class GaugeCell(AbstractMetricCell): This class is thread safe. """ - def __init__(self): - super().__init__(GaugeData) + def __init__(self, container_lock=None): + super().__init__(GaugeData, container_lock=container_lock) def _update_locked(self, value): # Set the value directly without checking timestamp, because @@ -298,8 +307,8 @@ class StringSetCell(AbstractMetricCell): This class is thread safe. """ - def __init__(self): - super().__init__(StringSetData) + def __init__(self, container_lock=None): + super().__init__(StringSetData, container_lock=container_lock) def add(self, value): self.update(value) @@ -327,8 +336,8 @@ class BoundedTrieCell(AbstractMetricCell): This class is thread safe. """ - def __init__(self): - super().__init__(BoundedTrieData) + def __init__(self, container_lock=None): + super().__init__(BoundedTrieData, container_lock=container_lock) def add(self, value): self.update(value) diff --git a/sdks/python/apache_beam/metrics/execution.py b/sdks/python/apache_beam/metrics/execution.py index a3414447c48f..7590a50c493c 100644 --- a/sdks/python/apache_beam/metrics/execution.py +++ b/sdks/python/apache_beam/metrics/execution.py @@ -47,6 +47,7 @@ from apache_beam.metrics.cells import CounterCell from apache_beam.metrics.cells import DistributionCell from apache_beam.metrics.cells import GaugeCell +from apache_beam.metrics.cells import MetricCellFactory from apache_beam.metrics.cells import StringSetCell from apache_beam.metrics.cells import StringSetData from apache_beam.runners.worker import statesampler @@ -57,7 +58,6 @@ from apache_beam.metrics.cells import GaugeData from apache_beam.metrics.cells import DistributionData from apache_beam.metrics.cells import MetricCell - from apache_beam.metrics.cells import MetricCellFactory from apache_beam.metrics.metricbase import MetricName from apache_beam.portability.api import metrics_pb2 @@ -272,10 +272,24 @@ def get_bounded_trie(self, metric_name): def get_metric_cell(self, typed_metric_name): # type: (_TypedMetricName) -> MetricCell + # First check without a lock. cell = self.metrics.get(typed_metric_name, None) if cell is None: + # If not found, acquire lock and check again. + # This is to prevent duplicate cell creation in concurrent scenarios. with self.lock: - cell = self.metrics[typed_metric_name] = typed_metric_name.cell_type() + cell = self.metrics.get(typed_metric_name, None) + if cell is None: + if isinstance(typed_metric_name.cell_type, MetricCellFactory): + # If it's a factory, call it without container_lock, + # as the factory's __call__ should handle cell creation. + cell = self.metrics[ + typed_metric_name] = typed_metric_name.cell_type() + else: + # Otherwise, assume it's a MetricCell class and pass container_lock. + cell = self.metrics[ + typed_metric_name] = typed_metric_name.cell_type( + container_lock=self.lock) return cell def get_cumulative(self): @@ -325,14 +339,14 @@ def to_runner_api_monitoring_infos(self, transform_id): """Returns a list of MonitoringInfos for the metrics in this container.""" with self.lock: items = list(self.metrics.items()) - all_metrics = [ - cell.to_runner_api_monitoring_info(key.metric_name, transform_id) - for key, cell in items - ] - return { - monitoring_infos.to_key(mi): mi - for mi in all_metrics if mi is not None - } + all_metrics = [ + cell.to_runner_api_monitoring_info(key.metric_name, transform_id) + for key, cell in items + ] + return { + monitoring_infos.to_key(mi): mi + for mi in all_metrics if mi is not None + } def reset(self): # type: () -> None