Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 59 additions & 30 deletions cognite/extractorutils/unstable/configuration/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import os
import re
from collections.abc import Iterator
from datetime import timedelta
from enum import Enum
from pathlib import Path
Expand Down Expand Up @@ -53,23 +54,44 @@ class ConfigModel(BaseModel):
)


class _ClientCredentialsConfig(ConfigModel):
type: Literal["client-credentials"]
class Scopes(str):
def __init__(self, scopes: str) -> None:
self._scopes = list(scopes.split(" "))

@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
return core_schema.no_info_after_validator_function(cls, handler(str))

def __eq__(self, other: object) -> bool:
if not isinstance(other, Scopes):
return NotImplemented
return self._scopes == other._scopes

def __hash__(self) -> int:
return hash(self._scopes)

def __iter__(self) -> Iterator[str]:
return iter(self._scopes)


class BaseCredentialsConfig(ConfigModel):
client_id: str
scopes: Scopes


class _ClientCredentialsConfig(BaseCredentialsConfig):
type: Literal["client-credentials"]
client_secret: str
token_url: str
scopes: list[str]
resource: str | None = None
audience: str | None = None


class _ClientCertificateConfig(ConfigModel):
class _ClientCertificateConfig(BaseCredentialsConfig):
type: Literal["client-certificate"]
client_id: str
path: Path
password: str | None = None
authority_url: str
scopes: list[str]


AuthenticationConfig = Annotated[_ClientCredentialsConfig | _ClientCertificateConfig, Field(discriminator="type")]
Expand Down Expand Up @@ -191,18 +213,26 @@ def __repr__(self) -> str:
return self._expression


class _ConnectionParameters(ConfigModel):
gzip_compression: bool = False
status_forcelist: list[int] = Field(default_factory=lambda: [429, 502, 503, 504])
max_retries: int = 10
max_retries_connect: int = 3
max_retry_backoff: TimeIntervalConfig = Field(default_factory=lambda: TimeIntervalConfig("30s"))
max_connection_pool_size: int = 50
ssl_verify: bool = True
proxies: dict[str, str] = Field(default_factory=dict)
class RetriesConfig(ConfigModel):
max_retries: int = Field(default=10, ge=-1)
max_backoff: TimeIntervalConfig = Field(default_factory=lambda: TimeIntervalConfig("30s"))
timeout: TimeIntervalConfig = Field(default_factory=lambda: TimeIntervalConfig("30s"))


class SslCertificatesConfig(ConfigModel):
verify: bool = True
allow_list: list[str] | None = None


class ConnectionParameters(ConfigModel):
retries: RetriesConfig = Field(default_factory=RetriesConfig)
ssl_certificates: SslCertificatesConfig = Field(default_factory=SslCertificatesConfig)


class IntegrationConfig(ConfigModel):
external_id: str


class ConnectionConfig(ConfigModel):
"""
Configuration for connecting to a Cognite Data Fusion project.
Expand All @@ -216,11 +246,11 @@ class ConnectionConfig(ConfigModel):
project: str
base_url: str

integration: str
integration: IntegrationConfig

authentication: AuthenticationConfig

connection: _ConnectionParameters = Field(default_factory=_ConnectionParameters)
connection: ConnectionParameters = Field(default_factory=ConnectionParameters)

def get_cognite_client(self, client_name: str) -> CogniteClient:
"""
Expand All @@ -235,14 +265,9 @@ def get_cognite_client(self, client_name: str) -> CogniteClient:
from cognite.client.config import global_config

global_config.disable_pypi_version_check = True
global_config.disable_gzip = not self.connection.gzip_compression
global_config.status_forcelist = set(self.connection.status_forcelist)
global_config.max_retries = self.connection.max_retries
global_config.max_retries_connect = self.connection.max_retries_connect
global_config.max_retry_backoff = self.connection.max_retry_backoff.seconds
global_config.max_connection_pool_size = self.connection.max_connection_pool_size
global_config.disable_ssl = not self.connection.ssl_verify
global_config.proxies = self.connection.proxies
global_config.max_retries = self.connection.retries.max_retries
global_config.max_retry_backoff = self.connection.retries.max_backoff.seconds
global_config.disable_ssl = not self.connection.ssl_certificates.verify

credential_provider: CredentialProvider
match self.authentication:
Expand Down Expand Up @@ -270,7 +295,7 @@ def get_cognite_client(self, client_name: str) -> CogniteClient:
client_id=client_certificate.client_id,
cert_thumbprint=str(thumbprint),
certificate=str(key),
scopes=client_certificate.scopes,
scopes=list(client_certificate.scopes),
)

case _:
Expand All @@ -280,7 +305,7 @@ def get_cognite_client(self, client_name: str) -> CogniteClient:
project=self.project,
base_url=self.base_url,
client_name=client_name,
timeout=self.connection.timeout.seconds,
timeout=self.connection.retries.timeout.seconds,
credentials=credential_provider,
)

Expand Down Expand Up @@ -315,7 +340,9 @@ def from_environment(cls) -> "ConnectionConfig":
client_id=os.environ["COGNITE_CLIENT_ID"],
client_secret=os.environ["COGNITE_CLIENT_SECRET"],
token_url=os.environ["COGNITE_TOKEN_URL"],
scopes=os.environ["COGNITE_TOKEN_SCOPES"].split(","),
scopes=Scopes(
os.environ["COGNITE_TOKEN_SCOPES"],
),
)
elif "COGNITE_CLIENT_CERTIFICATE_PATH" in os.environ:
auth = _ClientCertificateConfig(
Expand All @@ -324,15 +351,17 @@ def from_environment(cls) -> "ConnectionConfig":
path=Path(os.environ["COGNITE_CLIENT_CERTIFICATE_PATH"]),
password=os.environ.get("COGNITE_CLIENT_CERTIFICATE_PATH"),
authority_url=os.environ["COGNITE_AUTHORITY_URL"],
scopes=os.environ["COGNITE_TOKEN_SCOPES"].split(","),
scopes=Scopes(
os.environ["COGNITE_TOKEN_SCOPES"],
),
)
else:
raise KeyError("Missing auth, either COGNITE_CLIENT_SECRET or COGNITE_CLIENT_CERTIFICATE_PATH must be set")

return ConnectionConfig(
project=os.environ["COGNITE_PROJECT"],
base_url=os.environ["COGNITE_BASE_URL"],
integration=os.environ["COGNITE_INTEGRATION"],
integration=IntegrationConfig(external_id=os.environ["COGNITE_INTEGRATION"]),
authentication=auth,
)

Expand Down
4 changes: 2 additions & 2 deletions cognite/extractorutils/unstable/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _checkin(self) -> None:
res = self.cognite_client.post(
f"/api/v1/projects/{self.cognite_client.config.project}/integrations/checkin",
json={
"externalId": self.connection_config.integration,
"externalId": self.connection_config.integration.external_id,
"taskEvents": task_updates,
"errors": error_updates,
},
Expand Down Expand Up @@ -345,7 +345,7 @@ def _report_extractor_info(self) -> None:
self.cognite_client.post(
f"/api/v1/projects/{self.cognite_client.config.project}/integrations/extractorinfo",
json={
"externalId": self.connection_config.integration,
"externalId": self.connection_config.integration.external_id,
"activeConfigRevision": self.current_config_revision,
"extractor": {
"version": self.VERSION,
Expand Down
12 changes: 6 additions & 6 deletions cognite/extractorutils/unstable/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ def main() -> None:
load_file,
load_from_cdf,
)
from cognite.extractorutils.unstable.configuration.models import ConnectionConfig
from cognite.extractorutils.unstable.configuration.models import ConnectionConfig, ExtractorConfig
from cognite.extractorutils.unstable.core._dto import Error
from cognite.extractorutils.unstable.core.errors import ErrorLevel
from cognite.extractorutils.util import now

from ._messaging import RuntimeMessage
from .base import ConfigRevision, ConfigType, Extractor, FullConfig
from .base import ConfigRevision, Extractor, FullConfig

__all__ = ["ExtractorType", "Runtime"]

Expand Down Expand Up @@ -173,7 +173,7 @@ def _try_get_application_config(
self,
args: Namespace,
connection_config: ConnectionConfig,
) -> tuple[ConfigType, ConfigRevision]:
) -> tuple[ExtractorConfig, ConfigRevision]:
current_config_revision: ConfigRevision

if args.local_override:
Expand All @@ -194,7 +194,7 @@ def _try_get_application_config(

application_config, current_config_revision = load_from_cdf(
self._cognite_client,
connection_config.integration,
connection_config.integration.external_id,
self._extractor_class.CONFIG_TYPE,
)

Expand All @@ -204,7 +204,7 @@ def _safe_get_application_config(
self,
args: Namespace,
connection_config: ConnectionConfig,
) -> tuple[ConfigType, ConfigRevision] | None:
) -> tuple[ExtractorConfig, ConfigRevision] | None:
prev_error: str | None = None

while not self._cancellation_token.is_cancelled:
Expand Down Expand Up @@ -233,7 +233,7 @@ def _safe_get_application_config(
self._cognite_client.post(
f"/api/v1/projects/{self._cognite_client.config.project}/odin/checkin",
json={
"externalId": connection_config.integration,
"externalId": connection_config.integration.external_id,
"errors": [error.model_dump()],
},
headers={"cdf-version": "alpha"},
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ classifiers = [
]

dependencies = [
"cognite-sdk>=7.59.0",
"cognite-sdk>=7.75.2",
"prometheus-client>=0.7.0,<=1.0.0",
"arrow>=1.0.0",
"pyyaml>=5.3.0,<7",
Expand Down
8 changes: 6 additions & 2 deletions tests/test_unstable/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from cognite.extractorutils.unstable.configuration.models import (
ConnectionConfig,
ExtractorConfig,
IntegrationConfig,
Scopes,
_ClientCredentialsConfig,
)
from cognite.extractorutils.unstable.core.base import Extractor
Expand Down Expand Up @@ -75,12 +77,14 @@ def connection_config(extraction_pipeline: str) -> ConnectionConfig:
return ConnectionConfig(
project=os.environ["COGNITE_DEV_PROJECT"],
base_url=os.environ["COGNITE_DEV_BASE_URL"],
integration=extraction_pipeline,
integration=IntegrationConfig(external_id=extraction_pipeline),
authentication=_ClientCredentialsConfig(
type="client-credentials",
client_id=os.environ.get("COGNITE_DEV_CLIENT_ID", os.environ["COGNITE_CLIENT_ID"]),
client_secret=os.environ.get("COGNITE_DEV_CLIENT_SECRET", os.environ["COGNITE_CLIENT_SECRET"]),
scopes=os.environ["COGNITE_DEV_TOKEN_SCOPES"].split(","),
scopes=Scopes(
os.environ["COGNITE_DEV_TOKEN_SCOPES"],
),
token_url=os.environ.get("COGNITE_DEV_TOKEN_URL", os.environ["COGNITE_TOKEN_URL"]),
),
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_unstable/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_simple_task_report(

# Test that the task run is entered into the history for that task
res = extractor.cognite_client.get(
f"/api/v1/projects/{extractor.cognite_client.config.project}/integrations/history?integration={connection_config.integration}&taskName=TestTask",
f"/api/v1/projects/{extractor.cognite_client.config.project}/integrations/history?integration={connection_config.integration.external_id}&taskName=TestTask",
headers={"cdf-version": "alpha"},
).json()

Expand Down
Loading
Loading