-
Notifications
You must be signed in to change notification settings - Fork 62
Support functionalities to enhance task traceability with metadata for dependency search. #450
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
79a2881
0cfe7ee
3eee422
ec3bf4f
22a69d0
08e3f59
9b19a1c
accbf1d
6719f4d
0bcc16c
5c41035
0b951ab
10795a2
32b4343
6f70a41
637f5da
b607926
a8059a1
27b1abd
5ac1c4d
f4479da
e71833b
46aabcf
7bde3b0
4c44cea
dd6a629
d884c79
f1418f8
6a1c4c2
0b06455
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,6 +1,7 @@ | ||||||||||
| from __future__ import annotations | ||||||||||
|
|
||||||||||
| import copy | ||||||||||
| import json | ||||||||||
| import re | ||||||||||
| from logging import getLogger | ||||||||||
| from typing import Any, Union | ||||||||||
|
|
@@ -21,7 +22,7 @@ class GCSObjectMetadataClient: | |||||||||
|
|
||||||||||
| @staticmethod | ||||||||||
| def _is_log_related_path(path: str) -> bool: | ||||||||||
| return re.match(r'^log/(processing_time/|task_info/|task_log/|module_versions/|random_seed/|task_params/).+', path) is not None | ||||||||||
| return re.match(r'^gs://.+?/log/(processing_time/|task_info/|task_log/|module_versions/|random_seed/|task_params/).+', path) is not None | ||||||||||
|
|
||||||||||
| # This is the copied method of luigi.gcs._path_to_bucket_and_key(path). | ||||||||||
| @staticmethod | ||||||||||
|
|
@@ -32,7 +33,12 @@ def _path_to_bucket_and_key(path: str) -> tuple[str, str]: | |||||||||
| return netloc, path_without_initial_slash | ||||||||||
|
|
||||||||||
| @staticmethod | ||||||||||
| def add_task_state_labels(path: str, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None: | ||||||||||
| def add_task_state_labels( | ||||||||||
| path: str, | ||||||||||
| task_params: dict[str, str] | None = None, | ||||||||||
| custom_labels: dict[str, Any] | None = None, | ||||||||||
| required_task_outputs: dict[str, str] | None = None, | ||||||||||
| ) -> None: | ||||||||||
| if GCSObjectMetadataClient._is_log_related_path(path): | ||||||||||
| return | ||||||||||
| # In gokart/object_storage.get_time_stamp, could find same call. | ||||||||||
|
|
@@ -42,20 +48,18 @@ def add_task_state_labels(path: str, task_params: dict[str, str] | None = None, | |||||||||
| if _response is None: | ||||||||||
| logger.error(f'failed to get object from GCS bucket {bucket} and object {obj}.') | ||||||||||
| return | ||||||||||
|
|
||||||||||
| response: dict[str, Any] = dict(_response) | ||||||||||
| original_metadata: dict[Any, Any] = {} | ||||||||||
| if 'metadata' in response.keys(): | ||||||||||
| _metadata = response.get('metadata') | ||||||||||
| if _metadata is not None: | ||||||||||
| original_metadata = dict(_metadata) | ||||||||||
|
|
||||||||||
| patched_metadata = GCSObjectMetadataClient._get_patched_obj_metadata( | ||||||||||
| copy.deepcopy(original_metadata), | ||||||||||
| task_params, | ||||||||||
| custom_labels, | ||||||||||
| required_task_outputs if required_task_outputs else None, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| if original_metadata != patched_metadata: | ||||||||||
| # If we use update api, existing object metadata are removed, so should use patch api. | ||||||||||
| # See the official document descriptions. | ||||||||||
|
|
@@ -71,7 +75,6 @@ def add_task_state_labels(path: str, task_params: dict[str, str] | None = None, | |||||||||
| ) | ||||||||||
| .execute() | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| if update_response is None: | ||||||||||
| logger.error(f'failed to patch object {obj} in bucket {bucket} and object {obj}.') | ||||||||||
|
|
||||||||||
|
|
@@ -84,13 +87,13 @@ def _get_patched_obj_metadata( | |||||||||
| metadata: Any, | ||||||||||
| task_params: dict[str, str] | None = None, | ||||||||||
| custom_labels: dict[str, Any] | None = None, | ||||||||||
| required_task_outputs: dict[str, str] | None = None, | ||||||||||
| ) -> Union[dict, Any]: | ||||||||||
| # If metadata from response when getting bucket and object information is not dictionary, | ||||||||||
| # something wrong might be happened, so return original metadata, no patched. | ||||||||||
| if not isinstance(metadata, dict): | ||||||||||
| logger.warning(f'metadata is not a dict: {metadata}, something wrong was happened when getting response when get bucket and object information.') | ||||||||||
| return metadata | ||||||||||
|
|
||||||||||
| if not task_params and not custom_labels: | ||||||||||
| return metadata | ||||||||||
| # Maximum size of metadata for each object is 8 KiB. | ||||||||||
|
|
@@ -101,23 +104,28 @@ def _get_patched_obj_metadata( | |||||||||
| # However, users who utilize custom_labels are no longer expected to search using the labels generated from task parameters. | ||||||||||
| # Instead, users are expected to search using the labels they provided. | ||||||||||
| # Therefore, in the event of a key conflict, the value registered by the user-provided labels will take precedence. | ||||||||||
| _merged_labels = GCSObjectMetadataClient._merge_custom_labels_and_task_params_labels(normalized_task_params_labels, normalized_custom_labels) | ||||||||||
| normalized_labels = ( | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [imo]
Suggested change
|
||||||||||
| [normalized_custom_labels, normalized_task_params_labels] | ||||||||||
| if not required_task_outputs | ||||||||||
| else [normalized_custom_labels, normalized_custom_labels, {'required_task_outputs': json.dumps(required_task_outputs)}] | ||||||||||
| ) | ||||||||||
| _merged_labels = GCSObjectMetadataClient._merge_custom_labels_and_task_params_labels(normalized_labels) | ||||||||||
| return GCSObjectMetadataClient._adjust_gcs_metadata_limit_size(dict(metadata) | _merged_labels) | ||||||||||
|
|
||||||||||
| @staticmethod | ||||||||||
| def _merge_custom_labels_and_task_params_labels( | ||||||||||
| normalized_task_params: dict[str, str], | ||||||||||
| normalized_custom_labels: dict[str, Any], | ||||||||||
| normalized_labels_list: list[dict[str, Any]], | ||||||||||
| ) -> dict[str, str]: | ||||||||||
| merged_labels = copy.deepcopy(normalized_custom_labels) | ||||||||||
| for label_name, label_value in normalized_task_params.items(): | ||||||||||
| if len(label_value) == 0: | ||||||||||
| logger.warning(f'value of label_name={label_name} is empty. So skip to add as a metadata.') | ||||||||||
| continue | ||||||||||
| if label_name in merged_labels.keys(): | ||||||||||
| logger.warning(f'label_name={label_name} is already seen. So skip to add as a metadata.') | ||||||||||
| continue | ||||||||||
| merged_labels[label_name] = label_value | ||||||||||
| merged_labels: dict[str, str] = {} | ||||||||||
| for normalized_label in normalized_labels_list[:]: | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [weak-IMO] I thought this part a bit difficult to understand, since it is deeply nested. It may get better if you extract However, current code is OK though. :)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for great suggestion! For this specific task of merging labels, the simple nested loop is likely more readable and Pythonic than using functools.reduce. While reduce can be used, in this scenario, the straightforward nested loop (or perhaps the alternative 'flattening' approach) probably offers better clarity and maintainability. How do you think?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I preferred reduce approach, because it express the motivation of making In the nested loop, you need to read to L.147 to understand the motivation of building However, both approach is OK, since this is relatively small loop nest. :)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right. |
||||||||||
| for label_name, label_value in normalized_label.items(): | ||||||||||
| if len(label_value) == 0: | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [MUST] This code may fail, since it seems to assume that I prefer checking if it is str, and then check the length as,
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for reviewing my code!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @TlexCypher
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @TlexCypher Colud you check this comment? If you are confirmed that label_value is str, you should
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I fixed here 0b06455 |
||||||||||
| logger.warning(f'value of label_name={label_name} is empty. So skip to add as a metadata.') | ||||||||||
| continue | ||||||||||
| if label_name in merged_labels.keys(): | ||||||||||
| logger.warning(f'label_name={label_name} is already seen. So skip to add as a metadata.') | ||||||||||
| continue | ||||||||||
| merged_labels[label_name] = label_value | ||||||||||
| return merged_labels | ||||||||||
|
|
||||||||||
| # Google Cloud Storage(GCS) has a limitation of metadata size, 8 KiB. | ||||||||||
|
|
@@ -132,10 +140,8 @@ def _get_label_size(label_name: str, label_value: str) -> int: | |||||||||
| 8 * 1024, | ||||||||||
| sum(_get_label_size(label_name, label_value) for label_name, label_value in labels.items()), | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| if current_total_metadata_size <= max_gcs_metadata_size: | ||||||||||
| return labels | ||||||||||
|
|
||||||||||
| for label_name, label_value in reversed(labels.items()): | ||||||||||
| size = _get_label_size(label_name, label_value) | ||||||||||
| del labels[label_name] | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,13 +30,23 @@ def exists(self) -> bool: | |
| def load(self) -> Any: | ||
| return wrap_load_with_lock(func=self._load, task_lock_params=self._get_task_lock_params())() | ||
|
|
||
| def dump(self, obj, lock_at_dump: bool = True, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None: | ||
| def dump( | ||
| self, | ||
| obj, | ||
| lock_at_dump: bool = True, | ||
| task_params: dict[str, str] | None = None, | ||
| custom_labels: dict[str, Any] | None = None, | ||
| required_task_outputs: dict[str, str] | None = None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [imo] |
||
| ) -> None: | ||
| if lock_at_dump: | ||
| wrap_dump_with_lock(func=self._dump, task_lock_params=self._get_task_lock_params(), exist_check=self.exists)( | ||
| obj=obj, task_params=task_params, custom_labels=custom_labels | ||
| obj=obj, | ||
| task_params=task_params, | ||
| custom_labels=custom_labels, | ||
| required_task_outputs=required_task_outputs, | ||
| ) | ||
| else: | ||
| self._dump(obj=obj, task_params=task_params, custom_labels=custom_labels) | ||
| self._dump(obj=obj, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs) | ||
|
|
||
| def remove(self) -> None: | ||
| if self.exists(): | ||
|
|
@@ -61,7 +71,13 @@ def _load(self) -> Any: | |
| pass | ||
|
|
||
| @abstractmethod | ||
| def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None: | ||
| def _dump( | ||
| self, | ||
| obj, | ||
| task_params: Optional[dict[str, str]] = None, | ||
| custom_labels: dict[str, Any] | None = None, | ||
| required_task_outputs: dict[str, str] | None = None, | ||
| ) -> None: | ||
| pass | ||
|
|
||
| @abstractmethod | ||
|
|
@@ -98,11 +114,19 @@ def _load(self) -> Any: | |
| with self._target.open('r') as f: | ||
| return self._processor.load(f) | ||
|
|
||
| def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None: | ||
| def _dump( | ||
| self, | ||
| obj, | ||
| task_params: dict[str, str] | None = None, | ||
| custom_labels: dict[str, Any] | None = None, | ||
| required_task_outputs: dict[str, str] | None = None, | ||
| ) -> None: | ||
| with self._target.open('w') as f: | ||
| self._processor.dump(obj, f) | ||
| if self.path().startswith('gs://'): | ||
| GCSObjectMetadataClient.add_task_state_labels(path=self.path(), task_params=task_params, custom_labels=custom_labels) | ||
| GCSObjectMetadataClient.add_task_state_labels( | ||
| path=self.path(), task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs | ||
| ) | ||
|
|
||
| def _remove(self) -> None: | ||
| self._target.remove() | ||
|
|
@@ -142,10 +166,18 @@ def _load(self) -> Any: | |
| self._remove_temporary_directory() | ||
| return model | ||
|
|
||
| def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None: | ||
| def _dump( | ||
| self, | ||
| obj, | ||
| task_params: dict[str, str] | None = None, | ||
| custom_labels: dict[str, Any] | None = None, | ||
| required_task_outputs: dict[str, str] | None = None, | ||
| ) -> None: | ||
| self._make_temporary_directory() | ||
| self._save_function(obj, self._model_path()) | ||
| make_target(self._load_function_path()).dump(self._load_function, task_params=task_params) | ||
| make_target(self._load_function_path()).dump( | ||
| self._load_function, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs | ||
| ) | ||
| self._zip_client.make_archive() | ||
| self._remove_temporary_directory() | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,10 +7,13 @@ | |
| import random | ||
| import sys | ||
| import types | ||
| from dataclasses import dataclass | ||
| from importlib import import_module | ||
| from logging import getLogger | ||
| from typing import Any, Callable, Dict, Generator, Generic, Iterable, List, Optional, Set, TypeVar, Union, overload | ||
|
|
||
| from gokart.utils import map_flattenable_items | ||
|
|
||
| if sys.version_info < (3, 13): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe this part is not needed? |
||
| from typing_extensions import deprecated | ||
| else: | ||
|
|
@@ -362,11 +365,26 @@ def dump(self, obj: Any, target: Union[None, str, TargetOnKart] = None, custom_l | |
| if self.fail_on_empty_dump and isinstance(obj, pd.DataFrame): | ||
| assert not obj.empty | ||
|
|
||
| @dataclass | ||
| class _RequiredTaskOutput: | ||
| task_name: str | ||
| output_path: str | ||
|
|
||
| _required_task_outputs = flatten( | ||
| map_flattenable_items( | ||
| lambda task: map_flattenable_items( | ||
| lambda output: _RequiredTaskOutput(task_name=task.get_task_family(), output_path=output.path()), task.output() | ||
|
kitagry marked this conversation as resolved.
Outdated
|
||
| ), | ||
| self.requires(), | ||
| ) | ||
| ) | ||
| required_task_outputs = {r.task_name: r.output_path for r in _required_task_outputs} | ||
| self._get_output_target(target).dump( | ||
| obj, | ||
| lock_at_dump=self._lock_at_dump, | ||
| task_params=super().to_str_params(only_significant=True, only_public=True), | ||
| custom_labels=custom_labels, | ||
| required_task_outputs=required_task_outputs, | ||
| ) | ||
|
|
||
| @staticmethod | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -3,7 +3,7 @@ | |||||
| import os | ||||||
| import sys | ||||||
| from io import BytesIO | ||||||
| from typing import Any, Iterable, Protocol, TypeVar, Union | ||||||
| from typing import Any, Callable, Iterable, Protocol, TypeVar, Union | ||||||
|
|
||||||
| import dill | ||||||
| import luigi | ||||||
|
|
@@ -71,6 +71,21 @@ def flatten(targets: FlattenableItems[T]) -> list[T]: | |||||
| return flat | ||||||
|
|
||||||
|
|
||||||
| K = TypeVar('K') | ||||||
|
|
||||||
|
|
||||||
| def map_flattenable_items(func: Callable[[T], K], items: FlattenableItems[T]) -> FlattenableItems[K]: | ||||||
| if isinstance(items, dict): | ||||||
| return {k: map_flattenable_items(func, v) for k, v in items.items()} | ||||||
| if isinstance(items, tuple): | ||||||
| return tuple(map_flattenable_items(func, i) for i in items) | ||||||
| if isinstance(items, str): | ||||||
| return func(items) # type: ignore | ||||||
| if isinstance(items, Iterable): | ||||||
| return [map_flattenable_items(func, i) for i in items] | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| return func(items) | ||||||
|
|
||||||
|
|
||||||
| def load_dill_with_pandas_backward_compatibility(file: Union[FileLike, BytesIO]) -> Any: | ||||||
| """Load binary dumped by dill with pandas backward compatibility. | ||||||
| pd.read_pickle can load binary dumped in backward pandas version, and also any objects dumped by pickle. | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to be redundant