Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0aaddca
WIP: implement GCSObjectMetadataClient class to attach custom-metadat…
Feb 27, 2025
dce1a73
WIP: change dump method interface.
Feb 27, 2025
6b0dc4c
WIP: add user_provided_gcs_labels parameter to TaskOnKart.
Feb 27, 2025
ed86bcb
Test: add test of GCSObjectMetadataClient.
Feb 27, 2025
38fa924
feat: dealed with nits PR comments.
Feb 27, 2025
fcfdd56
feat: Remove user_provided_labels feature. This feature will be suppo…
Feb 27, 2025
de1fb45
for-PR: apply almost all comments.
Feb 28, 2025
c11abea
fix: change Dict to dict.
Feb 28, 2025
e471163
feat: add gokart specific parameter serialize test.
Feb 28, 2025
47195b9
fix: fix testcases with literals and more meaningful assertion.
Feb 28, 2025
c8931f2
feat: add mock testcase.
Feb 28, 2025
cc320b5
fix: fix CI errors.
Feb 28, 2025
d8c6ec5
feat: deal with pr comments, and modify testcases.
Mar 2, 2025
ca6d69e
feat: deal with kitagry comments.
Mar 3, 2025
08549a7
feat: Supportfunctionalities to add user specific original labels and…
Mar 3, 2025
a73801f
feat: remove namespace feature.
Mar 3, 2025
aa33064
feat: resolve conflicts.
Mar 3, 2025
a047f6a
feat: Deal with kitagry PR comments. Added type annotations.
Mar 3, 2025
2d4322c
CI: apply ruff
Mar 3, 2025
6e80c2f
feat: rename user_provided labels to custom_labels.
Mar 3, 2025
f62742e
CI: apply ruff
Mar 3, 2025
91b8282
feat: Deal with PR comments. Change _normalize_labels signature and a…
Mar 4, 2025
6364394
feat: deal with PR comments.
Mar 4, 2025
d7cfcdc
feat: deal with PR comments.
Mar 4, 2025
13dc761
CI: fix tests
Mar 4, 2025
a4d6a0b
feat: deal with PR comments.
Mar 4, 2025
3315be5
CI: apply ruff
Mar 5, 2025
fc4b1fb
feat: deal with kitagry PR comments, responsibility separation.
Mar 5, 2025
337fab8
feat: deal with yokomotod PR comments, use hoge | None expression ins…
Mar 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 63 additions & 18 deletions gokart/gcs_obj_metadata_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import copy
import re
from logging import getLogger
from typing import Any, Optional, Union
from typing import Any, Union
from urllib.parse import urlsplit

from googleapiclient.model import makepatch
Expand All @@ -18,23 +19,25 @@ class GCSObjectMetadataClient:
This class used for adding metadata as labels.
"""

@staticmethod
def _is_log_related_path(path: str) -> bool:
return re.match(r'^log/(processing_time/|task_info/|task_log/|module_versions/|random_seed/|task_params/).+', path) is not None

# This is the copied method of luigi.gcs._path_to_bucket_and_key(path).
@staticmethod
def path_to_bucket_and_key(path: str) -> tuple[str, str]:
def _path_to_bucket_and_key(path: str) -> tuple[str, str]:
(scheme, netloc, path, _, _) = urlsplit(path)
assert scheme == 'gs'
path_without_initial_slash = path[1:]
return netloc, path_without_initial_slash

@staticmethod
def add_task_state_labels(
path: str,
task_params: Optional[dict[Any, str]] = None,
) -> None:
def add_task_state_labels(path: str, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
if GCSObjectMetadataClient._is_log_related_path(path):
return
# In gokart/object_storage.get_time_stamp, could find same call.
# _path_to_bucket_and_key is a private method, so, this might not be acceptable.
bucket, obj = GCSObjectMetadataClient.path_to_bucket_and_key(path)

bucket, obj = GCSObjectMetadataClient._path_to_bucket_and_key(path)
_response = GCSConfig().get_gcs_client().client.objects().get(bucket=bucket, object=obj).execute()
if _response is None:
logger.error(f'failed to get object from GCS bucket {bucket} and object {obj}.')
Expand All @@ -50,6 +53,7 @@ def add_task_state_labels(
patched_metadata = GCSObjectMetadataClient._get_patched_obj_metadata(
copy.deepcopy(original_metadata),
task_params,
custom_labels,
)

if original_metadata != patched_metadata:
Expand All @@ -71,30 +75,71 @@ def add_task_state_labels(
if update_response is None:
logger.error(f'failed to patch object {obj} in bucket {bucket} and object {obj}.')

@staticmethod
def _normalize_labels(labels: dict[str, Any] | None) -> dict[str, str]:
return {str(key): str(value) for key, value in labels.items()} if labels else {}

@staticmethod
def _get_patched_obj_metadata(
metadata: Any,
task_params: Optional[dict[Any, str]] = None,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, Any] | None = None,
) -> Union[dict, Any]:
# If metadata from response when getting bucket and object information is not dictionary,
# something wrong might be happened, so return original metadata, no patched.
if not isinstance(metadata, dict):
logger.warning(f'metadata is not a dict: {metadata}, something wrong was happened when getting response when get bucket and object information.')
return metadata

if not task_params:
if not task_params and not custom_labels:
return metadata
# Maximum size of metadata for each object is 8 KiB.
# [Link]: https://cloud.google.com/storage/quotas#objects
max_gcs_metadata_size, total_metadata_size, labels = 8 * 1024, 0, []
for label_name, label_value in task_params.items():
normalized_task_params_labels = GCSObjectMetadataClient._normalize_labels(task_params)
normalized_custom_labels = GCSObjectMetadataClient._normalize_labels(custom_labels)
# There is a possibility that the keys of user-provided labels(custom_labels) may conflict with those generated from task parameters (task_params_labels).
# However, users who utilize custom_labels are no longer expected to search using the labels generated from task parameters.
# Instead, users are expected to search using the labels they provided.
# Therefore, in the event of a key conflict, the value registered by the user-provided labels will take precedence.
_merged_labels = GCSObjectMetadataClient._merge_custom_labels_and_task_params_labels(normalized_task_params_labels, normalized_custom_labels)
return dict(metadata) | dict(GCSObjectMetadataClient._adjust_gcs_metadata_limit_size(_merged_labels))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When GCSObjectMetadataClient._adjust_gcs_metadata_limit_size(_merged_labels) has 7.9KiB and metadata has 0.2KiB, it will be more than 8KiB. It is ok?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh... sorry, try to recontribute.


@staticmethod
def _merge_custom_labels_and_task_params_labels(
normalized_task_params: dict[str, str],
normalized_custom_labels: dict[str, Any],
) -> dict[str, str]:
merged_labels = copy.deepcopy(normalized_custom_labels)
for label_name, label_value in normalized_task_params.items():
if len(label_value) == 0:
logger.warning(f'value of label_name={label_name} is empty. So skip to add as a metadata.')
continue
size = len(str(label_name).encode('utf-8')) + len(str(label_value).encode('utf-8'))
if total_metadata_size + size > max_gcs_metadata_size:
logger.warning(f'current metadata total size is {total_metadata_size} byte, and no more labels would be added.')
if label_name in merged_labels.keys():
logger.warning(f'label_name={label_name} is already seen. So skip to add as a metadata.')
continue
merged_labels[label_name] = label_value
return merged_labels

# Google Cloud Storage(GCS) has a limitation of metadata size, 8 KiB.
# So, we need to adjust the size of metadata.
@staticmethod
def _adjust_gcs_metadata_limit_size(_labels: dict[str, str]) -> dict[str, str]:
def _get_label_size(label_name: str, label_value: str) -> int:
return len(label_name.encode('utf-8')) + len(label_value.encode('utf-8'))

labels = copy.deepcopy(_labels)
max_gcs_metadata_size, current_total_metadata_size = (
8 * 1024,
sum(_get_label_size(label_name, label_value) for label_name, label_value in labels.items()),
)

if current_total_metadata_size <= max_gcs_metadata_size:
return labels

for label_name, label_value in reversed(labels.items()):
size = _get_label_size(label_name, label_value)
del labels[label_name]
current_total_metadata_size -= size
if current_total_metadata_size <= max_gcs_metadata_size:
break
total_metadata_size += size
labels.append((label_name, label_value))
return dict(metadata) | dict(labels)
return labels
6 changes: 4 additions & 2 deletions gokart/in_memory/target.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from datetime import datetime
from typing import Any, Optional
from typing import Any

from gokart.in_memory.repository import InMemoryCacheRepository
from gokart.target import TargetOnKart, TaskLockParams
Expand All @@ -24,7 +26,7 @@ def _get_task_lock_params(self) -> TaskLockParams:
def _load(self) -> Any:
return _repository.get_value(self._data_key)

def _dump(self, obj: Any, task_params: Optional[dict[str, str]] = None) -> None:
def _dump(self, obj: Any, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
return _repository.set_value(self._data_key, obj)

def _remove(self) -> None:
Expand Down
18 changes: 11 additions & 7 deletions gokart/target.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import hashlib
import os
import shutil
Expand Down Expand Up @@ -28,11 +30,13 @@ def exists(self) -> bool:
def load(self) -> Any:
return wrap_load_with_lock(func=self._load, task_lock_params=self._get_task_lock_params())()

def dump(self, obj, lock_at_dump: bool = True, task_params: Optional[dict[str, str]] = None) -> None:
def dump(self, obj, lock_at_dump: bool = True, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
if lock_at_dump:
wrap_dump_with_lock(func=self._dump, task_lock_params=self._get_task_lock_params(), exist_check=self.exists)(obj=obj, task_params=task_params)
wrap_dump_with_lock(func=self._dump, task_lock_params=self._get_task_lock_params(), exist_check=self.exists)(
obj=obj, task_params=task_params, custom_labels=custom_labels
)
else:
self._dump(obj=obj, task_params=task_params)
self._dump(obj=obj, task_params=task_params, custom_labels=custom_labels)

def remove(self) -> None:
if self.exists():
Expand All @@ -57,7 +61,7 @@ def _load(self) -> Any:
pass

@abstractmethod
def _dump(self, obj, task_params: Optional[dict[str, str]] = None) -> None:
def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
pass

@abstractmethod
Expand Down Expand Up @@ -94,11 +98,11 @@ def _load(self) -> Any:
with self._target.open('r') as f:
return self._processor.load(f)

def _dump(self, obj, task_params: Optional[dict[str, str]] = None) -> None:
def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
with self._target.open('w') as f:
self._processor.dump(obj, f)
if self.path().startswith('gs://'):
GCSObjectMetadataClient.add_task_state_labels(path=self.path(), task_params=task_params)
GCSObjectMetadataClient.add_task_state_labels(path=self.path(), task_params=task_params, custom_labels=custom_labels)

def _remove(self) -> None:
self._target.remove()
Expand Down Expand Up @@ -138,7 +142,7 @@ def _load(self) -> Any:
self._remove_temporary_directory()
return model

def _dump(self, obj, task_params: Optional[dict[str, str]] = None) -> None:
def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
self._make_temporary_directory()
self._save_function(obj, self._model_path())
make_target(self._load_function_path()).dump(self._load_function, task_params=task_params)
Expand Down
15 changes: 11 additions & 4 deletions gokart/task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import functools
import hashlib
import inspect
Expand Down Expand Up @@ -350,17 +352,22 @@ def _flatten_recursively(dfs):
return data

@overload
def dump(self, obj: T, target: None = None) -> None: ...
def dump(self, obj: T, target: None = None, custom_labels: dict[Any, Any] | None = None) -> None: ...

@overload
def dump(self, obj: Any, target: Union[str, TargetOnKart]) -> None: ...
def dump(self, obj: Any, target: Union[str, TargetOnKart], custom_labels: dict[Any, Any] | None = None) -> None: ...

def dump(self, obj: Any, target: Union[None, str, TargetOnKart] = None) -> None:
def dump(self, obj: Any, target: Union[None, str, TargetOnKart] = None, custom_labels: dict[str, Any] | None = None) -> None:
PandasTypeConfigMap().check(obj, task_namespace=self.task_namespace)
if self.fail_on_empty_dump and isinstance(obj, pd.DataFrame):
assert not obj.empty

self._get_output_target(target).dump(obj, lock_at_dump=self._lock_at_dump, task_params=super().to_str_params(only_significant=True, only_public=True))
self._get_output_target(target).dump(
obj,
lock_at_dump=self._lock_at_dump,
task_params=super().to_str_params(only_significant=True, only_public=True),
custom_labels=custom_labels,
)

@staticmethod
def get_code(target_class) -> Set[str]:
Expand Down
79 changes: 72 additions & 7 deletions test/test_gcs_obj_metadata_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,47 @@ def run(self):


class TestGCSObjectMetadataClient(unittest.TestCase):
def test_get_patched_obj_metadata(self):
task_params: dict[Any, str] = {
def setUp(self):
self.task_params: dict[str, str] = {
'param1': 'a' * 1000,
'param2': str(1000),
'param3': str({'key1': 'value1', 'key2': True, 'key3': 2}),
'param4': str([1, 2, 3, 4, 5]),
'param5': str(datetime.datetime(year=2025, month=1, day=2, hour=3, minute=4, second=5)),
'param6': '',
}
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=task_params)
self.custom_labels: dict[str, Any] = {
'created_at': datetime.datetime(year=2025, month=1, day=2, hour=3, minute=4, second=5),
'created_by': 'hoge fuga',
'empty': True,
'try_num': 3,
}

self.task_params_with_conflicts = {
'empty': 'False',
'created_by': 'fuga hoge',
'param1': 'a' * 10,
}

def test_normalize_labels_not_empty(self):
got = GCSObjectMetadataClient._normalize_labels(None)
self.assertEqual(got, {})

def test_normalize_labels_has_value(self):
got = GCSObjectMetadataClient._normalize_labels(self.task_params)

self.assertIsInstance(got, dict)
self.assertIsInstance(got, dict)
self.assertIn('param1', got)
self.assertIn('param2', got)
self.assertIn('param3', got)
self.assertIn('param4', got)
self.assertIn('param5', got)
self.assertIn('param6', got)

def test_get_patched_obj_metadata_only_task_params(self):
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=self.task_params, custom_labels=None)

self.assertIsInstance(got, dict)
self.assertIn('param1', got)
self.assertIn('param2', got)
Expand All @@ -34,17 +65,52 @@ def test_get_patched_obj_metadata(self):
self.assertIn('param5', got)
self.assertNotIn('param6', got)

def test_get_patched_obj_metadata_only_custom_labels(self):
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=None, custom_labels=self.custom_labels)

self.assertIsInstance(got, dict)
self.assertIn('created_at', got)
self.assertIn('created_by', got)
self.assertIn('empty', got)
self.assertIn('try_num', got)

def test_get_patched_obj_metadata_with_both_task_params_and_custom_labels(self):
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=self.task_params, custom_labels=self.custom_labels)

self.assertIsInstance(got, dict)
self.assertIn('param1', got)
self.assertIn('param2', got)
self.assertIn('param3', got)
self.assertIn('param4', got)
self.assertIn('param5', got)
self.assertNotIn('param6', got)
self.assertIn('created_at', got)
self.assertIn('created_by', got)
self.assertIn('empty', got)
self.assertIn('try_num', got)

def test_get_patched_obj_metadata_with_exceeded_size_metadata(self):
task_params = {
size_exceeded_task_params = {
'param1': 'a' * 5000,
'param2': 'b' * 5000,
}
want = {
'param1': 'a' * 5000,
}
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=task_params)
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=size_exceeded_task_params)
self.assertEqual(got, want)

def test_get_patched_obj_metadata_with_conflicts(self):
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=self.task_params_with_conflicts, custom_labels=self.custom_labels)
self.assertIsInstance(got, dict)
self.assertIn('created_at', got)
self.assertIn('created_by', got)
self.assertIn('empty', got)
self.assertIn('try_num', got)
self.assertEqual(got['empty'], 'True')
self.assertEqual(got['created_by'], 'hoge fuga')
self.assertEqual(got['param1'], 'a' * 10)


class TestGokartTask(unittest.TestCase):
@patch.object(_DummyTaskOnKart, '_get_output_target')
Expand All @@ -54,8 +120,7 @@ def test_mock_target_on_kart(self, mock_get_output_target):

task = _DummyTaskOnKart()
task.dump({'key': 'value'}, mock_target)

mock_target.dump.assert_called_once_with({'key': 'value'}, lock_at_dump=task._lock_at_dump, task_params={})
mock_target.dump.assert_called_once_with({'key': 'value'}, lock_at_dump=task._lock_at_dump, task_params={}, custom_labels=None)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions test/test_task_on_kart.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def test_should_fail_lock_run_when_port_unset(self):


class _DummyTaskWithNonCompleted(gokart.TaskOnKart):
def dump(self, _obj: Any, _target: Any = None):
def dump(self, _obj: Any, _target: Any = None, _custom_labels: Any = None):
# overrive dump() to do nothing.
pass

Expand All @@ -625,7 +625,7 @@ def complete(self):


class _DummyTaskWithCompleted(gokart.TaskOnKart):
def dump(self, obj: Any, _target: Any = None):
def dump(self, obj: Any, _target: Any = None, custom_labels: Any = None):
# overrive dump() to do nothing.
pass

Expand Down