Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions cognite/extractorutils/unstable/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ class MyConfig(ExtractorConfig):
another_parameter: int
schedule: ScheduleConfig

class MyMetrics(BaseMetrics):
def __init__(self, extractor_name: str, extractor_version: str):
super().__init__(extractor_name, extractor_version)
self.custom_counter = Counter("custom_counter", "A custom counter")

class MyExtractor(Extractor[MyConfig]):
NAME = "My Extractor"
EXTERNAL_ID = "my-extractor"
Expand All @@ -30,6 +35,9 @@ class MyExtractor(Extractor[MyConfig]):

CONFIG_TYPE = MyConfig

# Override metrics type annotation for IDE support
metrics: MyMetrics

def __init_tasks__(self) -> None:
self.add_task(
ScheduledTask(
Expand All @@ -42,6 +50,8 @@ def __init_tasks__(self) -> None:

def my_task_function(self, task_context: TaskContext) -> None:
task_context.logger.info("Running my task")
# IDE will now autocomplete custom_counter
self.metrics.custom_counter.inc()
"""

import logging
Expand All @@ -59,7 +69,7 @@ def my_task_function(self, task_context: TaskContext) -> None:
from typing_extensions import Self, assert_never

from cognite.extractorutils._inner_util import _resolve_log_level
from cognite.extractorutils.metrics import BaseMetrics
from cognite.extractorutils.metrics import BaseMetrics, safe_get
from cognite.extractorutils.statestore import (
AbstractStateStore,
LocalStateStore,
Expand Down Expand Up @@ -117,11 +127,13 @@ def __init__(
application_config: _T,
current_config_revision: ConfigRevision,
log_level_override: str | None = None,
metrics_class: type[BaseMetrics] | None = None,
) -> None:
self.connection_config = connection_config
self.application_config = application_config
self.current_config_revision: ConfigRevision = current_config_revision
self.log_level_override = log_level_override
self.metrics_class: type[BaseMetrics] | None = metrics_class


class Extractor(Generic[ConfigType], CogniteLogger):
Expand Down Expand Up @@ -149,9 +161,7 @@ class Extractor(Generic[ConfigType], CogniteLogger):

cancellation_token: CancellationToken

def __init__(
self, config: FullConfig[ConfigType], checkin_worker: CheckinWorker, metrics: BaseMetrics | None = None
) -> None:
def __init__(self, config: FullConfig[ConfigType], checkin_worker: CheckinWorker) -> None:
self._logger = logging.getLogger(f"{self.EXTERNAL_ID}.main")
self._checkin_worker = checkin_worker

Expand All @@ -175,7 +185,8 @@ def __init__(

self._tasks: list[Task] = []
self._start_time: datetime
self._metrics: BaseMetrics | None = metrics

self.metrics: BaseMetrics = self._load_metrics(config.metrics_class)

self.metrics_push_manager = (
self.metrics_config.create_manager(self.cognite_client, cancellation_token=self.cancellation_token)
Expand Down Expand Up @@ -262,6 +273,18 @@ def _setup_logging(self) -> None:
"Defaulted to console logging."
)

def _load_metrics(self, metrics_class: type[BaseMetrics] | None = None) -> BaseMetrics:
"""
Loads metrics based on the provided metrics class.

Reuses existing singleton if available to avoid Prometheus registry conflicts.
"""
if metrics_class and issubclass(metrics_class, BaseMetrics):
metrics_instance = safe_get(metrics_class)
else:
metrics_instance = safe_get(BaseMetrics, extractor_name=self.EXTERNAL_ID, extractor_version=self.VERSION)
return metrics_instance

def _load_state_store(self) -> None:
"""
Searches through the config object for a StateStoreConfig.
Expand Down Expand Up @@ -379,10 +402,8 @@ def restart(self) -> None:
self.cancellation_token.cancel()

@classmethod
def _init_from_runtime(
cls, config: FullConfig[ConfigType], checkin_worker: CheckinWorker, metrics: BaseMetrics
) -> Self:
return cls(config, checkin_worker, metrics)
def _init_from_runtime(cls, config: FullConfig[ConfigType], checkin_worker: CheckinWorker) -> Self:
return cls(config, checkin_worker)

def add_task(self, task: Task) -> None:
"""
Expand Down
12 changes: 5 additions & 7 deletions cognite/extractorutils/unstable/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,13 @@ def _extractor_process_entrypoint(
controls: _RuntimeControls,
config: FullConfig,
checkin_worker: CheckinWorker,
metrics: BaseMetrics | None = None,
) -> None:
logger = logging.getLogger(f"{extractor_class.EXTERNAL_ID}.runtime")
checkin_worker.active_revision = config.current_config_revision
checkin_worker.set_on_fatal_error_handler(lambda _: on_fatal_error(controls))
checkin_worker.set_on_revision_change_handler(lambda _: on_revision_changed(controls))
checkin_worker.set_retry_startup(extractor_class.RETRY_STARTUP)
if not metrics:
metrics = BaseMetrics(extractor_name=extractor_class.NAME, extractor_version=extractor_class.VERSION)
extractor = extractor_class._init_from_runtime(config, checkin_worker, metrics)
extractor = extractor_class._init_from_runtime(config, checkin_worker)
extractor._attach_runtime_controls(
cancel_event=controls.cancel_event,
message_queue=controls.message_queue,
Expand Down Expand Up @@ -138,13 +135,13 @@ class Runtime(Generic[ExtractorType]):
def __init__(
self,
extractor: type[ExtractorType],
metrics: BaseMetrics | None = None,
metrics: type[BaseMetrics] | None = None,
) -> None:
self._extractor_class = extractor
self._cancellation_token = CancellationToken()
self._cancellation_token.cancel_on_interrupt()
self._message_queue: Queue[RuntimeMessage] = Queue()
self._metrics = metrics
self._metrics_class = metrics
self.logger = logging.getLogger(f"{self._extractor_class.EXTERNAL_ID}.runtime")
self._setup_logging()
self._cancel_event: MpEvent | None = None
Expand Down Expand Up @@ -273,7 +270,7 @@ def _spawn_extractor(

process = Process(
target=_extractor_process_entrypoint,
args=(self._extractor_class, controls, config, checkin_worker, self._metrics),
args=(self._extractor_class, controls, config, checkin_worker),
)

process.start()
Expand Down Expand Up @@ -507,6 +504,7 @@ def _main_runtime(self, args: Namespace) -> None:
application_config=application_config,
current_config_revision=current_config_revision,
log_level_override=args.log_level,
metrics_class=self._metrics_class,
),
checkin_worker,
)
Expand Down
30 changes: 30 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,47 @@
from enum import Enum

import pytest
from prometheus_client.core import REGISTRY

from cognite.client import CogniteClient
from cognite.client.config import ClientConfig
from cognite.client.credentials import OAuthClientCredentials
from cognite.client.data_classes.data_modeling import NodeId
from cognite.client.exceptions import CogniteAPIError, CogniteNotFoundError
from cognite.extractorutils import metrics

NUM_NODES = 5000
NUM_EDGES = NUM_NODES // 100


@pytest.fixture(autouse=True)
def reset_singleton() -> Generator[None, None, None]:
"""
This fixture ensures that the _metrics_singularities
class variables are reset, and Prometheus collectors are unregistered,
providing test isolation.
"""
# Clean up before test
metrics._metrics_singularities.clear()

# Unregister all collectors to prevent "Duplicated timeseries" errors
collectors = list(REGISTRY._collector_to_names.keys())
for collector in collectors:
with contextlib.suppress(Exception):
REGISTRY.unregister(collector)

yield

# Clean up after test
metrics._metrics_singularities.clear()

# Unregister all collectors again
collectors = list(REGISTRY._collector_to_names.keys())
for collector in collectors:
with contextlib.suppress(Exception):
REGISTRY.unregister(collector)


class ETestType(Enum):
TIME_SERIES = "time_series"
CDM_TIME_SERIES = "cdm_time_series"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_unstable/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import gzip
import json
import os
from collections import Counter
from collections.abc import Callable, Generator, Iterator
from threading import RLock
from time import sleep, time
Expand All @@ -10,6 +9,7 @@

import pytest
import requests_mock
from prometheus_client.core import Counter

from cognite.client import CogniteClient
from cognite.client.config import ClientConfig
Expand Down
5 changes: 3 additions & 2 deletions tests/test_unstable/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,11 @@ def counting_push(self: CognitePusher) -> None:
application_config=app_config,
current_config_revision=1,
log_level_override=override_level,
metrics_class=TestMetrics,
)
worker = get_checkin_worker(connection_config)
extractor = TestExtractor(full_config, worker, metrics=TestMetrics)
assert isinstance(extractor._metrics, TestMetrics) or extractor._metrics == TestMetrics
extractor = TestExtractor(full_config, worker)
assert isinstance(extractor.metrics, TestMetrics)

with contextlib.ExitStack() as stack:
stack.enter_context(contextlib.suppress(Exception))
Expand Down
112 changes: 105 additions & 7 deletions tests/test_unstable/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,35 @@
from cognite.extractorutils.unstable.core.base import ConfigRevision, FullConfig
from cognite.extractorutils.unstable.core.checkin_worker import CheckinWorker
from cognite.extractorutils.unstable.core.runtime import Runtime
from cognite.extractorutils.unstable.core.tasks import StartupTask, TaskContext
from test_unstable.conftest import TestConfig, TestExtractor, TestMetrics


class MetricsTestExtractor(SimpleExtractor):
"""Custom extractor for testing metrics in multiprocessing context."""

def __init_tasks__(self) -> None:
super().__init_tasks__()

def test_metrics_task(context: TaskContext) -> None:
# Increment counter twice
self.metrics.a_counter.inc()
self.metrics.a_counter.inc()

# Log the counter value so we can verify it in output
counter_value = self.metrics.a_counter._value.get()
context.info(f"METRICS_TEST: Counter value is {counter_value}")

# Add startup task to test metrics
self.add_task(
StartupTask(
name="test-metrics",
description="Test metrics increment",
target=test_metrics_task,
)
)


@pytest.fixture
def local_config_file() -> Generator[Path, None, None]:
file = Path(__file__).parent.parent.parent / f"test-{randint(0, 1000000)}.yaml"
Expand Down Expand Up @@ -396,11 +422,83 @@ def test_logging_on_windows_with_import_error(
assert mock_root_logger.addHandler.call_count == 1


def test_extractor_with_metrics() -> None:
runtime = Runtime(TestExtractor, metrics=TestMetrics)
assert isinstance(runtime._metrics, TestMetrics) or runtime._metrics == TestMetrics
def test_extractor_with_metrics(
connection_config: ConnectionConfig, tmp_path: Path, monkeypatch: MonkeyPatch, capfd: pytest.CaptureFixture[str]
) -> None:
"""
Test metrics_class is properly passed through Runtime to child process.
This test verifies multiprocessing integration with metrics and counter increments.
"""
cfg_dir = Path("cognite/examples/unstable/extractors/simple_extractor/config")
base_conn = cfg_dir / "connection_config.yaml"
base_app = cfg_dir / "config.yaml"

conn_file = tmp_path / f"test-{randint(0, 1000000)}-connection_config.yaml"
_write_conn_from_fixture(base_conn, conn_file, connection_config)

app_file = tmp_path / f"test-{randint(0, 1000000)}-config.yaml"
app_file.write_text(base_app.read_text(encoding="utf-8"))

argv = [
"simple-extractor",
"--cwd",
str(tmp_path),
"-c",
conn_file.name,
"-f",
app_file.name,
"--skip-init-checks",
"-l",
"info",
]

# The metrics instance should be a singleton
another_runtime = Runtime(TestExtractor, metrics=TestMetrics)
assert another_runtime._metrics is runtime._metrics
assert isinstance(another_runtime._metrics, TestMetrics) or another_runtime._metrics == TestMetrics
monkeypatch.setattr(sys, "argv", argv)

runtime = Runtime(MetricsTestExtractor, metrics=TestMetrics)

# Verify runtime stores metrics class
assert runtime._metrics_class is TestMetrics, "Runtime should store TestMetrics class"

child_holder = {}
original_spawn = Runtime._spawn_extractor

def spy_spawn(self: Self, config: FullConfig, checkin_worker: CheckinWorker) -> Process:
assert config.metrics_class is TestMetrics, "FullConfig should carry TestMetrics class"

p = original_spawn(
self,
config,
checkin_worker,
)
child_holder["proc"] = p
return p

monkeypatch.setattr(Runtime, "_spawn_extractor", spy_spawn, raising=True)

t = Thread(target=runtime.run, name="RuntimeMain")
t.start()

start = time.time()
while "proc" not in child_holder and time.time() - start < 10:
time.sleep(0.05)

assert "proc" in child_holder, "Extractor process was not spawned in time."
proc = child_holder["proc"]

time.sleep(1.5) # Give more time for the startup task to run

runtime._cancellation_token.cancel()

t.join(timeout=30)
assert not t.is_alive(), "Runtime did not shut down within timeout after cancellation."

proc.join(timeout=0)
assert not proc.is_alive(), "Extractor process is still alive"

out, err = capfd.readouterr()
combined = (out or "") + (err or "")

# Verify metrics counter was incremented
assert "METRICS_TEST: Counter value is 2" in combined, (
f"Expected metrics counter to be 2 in child process.\nCaptured output:\n{combined}"
)
Loading
Loading