Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
387 changes: 269 additions & 118 deletions lib/charms/data_platform_libs/v0/data_interfaces.py

Large diffs are not rendered by default.

103 changes: 74 additions & 29 deletions lib/charms/data_platform_libs/v0/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""A library for communicating with the S3 credentials providers and consumers.
r"""A library for communicating with the S3 credentials providers and consumers.

This library provides the relevant interface code implementing the communication
specification for fetching, retrieving, triggering, and responding to events related to
Expand Down Expand Up @@ -113,23 +113,21 @@ def _on_credential_gone(self, event: CredentialsGoneEvent):
import json
import logging
from collections import namedtuple
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

import ops.charm
import ops.framework
import ops.model
from ops.charm import (
CharmBase,
CharmEvents,
EventSource,
Object,
ObjectEvents,
RelationBrokenEvent,
RelationChangedEvent,
RelationEvent,
RelationJoinedEvent,
)
from ops.model import Relation
from ops.framework import EventSource, Object, ObjectEvents
from ops.model import Application, Relation, RelationDataContent, Unit

# The unique Charmhub library identifier, never change it
LIBID = "fca396f6254246c9bfa565b1f85ab528"
Expand All @@ -139,7 +137,7 @@ def _on_credential_gone(self, event: CredentialsGoneEvent):

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version
LIBPATCH = 2
LIBPATCH = 4

logger = logging.getLogger(__name__)

Expand All @@ -152,7 +150,7 @@ def _on_credential_gone(self, event: CredentialsGoneEvent):
deleted - key that were deleted"""


def diff(event: RelationChangedEvent, bucket: str) -> Diff:
def diff(event: RelationChangedEvent, bucket: Union[Unit, Application]) -> Diff:
"""Retrieves the diff of the data in the relation changed databag.

Args:
Expand All @@ -166,9 +164,11 @@ def diff(event: RelationChangedEvent, bucket: str) -> Diff:
# Retrieve the old data from the data key in the application relation databag.
old_data = json.loads(event.relation.data[bucket].get("data", "{}"))
# Retrieve the new data from the event relation databag.
new_data = {
key: value for key, value in event.relation.data[event.app].items() if key != "data"
}
new_data = (
{key: value for key, value in event.relation.data[event.app].items() if key != "data"}
if event.app
else {}
)

# These are the keys that were added to the databag and triggered this event.
added = new_data.keys() - old_data.keys()
Expand All @@ -193,7 +193,10 @@ class BucketEvent(RelationEvent):
@property
def bucket(self) -> Optional[str]:
"""Returns the bucket was requested."""
return self.relation.data[self.relation.app].get("bucket")
if not self.relation.app:
return None

return self.relation.data[self.relation.app].get("bucket", "")


class CredentialRequestedEvent(BucketEvent):
Expand All @@ -209,7 +212,7 @@ class S3CredentialEvents(CharmEvents):
class S3Provider(Object):
"""A provider handler for communicating S3 credentials to consumers."""

on = S3CredentialEvents()
on = S3CredentialEvents() # pyright: ignore [reportGeneralTypeIssues]

def __init__(
self,
Expand All @@ -232,7 +235,9 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None:
diff = self._diff(event)
# emit on credential requested if bucket is provided by the requirer application
if "bucket" in diff.added:
self.on.credentials_requested.emit(event.relation, app=event.app, unit=event.unit)
getattr(self.on, "credentials_requested").emit(
event.relation, app=event.app, unit=event.unit
)

def _load_relation_data(self, raw_relation_data: dict) -> dict:
"""Loads relation data from the relation data bag.
Expand All @@ -242,7 +247,7 @@ def _load_relation_data(self, raw_relation_data: dict) -> dict:
Returns:
dict: Relation data in dict format.
"""
connection_data = dict()
connection_data = {}
for key in raw_relation_data:
try:
connection_data[key] = json.loads(raw_relation_data[key])
Expand Down Expand Up @@ -309,9 +314,11 @@ def fetch_relation_data(self) -> dict:
"""
data = {}
for relation in self.relations:
data[relation.id] = {
key: value for key, value in relation.data[relation.app].items() if key != "data"
}
data[relation.id] = (
{key: value for key, value in relation.data[relation.app].items() if key != "data"}
if relation.app
else {}
)
return data

def update_connection_info(self, relation_id: int, connection_data: dict) -> None:
Expand Down Expand Up @@ -493,46 +500,73 @@ class S3Event(RelationEvent):
@property
def bucket(self) -> Optional[str]:
"""Returns the bucket name."""
if not self.relation.app:
return None

return self.relation.data[self.relation.app].get("bucket")

@property
def access_key(self) -> Optional[str]:
"""Returns the access key."""
if not self.relation.app:
return None

return self.relation.data[self.relation.app].get("access-key")

@property
def secret_key(self) -> Optional[str]:
"""Returns the secret key."""
if not self.relation.app:
return None

return self.relation.data[self.relation.app].get("secret-key")

@property
def path(self) -> Optional[str]:
"""Returns the path where data can be stored."""
if not self.relation.app:
return None

return self.relation.data[self.relation.app].get("path")

@property
def endpoint(self) -> Optional[str]:
"""Returns the endpoint address."""
if not self.relation.app:
return None

return self.relation.data[self.relation.app].get("endpoint")

@property
def region(self) -> Optional[str]:
"""Returns the region."""
if not self.relation.app:
return None

return self.relation.data[self.relation.app].get("region")

@property
def s3_uri_style(self) -> Optional[str]:
"""Returns the s3 uri style."""
if not self.relation.app:
return None

return self.relation.data[self.relation.app].get("s3-uri-style")

@property
def storage_class(self) -> Optional[str]:
"""Returns the storage class name."""
if not self.relation.app:
return None

return self.relation.data[self.relation.app].get("storage-class")

@property
def tls_ca_chain(self) -> Optional[List[str]]:
"""Returns the TLS CA chain."""
if not self.relation.app:
return None

tls_ca_chain = self.relation.data[self.relation.app].get("tls-ca-chain")
if tls_ca_chain is not None:
return json.loads(tls_ca_chain)
Expand All @@ -541,11 +575,17 @@ def tls_ca_chain(self) -> Optional[List[str]]:
@property
def s3_api_version(self) -> Optional[str]:
"""Returns the S3 API version."""
if not self.relation.app:
return None

return self.relation.data[self.relation.app].get("s3-api-version")

@property
def attributes(self) -> Optional[List[str]]:
"""Returns the attributes."""
if not self.relation.app:
return None

attributes = self.relation.data[self.relation.app].get("attributes")
if attributes is not None:
return json.loads(attributes)
Expand Down Expand Up @@ -573,9 +613,11 @@ class S3CredentialRequiresEvents(ObjectEvents):
class S3Requirer(Object):
"""Requires-side of the s3 relation."""

on = S3CredentialRequiresEvents()
on = S3CredentialRequiresEvents() # pyright: ignore[reportGeneralTypeIssues]

def __init__(self, charm: ops.charm.CharmBase, relation_name: str, bucket_name: str = None):
def __init__(
self, charm: ops.charm.CharmBase, relation_name: str, bucket_name: Optional[str] = None
):
"""Manager of the s3 client relations."""
super().__init__(charm, relation_name)

Expand Down Expand Up @@ -658,15 +700,15 @@ def update_connection_info(self, relation_id: int, connection_data: dict) -> Non
relation.data[self.local_app].update(updated_connection_data)
logger.debug(f"Updated S3 credentials: {updated_connection_data}")

def _load_relation_data(self, raw_relation_data: dict) -> dict:
def _load_relation_data(self, raw_relation_data: RelationDataContent) -> Dict[str, str]:
"""Loads relation data from the relation data bag.

Args:
raw_relation_data: Relation data from the databag
Returns:
dict: Relation data in dict format.
"""
connection_data = dict()
connection_data = {}
for key in raw_relation_data:
try:
connection_data[key] = json.loads(raw_relation_data[key])
Expand Down Expand Up @@ -700,22 +742,25 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None:
missing_options.append(configuration_option)
# emit credential change event only if all mandatory fields are present
if contains_required_options:
self.on.credentials_changed.emit(event.relation, app=event.app, unit=event.unit)
getattr(self.on, "credentials_changed").emit(
event.relation, app=event.app, unit=event.unit
)
else:
logger.warning(
f"Some mandatory fields: {missing_options} are not present, do not emit credential change event!"
)

def get_s3_connection_info(self) -> Dict:
def get_s3_connection_info(self) -> Dict[str, str]:
"""Return the s3 credentials as a dictionary."""
relation = self.charm.model.get_relation(self.relation_name)
if not relation:
return {}
return self._load_relation_data(relation.data[relation.app])
for relation in self.relations:
if relation and relation.app:
return self._load_relation_data(relation.data[relation.app])

return {}

def _on_relation_broken(self, event: RelationBrokenEvent) -> None:
"""Notify the charm about a broken S3 credential store relation."""
self.on.credentials_gone.emit(event.relation, app=event.app, unit=event.unit)
getattr(self.on, "credentials_gone").emit(event.relation, app=event.app, unit=event.unit)

@property
def relations(self) -> List[Relation]:
Expand Down
Loading