Skip to content

Commit 77d6a4a

Browse files
tyzerrr荒木 太一
andauthored
Support functionalities to add user-provided, original labels. (#446)
* WIP: implement GCSObjectMetadataClient class to attach custom-metadata to gcs-object. * WIP: change dump method interface. * WIP: add user_provided_gcs_labels parameter to TaskOnKart. * Test: add test of GCSObjectMetadataClient. * feat: dealed with nits PR comments. * feat: Remove user_provided_labels feature. This feature will be supported another PR. * for-PR: apply almost all comments. * fix: change Dict to dict. * feat: add gokart specific parameter serialize test. * fix: fix testcases with literals and more meaningful assertion. * feat: add mock testcase. * fix: fix CI errors. * feat: deal with pr comments, and modify testcases. * feat: deal with kitagry comments. * feat: Supportfunctionalities to add user specific original labels and namespace support feature and test it. * feat: remove namespace feature. * feat: Deal with kitagry PR comments. Added type annotations. * CI: apply ruff * feat: rename user_provided labels to custom_labels. * CI: apply ruff * feat: Deal with PR comments. Change _normalize_labels signature and apply changes to test. * feat: deal with PR comments. * feat: deal with PR comments. * CI: fix tests * feat: deal with PR comments. * CI: apply ruff * feat: deal with kitagry PR comments, responsibility separation. * feat: deal with yokomotod PR comments, use hoge | None expression instead of Optional. --------- Co-authored-by: 荒木 太一 <taichi.araki@m3-2024mac75.local>
1 parent 78e6bb0 commit 77d6a4a

6 files changed

Lines changed: 163 additions & 40 deletions

File tree

gokart/gcs_obj_metadata_client.py

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

33
import copy
4+
import re
45
from logging import getLogger
5-
from typing import Any, Optional, Union
6+
from typing import Any, Union
67
from urllib.parse import urlsplit
78

89
from googleapiclient.model import makepatch
@@ -18,23 +19,25 @@ class GCSObjectMetadataClient:
1819
This class used for adding metadata as labels.
1920
"""
2021

22+
@staticmethod
23+
def _is_log_related_path(path: str) -> bool:
24+
return re.match(r'^log/(processing_time/|task_info/|task_log/|module_versions/|random_seed/|task_params/).+', path) is not None
25+
2126
# This is the copied method of luigi.gcs._path_to_bucket_and_key(path).
2227
@staticmethod
23-
def path_to_bucket_and_key(path: str) -> tuple[str, str]:
28+
def _path_to_bucket_and_key(path: str) -> tuple[str, str]:
2429
(scheme, netloc, path, _, _) = urlsplit(path)
2530
assert scheme == 'gs'
2631
path_without_initial_slash = path[1:]
2732
return netloc, path_without_initial_slash
2833

2934
@staticmethod
30-
def add_task_state_labels(
31-
path: str,
32-
task_params: Optional[dict[Any, str]] = None,
33-
) -> None:
35+
def add_task_state_labels(path: str, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
36+
if GCSObjectMetadataClient._is_log_related_path(path):
37+
return
3438
# In gokart/object_storage.get_time_stamp, could find same call.
3539
# _path_to_bucket_and_key is a private method, so, this might not be acceptable.
36-
bucket, obj = GCSObjectMetadataClient.path_to_bucket_and_key(path)
37-
40+
bucket, obj = GCSObjectMetadataClient._path_to_bucket_and_key(path)
3841
_response = GCSConfig().get_gcs_client().client.objects().get(bucket=bucket, object=obj).execute()
3942
if _response is None:
4043
logger.error(f'failed to get object from GCS bucket {bucket} and object {obj}.')
@@ -50,6 +53,7 @@ def add_task_state_labels(
5053
patched_metadata = GCSObjectMetadataClient._get_patched_obj_metadata(
5154
copy.deepcopy(original_metadata),
5255
task_params,
56+
custom_labels,
5357
)
5458

5559
if original_metadata != patched_metadata:
@@ -71,30 +75,71 @@ def add_task_state_labels(
7175
if update_response is None:
7276
logger.error(f'failed to patch object {obj} in bucket {bucket} and object {obj}.')
7377

78+
@staticmethod
79+
def _normalize_labels(labels: dict[str, Any] | None) -> dict[str, str]:
80+
return {str(key): str(value) for key, value in labels.items()} if labels else {}
81+
7482
@staticmethod
7583
def _get_patched_obj_metadata(
7684
metadata: Any,
77-
task_params: Optional[dict[Any, str]] = None,
85+
task_params: dict[str, str] | None = None,
86+
custom_labels: dict[str, Any] | None = None,
7887
) -> Union[dict, Any]:
7988
# If metadata from response when getting bucket and object information is not dictionary,
8089
# something wrong might be happened, so return original metadata, no patched.
8190
if not isinstance(metadata, dict):
8291
logger.warning(f'metadata is not a dict: {metadata}, something wrong was happened when getting response when get bucket and object information.')
8392
return metadata
8493

85-
if not task_params:
94+
if not task_params and not custom_labels:
8695
return metadata
8796
# Maximum size of metadata for each object is 8 KiB.
8897
# [Link]: https://cloud.google.com/storage/quotas#objects
89-
max_gcs_metadata_size, total_metadata_size, labels = 8 * 1024, 0, []
90-
for label_name, label_value in task_params.items():
98+
normalized_task_params_labels = GCSObjectMetadataClient._normalize_labels(task_params)
99+
normalized_custom_labels = GCSObjectMetadataClient._normalize_labels(custom_labels)
100+
# There is a possibility that the keys of user-provided labels(custom_labels) may conflict with those generated from task parameters (task_params_labels).
101+
# However, users who utilize custom_labels are no longer expected to search using the labels generated from task parameters.
102+
# Instead, users are expected to search using the labels they provided.
103+
# Therefore, in the event of a key conflict, the value registered by the user-provided labels will take precedence.
104+
_merged_labels = GCSObjectMetadataClient._merge_custom_labels_and_task_params_labels(normalized_task_params_labels, normalized_custom_labels)
105+
return dict(metadata) | dict(GCSObjectMetadataClient._adjust_gcs_metadata_limit_size(_merged_labels))
106+
107+
@staticmethod
108+
def _merge_custom_labels_and_task_params_labels(
109+
normalized_task_params: dict[str, str],
110+
normalized_custom_labels: dict[str, Any],
111+
) -> dict[str, str]:
112+
merged_labels = copy.deepcopy(normalized_custom_labels)
113+
for label_name, label_value in normalized_task_params.items():
91114
if len(label_value) == 0:
92115
logger.warning(f'value of label_name={label_name} is empty. So skip to add as a metadata.')
93116
continue
94-
size = len(str(label_name).encode('utf-8')) + len(str(label_value).encode('utf-8'))
95-
if total_metadata_size + size > max_gcs_metadata_size:
96-
logger.warning(f'current metadata total size is {total_metadata_size} byte, and no more labels would be added.')
117+
if label_name in merged_labels.keys():
118+
logger.warning(f'label_name={label_name} is already seen. So skip to add as a metadata.')
119+
continue
120+
merged_labels[label_name] = label_value
121+
return merged_labels
122+
123+
# Google Cloud Storage(GCS) has a limitation of metadata size, 8 KiB.
124+
# So, we need to adjust the size of metadata.
125+
@staticmethod
126+
def _adjust_gcs_metadata_limit_size(_labels: dict[str, str]) -> dict[str, str]:
127+
def _get_label_size(label_name: str, label_value: str) -> int:
128+
return len(label_name.encode('utf-8')) + len(label_value.encode('utf-8'))
129+
130+
labels = copy.deepcopy(_labels)
131+
max_gcs_metadata_size, current_total_metadata_size = (
132+
8 * 1024,
133+
sum(_get_label_size(label_name, label_value) for label_name, label_value in labels.items()),
134+
)
135+
136+
if current_total_metadata_size <= max_gcs_metadata_size:
137+
return labels
138+
139+
for label_name, label_value in reversed(labels.items()):
140+
size = _get_label_size(label_name, label_value)
141+
del labels[label_name]
142+
current_total_metadata_size -= size
143+
if current_total_metadata_size <= max_gcs_metadata_size:
97144
break
98-
total_metadata_size += size
99-
labels.append((label_name, label_value))
100-
return dict(metadata) | dict(labels)
145+
return labels

gokart/in_memory/target.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from __future__ import annotations
2+
13
from datetime import datetime
2-
from typing import Any, Optional
4+
from typing import Any
35

46
from gokart.in_memory.repository import InMemoryCacheRepository
57
from gokart.target import TargetOnKart, TaskLockParams
@@ -24,7 +26,7 @@ def _get_task_lock_params(self) -> TaskLockParams:
2426
def _load(self) -> Any:
2527
return _repository.get_value(self._data_key)
2628

27-
def _dump(self, obj: Any, task_params: Optional[dict[str, str]] = None) -> None:
29+
def _dump(self, obj: Any, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
2830
return _repository.set_value(self._data_key, obj)
2931

3032
def _remove(self) -> None:

gokart/target.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import hashlib
24
import os
35
import shutil
@@ -28,11 +30,13 @@ def exists(self) -> bool:
2830
def load(self) -> Any:
2931
return wrap_load_with_lock(func=self._load, task_lock_params=self._get_task_lock_params())()
3032

31-
def dump(self, obj, lock_at_dump: bool = True, task_params: Optional[dict[str, str]] = None) -> None:
33+
def dump(self, obj, lock_at_dump: bool = True, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
3234
if lock_at_dump:
33-
wrap_dump_with_lock(func=self._dump, task_lock_params=self._get_task_lock_params(), exist_check=self.exists)(obj=obj, task_params=task_params)
35+
wrap_dump_with_lock(func=self._dump, task_lock_params=self._get_task_lock_params(), exist_check=self.exists)(
36+
obj=obj, task_params=task_params, custom_labels=custom_labels
37+
)
3438
else:
35-
self._dump(obj=obj, task_params=task_params)
39+
self._dump(obj=obj, task_params=task_params, custom_labels=custom_labels)
3640

3741
def remove(self) -> None:
3842
if self.exists():
@@ -57,7 +61,7 @@ def _load(self) -> Any:
5761
pass
5862

5963
@abstractmethod
60-
def _dump(self, obj, task_params: Optional[dict[str, str]] = None) -> None:
64+
def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
6165
pass
6266

6367
@abstractmethod
@@ -94,11 +98,11 @@ def _load(self) -> Any:
9498
with self._target.open('r') as f:
9599
return self._processor.load(f)
96100

97-
def _dump(self, obj, task_params: Optional[dict[str, str]] = None) -> None:
101+
def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
98102
with self._target.open('w') as f:
99103
self._processor.dump(obj, f)
100104
if self.path().startswith('gs://'):
101-
GCSObjectMetadataClient.add_task_state_labels(path=self.path(), task_params=task_params)
105+
GCSObjectMetadataClient.add_task_state_labels(path=self.path(), task_params=task_params, custom_labels=custom_labels)
102106

103107
def _remove(self) -> None:
104108
self._target.remove()
@@ -138,7 +142,7 @@ def _load(self) -> Any:
138142
self._remove_temporary_directory()
139143
return model
140144

141-
def _dump(self, obj, task_params: Optional[dict[str, str]] = None) -> None:
145+
def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
142146
self._make_temporary_directory()
143147
self._save_function(obj, self._model_path())
144148
make_target(self._load_function_path()).dump(self._load_function, task_params=task_params)

gokart/task.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import functools
24
import hashlib
35
import inspect
@@ -350,17 +352,22 @@ def _flatten_recursively(dfs):
350352
return data
351353

352354
@overload
353-
def dump(self, obj: T, target: None = None) -> None: ...
355+
def dump(self, obj: T, target: None = None, custom_labels: dict[Any, Any] | None = None) -> None: ...
354356

355357
@overload
356-
def dump(self, obj: Any, target: Union[str, TargetOnKart]) -> None: ...
358+
def dump(self, obj: Any, target: Union[str, TargetOnKart], custom_labels: dict[Any, Any] | None = None) -> None: ...
357359

358-
def dump(self, obj: Any, target: Union[None, str, TargetOnKart] = None) -> None:
360+
def dump(self, obj: Any, target: Union[None, str, TargetOnKart] = None, custom_labels: dict[str, Any] | None = None) -> None:
359361
PandasTypeConfigMap().check(obj, task_namespace=self.task_namespace)
360362
if self.fail_on_empty_dump and isinstance(obj, pd.DataFrame):
361363
assert not obj.empty
362364

363-
self._get_output_target(target).dump(obj, lock_at_dump=self._lock_at_dump, task_params=super().to_str_params(only_significant=True, only_public=True))
365+
self._get_output_target(target).dump(
366+
obj,
367+
lock_at_dump=self._lock_at_dump,
368+
task_params=super().to_str_params(only_significant=True, only_public=True),
369+
custom_labels=custom_labels,
370+
)
364371

365372
@staticmethod
366373
def get_code(target_class) -> Set[str]:

test/test_gcs_obj_metadata_client.py

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,47 @@ def run(self):
1616

1717

1818
class TestGCSObjectMetadataClient(unittest.TestCase):
19-
def test_get_patched_obj_metadata(self):
20-
task_params: dict[Any, str] = {
19+
def setUp(self):
20+
self.task_params: dict[str, str] = {
2121
'param1': 'a' * 1000,
2222
'param2': str(1000),
2323
'param3': str({'key1': 'value1', 'key2': True, 'key3': 2}),
2424
'param4': str([1, 2, 3, 4, 5]),
2525
'param5': str(datetime.datetime(year=2025, month=1, day=2, hour=3, minute=4, second=5)),
2626
'param6': '',
2727
}
28-
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=task_params)
28+
self.custom_labels: dict[str, Any] = {
29+
'created_at': datetime.datetime(year=2025, month=1, day=2, hour=3, minute=4, second=5),
30+
'created_by': 'hoge fuga',
31+
'empty': True,
32+
'try_num': 3,
33+
}
34+
35+
self.task_params_with_conflicts = {
36+
'empty': 'False',
37+
'created_by': 'fuga hoge',
38+
'param1': 'a' * 10,
39+
}
40+
41+
def test_normalize_labels_not_empty(self):
42+
got = GCSObjectMetadataClient._normalize_labels(None)
43+
self.assertEqual(got, {})
44+
45+
def test_normalize_labels_has_value(self):
46+
got = GCSObjectMetadataClient._normalize_labels(self.task_params)
47+
48+
self.assertIsInstance(got, dict)
49+
self.assertIsInstance(got, dict)
50+
self.assertIn('param1', got)
51+
self.assertIn('param2', got)
52+
self.assertIn('param3', got)
53+
self.assertIn('param4', got)
54+
self.assertIn('param5', got)
55+
self.assertIn('param6', got)
56+
57+
def test_get_patched_obj_metadata_only_task_params(self):
58+
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=self.task_params, custom_labels=None)
59+
2960
self.assertIsInstance(got, dict)
3061
self.assertIn('param1', got)
3162
self.assertIn('param2', got)
@@ -34,17 +65,52 @@ def test_get_patched_obj_metadata(self):
3465
self.assertIn('param5', got)
3566
self.assertNotIn('param6', got)
3667

68+
def test_get_patched_obj_metadata_only_custom_labels(self):
69+
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=None, custom_labels=self.custom_labels)
70+
71+
self.assertIsInstance(got, dict)
72+
self.assertIn('created_at', got)
73+
self.assertIn('created_by', got)
74+
self.assertIn('empty', got)
75+
self.assertIn('try_num', got)
76+
77+
def test_get_patched_obj_metadata_with_both_task_params_and_custom_labels(self):
78+
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=self.task_params, custom_labels=self.custom_labels)
79+
80+
self.assertIsInstance(got, dict)
81+
self.assertIn('param1', got)
82+
self.assertIn('param2', got)
83+
self.assertIn('param3', got)
84+
self.assertIn('param4', got)
85+
self.assertIn('param5', got)
86+
self.assertNotIn('param6', got)
87+
self.assertIn('created_at', got)
88+
self.assertIn('created_by', got)
89+
self.assertIn('empty', got)
90+
self.assertIn('try_num', got)
91+
3792
def test_get_patched_obj_metadata_with_exceeded_size_metadata(self):
38-
task_params = {
93+
size_exceeded_task_params = {
3994
'param1': 'a' * 5000,
4095
'param2': 'b' * 5000,
4196
}
4297
want = {
4398
'param1': 'a' * 5000,
4499
}
45-
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=task_params)
100+
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=size_exceeded_task_params)
46101
self.assertEqual(got, want)
47102

103+
def test_get_patched_obj_metadata_with_conflicts(self):
104+
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=self.task_params_with_conflicts, custom_labels=self.custom_labels)
105+
self.assertIsInstance(got, dict)
106+
self.assertIn('created_at', got)
107+
self.assertIn('created_by', got)
108+
self.assertIn('empty', got)
109+
self.assertIn('try_num', got)
110+
self.assertEqual(got['empty'], 'True')
111+
self.assertEqual(got['created_by'], 'hoge fuga')
112+
self.assertEqual(got['param1'], 'a' * 10)
113+
48114

49115
class TestGokartTask(unittest.TestCase):
50116
@patch.object(_DummyTaskOnKart, '_get_output_target')
@@ -54,8 +120,7 @@ def test_mock_target_on_kart(self, mock_get_output_target):
54120

55121
task = _DummyTaskOnKart()
56122
task.dump({'key': 'value'}, mock_target)
57-
58-
mock_target.dump.assert_called_once_with({'key': 'value'}, lock_at_dump=task._lock_at_dump, task_params={})
123+
mock_target.dump.assert_called_once_with({'key': 'value'}, lock_at_dump=task._lock_at_dump, task_params={}, custom_labels=None)
59124

60125

61126
if __name__ == '__main__':

test/test_task_on_kart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ def test_should_fail_lock_run_when_port_unset(self):
613613

614614

615615
class _DummyTaskWithNonCompleted(gokart.TaskOnKart):
616-
def dump(self, _obj: Any, _target: Any = None):
616+
def dump(self, _obj: Any, _target: Any = None, _custom_labels: Any = None):
617617
# overrive dump() to do nothing.
618618
pass
619619

@@ -625,7 +625,7 @@ def complete(self):
625625

626626

627627
class _DummyTaskWithCompleted(gokart.TaskOnKart):
628-
def dump(self, obj: Any, _target: Any = None):
628+
def dump(self, obj: Any, _target: Any = None, custom_labels: Any = None):
629629
# overrive dump() to do nothing.
630630
pass
631631

0 commit comments

Comments
 (0)