Skip to content

Commit 78e6bb0

Browse files
tyzerrr荒木 太一
andauthored
Support functionalities to provide metadata as labels to object stored in GCS as cache. (#445)
* WIP: implement GCSObjectMetadataClient class to attach custom-metadata to gcs-object. * WIP: change dump method interface. * WIP: add user_provided_gcs_labels parameter to TaskOnKart. * Test: add test of GCSObjectMetadataClient. * feat: dealed with nits PR comments. * feat: Remove user_provided_labels feature. This feature will be supported another PR. * for-PR: apply almost all comments. * fix: change Dict to dict. * feat: add gokart specific parameter serialize test. * fix: fix testcases with literals and more meaningful assertion. * feat: add mock testcase. * fix: fix CI errors. * feat: deal with pr comments, and modify testcases. * feat: deal with kitagry comments. * feat: deal with PR comments, typle annotation, sealed private parameters. --------- Co-authored-by: 荒木 太一 <taichi.araki@m3-2024mac75.local>
1 parent 5d9d7fa commit 78e6bb0

5 files changed

Lines changed: 176 additions & 10 deletions

File tree

gokart/gcs_obj_metadata_client.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from __future__ import annotations
2+
3+
import copy
4+
from logging import getLogger
5+
from typing import Any, Optional, Union
6+
from urllib.parse import urlsplit
7+
8+
from googleapiclient.model import makepatch
9+
10+
from gokart.gcs_config import GCSConfig
11+
12+
logger = getLogger(__name__)
13+
14+
15+
class GCSObjectMetadataClient:
16+
"""
17+
This class is Utility-Class, so should not be initialized.
18+
This class used for adding metadata as labels.
19+
"""
20+
21+
# This is the copied method of luigi.gcs._path_to_bucket_and_key(path).
22+
@staticmethod
23+
def path_to_bucket_and_key(path: str) -> tuple[str, str]:
24+
(scheme, netloc, path, _, _) = urlsplit(path)
25+
assert scheme == 'gs'
26+
path_without_initial_slash = path[1:]
27+
return netloc, path_without_initial_slash
28+
29+
@staticmethod
30+
def add_task_state_labels(
31+
path: str,
32+
task_params: Optional[dict[Any, str]] = None,
33+
) -> None:
34+
# In gokart/object_storage.get_time_stamp, could find same call.
35+
# _path_to_bucket_and_key is a private method, so, this might not be acceptable.
36+
bucket, obj = GCSObjectMetadataClient.path_to_bucket_and_key(path)
37+
38+
_response = GCSConfig().get_gcs_client().client.objects().get(bucket=bucket, object=obj).execute()
39+
if _response is None:
40+
logger.error(f'failed to get object from GCS bucket {bucket} and object {obj}.')
41+
return
42+
43+
response: dict[str, Any] = dict(_response)
44+
original_metadata: dict[Any, Any] = {}
45+
if 'metadata' in response.keys():
46+
_metadata = response.get('metadata')
47+
if _metadata is not None:
48+
original_metadata = dict(_metadata)
49+
50+
patched_metadata = GCSObjectMetadataClient._get_patched_obj_metadata(
51+
copy.deepcopy(original_metadata),
52+
task_params,
53+
)
54+
55+
if original_metadata != patched_metadata:
56+
# If we use update api, existing object metadata are removed, so should use patch api.
57+
# See the official document descriptions.
58+
# [Link] https://cloud.google.com/storage/docs/viewing-editing-metadata?hl=ja#rest-set-object-metadata
59+
update_response = (
60+
GCSConfig()
61+
.get_gcs_client()
62+
.client.objects()
63+
.patch(
64+
bucket=bucket,
65+
object=obj,
66+
body=makepatch({'metadata': original_metadata}, {'metadata': patched_metadata}),
67+
)
68+
.execute()
69+
)
70+
71+
if update_response is None:
72+
logger.error(f'failed to patch object {obj} in bucket {bucket} and object {obj}.')
73+
74+
@staticmethod
75+
def _get_patched_obj_metadata(
76+
metadata: Any,
77+
task_params: Optional[dict[Any, str]] = None,
78+
) -> Union[dict, Any]:
79+
# If metadata from response when getting bucket and object information is not dictionary,
80+
# something wrong might be happened, so return original metadata, no patched.
81+
if not isinstance(metadata, dict):
82+
logger.warning(f'metadata is not a dict: {metadata}, something wrong was happened when getting response when get bucket and object information.')
83+
return metadata
84+
85+
if not task_params:
86+
return metadata
87+
# Maximum size of metadata for each object is 8 KiB.
88+
# [Link]: https://cloud.google.com/storage/quotas#objects
89+
max_gcs_metadata_size, total_metadata_size, labels = 8 * 1024, 0, []
90+
for label_name, label_value in task_params.items():
91+
if len(label_value) == 0:
92+
logger.warning(f'value of label_name={label_name} is empty. So skip to add as a metadata.')
93+
continue
94+
size = len(str(label_name).encode('utf-8')) + len(str(label_value).encode('utf-8'))
95+
if total_metadata_size + size > max_gcs_metadata_size:
96+
logger.warning(f'current metadata total size is {total_metadata_size} byte, and no more labels would be added.')
97+
break
98+
total_metadata_size += size
99+
labels.append((label_name, label_value))
100+
return dict(metadata) | dict(labels)

gokart/in_memory/target.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime
2-
from typing import Any
2+
from typing import Any, Optional
33

44
from gokart.in_memory.repository import InMemoryCacheRepository
55
from gokart.target import TargetOnKart, TaskLockParams
@@ -24,7 +24,7 @@ def _get_task_lock_params(self) -> TaskLockParams:
2424
def _load(self) -> Any:
2525
return _repository.get_value(self._data_key)
2626

27-
def _dump(self, obj: Any) -> None:
27+
def _dump(self, obj: Any, task_params: Optional[dict[str, str]] = None) -> None:
2828
return _repository.set_value(self._data_key, obj)
2929

3030
def _remove(self) -> None:

gokart/target.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from gokart.conflict_prevention_lock.task_lock import TaskLockParams, make_task_lock_params
1515
from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_dump_with_lock, wrap_load_with_lock, wrap_remove_with_lock
1616
from gokart.file_processor import FileProcessor, make_file_processor
17+
from gokart.gcs_obj_metadata_client import GCSObjectMetadataClient
1718
from gokart.object_storage import ObjectStorage
1819
from gokart.zip_client_util import make_zip_client
1920

@@ -27,11 +28,11 @@ def exists(self) -> bool:
2728
def load(self) -> Any:
2829
return wrap_load_with_lock(func=self._load, task_lock_params=self._get_task_lock_params())()
2930

30-
def dump(self, obj, lock_at_dump: bool = True) -> None:
31+
def dump(self, obj, lock_at_dump: bool = True, task_params: Optional[dict[str, str]] = None) -> None:
3132
if lock_at_dump:
32-
wrap_dump_with_lock(func=self._dump, task_lock_params=self._get_task_lock_params(), exist_check=self.exists)(obj)
33+
wrap_dump_with_lock(func=self._dump, task_lock_params=self._get_task_lock_params(), exist_check=self.exists)(obj=obj, task_params=task_params)
3334
else:
34-
self._dump(obj)
35+
self._dump(obj=obj, task_params=task_params)
3536

3637
def remove(self) -> None:
3738
if self.exists():
@@ -56,7 +57,7 @@ def _load(self) -> Any:
5657
pass
5758

5859
@abstractmethod
59-
def _dump(self, obj) -> None:
60+
def _dump(self, obj, task_params: Optional[dict[str, str]] = None) -> None:
6061
pass
6162

6263
@abstractmethod
@@ -93,9 +94,11 @@ def _load(self) -> Any:
9394
with self._target.open('r') as f:
9495
return self._processor.load(f)
9596

96-
def _dump(self, obj) -> None:
97+
def _dump(self, obj, task_params: Optional[dict[str, str]] = None) -> None:
9798
with self._target.open('w') as f:
9899
self._processor.dump(obj, f)
100+
if self.path().startswith('gs://'):
101+
GCSObjectMetadataClient.add_task_state_labels(path=self.path(), task_params=task_params)
99102

100103
def _remove(self) -> None:
101104
self._target.remove()
@@ -135,10 +138,10 @@ def _load(self) -> Any:
135138
self._remove_temporary_directory()
136139
return model
137140

138-
def _dump(self, obj) -> None:
141+
def _dump(self, obj, task_params: Optional[dict[str, str]] = None) -> None:
139142
self._make_temporary_directory()
140143
self._save_function(obj, self._model_path())
141-
make_target(self._load_function_path()).dump(self._load_function)
144+
make_target(self._load_function_path()).dump(self._load_function, task_params=task_params)
142145
self._zip_client.make_archive()
143146
self._remove_temporary_directory()
144147

gokart/task.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,8 @@ def dump(self, obj: Any, target: Union[None, str, TargetOnKart] = None) -> None:
359359
PandasTypeConfigMap().check(obj, task_namespace=self.task_namespace)
360360
if self.fail_on_empty_dump and isinstance(obj, pd.DataFrame):
361361
assert not obj.empty
362-
self._get_output_target(target).dump(obj, lock_at_dump=self._lock_at_dump)
362+
363+
self._get_output_target(target).dump(obj, lock_at_dump=self._lock_at_dump, task_params=super().to_str_params(only_significant=True, only_public=True))
363364

364365
@staticmethod
365366
def get_code(target_class) -> Set[str]:
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import datetime
2+
import unittest
3+
from typing import Any
4+
from unittest.mock import MagicMock, patch
5+
6+
import gokart
7+
from gokart.gcs_obj_metadata_client import GCSObjectMetadataClient
8+
from gokart.target import TargetOnKart
9+
10+
11+
class _DummyTaskOnKart(gokart.TaskOnKart):
12+
task_namespace = __name__
13+
14+
def run(self):
15+
self.dump('Dummy TaskOnKart')
16+
17+
18+
class TestGCSObjectMetadataClient(unittest.TestCase):
19+
def test_get_patched_obj_metadata(self):
20+
task_params: dict[Any, str] = {
21+
'param1': 'a' * 1000,
22+
'param2': str(1000),
23+
'param3': str({'key1': 'value1', 'key2': True, 'key3': 2}),
24+
'param4': str([1, 2, 3, 4, 5]),
25+
'param5': str(datetime.datetime(year=2025, month=1, day=2, hour=3, minute=4, second=5)),
26+
'param6': '',
27+
}
28+
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=task_params)
29+
self.assertIsInstance(got, dict)
30+
self.assertIn('param1', got)
31+
self.assertIn('param2', got)
32+
self.assertIn('param3', got)
33+
self.assertIn('param4', got)
34+
self.assertIn('param5', got)
35+
self.assertNotIn('param6', got)
36+
37+
def test_get_patched_obj_metadata_with_exceeded_size_metadata(self):
38+
task_params = {
39+
'param1': 'a' * 5000,
40+
'param2': 'b' * 5000,
41+
}
42+
want = {
43+
'param1': 'a' * 5000,
44+
}
45+
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=task_params)
46+
self.assertEqual(got, want)
47+
48+
49+
class TestGokartTask(unittest.TestCase):
50+
@patch.object(_DummyTaskOnKart, '_get_output_target')
51+
def test_mock_target_on_kart(self, mock_get_output_target):
52+
mock_target = MagicMock(spec=TargetOnKart)
53+
mock_get_output_target.return_value = mock_target
54+
55+
task = _DummyTaskOnKart()
56+
task.dump({'key': 'value'}, mock_target)
57+
58+
mock_target.dump.assert_called_once_with({'key': 'value'}, lock_at_dump=task._lock_at_dump, task_params={})
59+
60+
61+
if __name__ == '__main__':
62+
unittest.main()

0 commit comments

Comments
 (0)