Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
79a2881
WIP: End to implement the logic to gather the required task output path.
Mar 4, 2025
0cfe7ee
WIP: success to add output path in nest mode, but some other case sho…
Mar 4, 2025
3eee422
WIP: no ci apply.
Mar 5, 2025
ec3bf4f
feat: fix to pass labels and has_seen_keys.
Mar 5, 2025
22a69d0
feat: fix conflicts
Mar 5, 2025
08e3f59
CI: apply ruff and mypy
Mar 5, 2025
9b19a1c
feat: add implementation of nest mode.
Mar 5, 2025
accbf1d
feat: deal with kitagry comments.
Mar 6, 2025
6719f4d
feat: Remove CLI dependencies.
Mar 6, 2025
0bcc16c
feat: remove redundant statements.
Mar 6, 2025
5c41035
feat: change serialization expression for single FlattenableItems[Req…
Mar 6, 2025
0b951ab
CI: fix test and apply CI.
Mar 6, 2025
10795a2
feat: fix mypy error.
Mar 6, 2025
32b4343
feat: refactoring make _list_flatten inner function.
Mar 6, 2025
6f70a41
feat: fix nits miss and add __ prefix to avoid conflicts.
Mar 6, 2025
637f5da
feat: rename _list_flatten
Mar 7, 2025
b607926
Merge: fix conflicts.
tyzerrr Mar 15, 2025
a8059a1
Merge: fix conflicts.
tyzerrr Mar 15, 2025
27b1abd
feat: convert map object to list, any iterable objects that would be …
tyzerrr Apr 17, 2025
5ac1c4d
Merge remote-tracking branch 'origin/master' into feat/nestmode
tyzerrr Apr 17, 2025
f4479da
Merge remote-tracking branch 'origin/feat/nestmode' into feat/nestmode
tyzerrr Apr 17, 2025
e71833b
feat: add new line to end of param.ini
tyzerrr Apr 22, 2025
46aabcf
feat: remove redundant expressions
tyzerrr Apr 22, 2025
7bde3b0
Merge branch 'master' into feat/nestmode
hirosassa Apr 24, 2025
4c44cea
feat: use yiled to make memory efficient and use functools.reduce to …
tyzerrr Apr 28, 2025
dd6a629
Merge remote-tracking branch 'origin/feat/nestmode' into feat/nestmode
tyzerrr Apr 28, 2025
d884c79
Merge branch 'master' into feat/nestmode
hirosassa Apr 28, 2025
f1418f8
feat: fix type of normalized_labeles_list
tyzerrr Apr 28, 2025
6a1c4c2
Merge remote-tracking branch 'origin/feat/nestmode' into feat/nestmode
tyzerrr Apr 28, 2025
0b06455
chore: change custom_labels type
kitagry Apr 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions gokart/gcs_obj_metadata_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
import json
import re
from logging import getLogger
from typing import Any, Union
Expand All @@ -21,7 +22,7 @@ class GCSObjectMetadataClient:

@staticmethod
def _is_log_related_path(path: str) -> bool:
return re.match(r'^log/(processing_time/|task_info/|task_log/|module_versions/|random_seed/|task_params/).+', path) is not None
return re.match(r'^gs://.+?/log/(processing_time/|task_info/|task_log/|module_versions/|random_seed/|task_params/).+', path) is not None

# This is the copied method of luigi.gcs._path_to_bucket_and_key(path).
@staticmethod
Expand All @@ -32,7 +33,12 @@ def _path_to_bucket_and_key(path: str) -> tuple[str, str]:
return netloc, path_without_initial_slash

@staticmethod
def add_task_state_labels(path: str, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
def add_task_state_labels(
path: str,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, Any] | None = None,
required_task_outputs: dict[str, str] | None = None,
) -> None:
if GCSObjectMetadataClient._is_log_related_path(path):
return
# In gokart/object_storage.get_time_stamp, could find same call.
Expand All @@ -42,20 +48,18 @@ def add_task_state_labels(path: str, task_params: dict[str, str] | None = None,
if _response is None:
logger.error(f'failed to get object from GCS bucket {bucket} and object {obj}.')
return

response: dict[str, Any] = dict(_response)
original_metadata: dict[Any, Any] = {}
if 'metadata' in response.keys():
_metadata = response.get('metadata')
if _metadata is not None:
original_metadata = dict(_metadata)

patched_metadata = GCSObjectMetadataClient._get_patched_obj_metadata(
copy.deepcopy(original_metadata),
task_params,
custom_labels,
required_task_outputs if required_task_outputs else None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to be redundant

Suggested change
required_task_outputs if required_task_outputs else None,
required_task_outputs,

)

if original_metadata != patched_metadata:
# If we use update api, existing object metadata are removed, so should use patch api.
# See the official document descriptions.
Expand All @@ -71,7 +75,6 @@ def add_task_state_labels(path: str, task_params: dict[str, str] | None = None,
)
.execute()
)

if update_response is None:
logger.error(f'failed to patch object {obj} in bucket {bucket} and object {obj}.')

Expand All @@ -84,13 +87,13 @@ def _get_patched_obj_metadata(
metadata: Any,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, Any] | None = None,
required_task_outputs: dict[str, str] | None = None,
) -> Union[dict, Any]:
# If metadata from response when getting bucket and object information is not dictionary,
# something wrong might be happened, so return original metadata, no patched.
if not isinstance(metadata, dict):
logger.warning(f'metadata is not a dict: {metadata}, something wrong was happened when getting response when get bucket and object information.')
return metadata

if not task_params and not custom_labels:
return metadata
# Maximum size of metadata for each object is 8 KiB.
Expand All @@ -101,23 +104,28 @@ def _get_patched_obj_metadata(
# However, users who utilize custom_labels are no longer expected to search using the labels generated from task parameters.
# Instead, users are expected to search using the labels they provided.
# Therefore, in the event of a key conflict, the value registered by the user-provided labels will take precedence.
_merged_labels = GCSObjectMetadataClient._merge_custom_labels_and_task_params_labels(normalized_task_params_labels, normalized_custom_labels)
normalized_labels = (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[imo]
I prefer this because of readability

Suggested change
normalized_labels = (
normalized_labels = [normalized_custom_labels, normalized_task_params_labels]
if not required_task_outputs
normalized_labels.append({'__required_task_outputs': json.dumps(GCSObjectMetadataClient._get_serialized_string(required_task_outputs))})

[normalized_custom_labels, normalized_task_params_labels]
if not required_task_outputs
else [normalized_custom_labels, normalized_custom_labels, {'required_task_outputs': json.dumps(required_task_outputs)}]
)
_merged_labels = GCSObjectMetadataClient._merge_custom_labels_and_task_params_labels(normalized_labels)
return GCSObjectMetadataClient._adjust_gcs_metadata_limit_size(dict(metadata) | _merged_labels)

@staticmethod
def _merge_custom_labels_and_task_params_labels(
normalized_task_params: dict[str, str],
normalized_custom_labels: dict[str, Any],
normalized_labels_list: list[dict[str, Any]],
) -> dict[str, str]:
merged_labels = copy.deepcopy(normalized_custom_labels)
for label_name, label_value in normalized_task_params.items():
if len(label_value) == 0:
logger.warning(f'value of label_name={label_name} is empty. So skip to add as a metadata.')
continue
if label_name in merged_labels.keys():
logger.warning(f'label_name={label_name} is already seen. So skip to add as a metadata.')
continue
merged_labels[label_name] = label_value
merged_labels: dict[str, str] = {}
for normalized_label in normalized_labels_list[:]:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for normalized_label in normalized_labels_list[:]:
for normalized_label in normalized_labels_list:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[weak-IMO]

for normalized_label in normalized_labels_list:
    for label_name, label_value in normalized_label.items():
        if len(label_value) == 0:

I thought this part a bit difficult to understand, since it is deeply nested.

It may get better if you extract for label_name, label_value in... part as a separate function, and apply it with a functools.reduce().

However, current code is OK though. :)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for great suggestion!

For this specific task of merging labels, the simple nested loop is likely more readable and Pythonic than using functools.reduce.

While reduce can be used, in this scenario, the straightforward nested loop (or perhaps the alternative 'flattening' approach) probably offers better clarity and maintainability.

How do you think?

Copy link
Copy Markdown
Contributor

@mski-iksm mski-iksm Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I preferred reduce approach, because it express the motivation of making merged_labels earlier, which makes the first time reader easier to understand.

merged_labels = reduce(...)

In the nested loop, you need to read to L.147 to understand the motivation of building merged_labels.

However, both approach is OK, since this is relatively small loop nest. :)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right.
I put some changes to use functools.reduce.
Thank you for help!

for label_name, label_value in normalized_label.items():
if len(label_value) == 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[MUST] This code may fail, since it seems to assume that label_value is str.

I prefer checking if it is str, and then check the length as,

isinstance(label_value, str) and len(label_value)==0

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for reviewing my code!
In my opinion, type checking is not necessary, because GCSObjectMetadataClient._normalize_labels convert all values stored in dictionary into string.
So, label_value definitely is string.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TlexCypher
Then maybe the input normalized_labels_list: list[dict[str, Any]] should be normalized_labels_list: list[dict[str, str]] ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TlexCypher Colud you check this comment?

If you are confirmed that label_value is str, you should str instead of Any

Copy link
Copy Markdown
Member

@kitagry kitagry Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed here 0b06455

logger.warning(f'value of label_name={label_name} is empty. So skip to add as a metadata.')
continue
if label_name in merged_labels.keys():
logger.warning(f'label_name={label_name} is already seen. So skip to add as a metadata.')
continue
merged_labels[label_name] = label_value
return merged_labels

# Google Cloud Storage(GCS) has a limitation of metadata size, 8 KiB.
Expand All @@ -132,10 +140,8 @@ def _get_label_size(label_name: str, label_value: str) -> int:
8 * 1024,
sum(_get_label_size(label_name, label_value) for label_name, label_value in labels.items()),
)

if current_total_metadata_size <= max_gcs_metadata_size:
return labels

for label_name, label_value in reversed(labels.items()):
size = _get_label_size(label_name, label_value)
del labels[label_name]
Expand Down
8 changes: 7 additions & 1 deletion gokart/in_memory/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@ def _get_task_lock_params(self) -> TaskLockParams:
def _load(self) -> Any:
return _repository.get_value(self._data_key)

def _dump(self, obj: Any, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
def _dump(
self,
obj: Any,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, Any] | None = None,
required_task_outputs: dict[str, str] | None = None,
) -> None:
return _repository.set_value(self._data_key, obj)

def _remove(self) -> None:
Expand Down
48 changes: 40 additions & 8 deletions gokart/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,23 @@ def exists(self) -> bool:
def load(self) -> Any:
return wrap_load_with_lock(func=self._load, task_lock_params=self._get_task_lock_params())()

def dump(self, obj, lock_at_dump: bool = True, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
def dump(
self,
obj,
lock_at_dump: bool = True,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, Any] | None = None,
required_task_outputs: dict[str, str] | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[imo]
This parameter seems to be just a metadata. But its name may indicate that it effects the functionality of the method or the class's attribute. It would be better to rename for avoiding such a misleading.

) -> None:
if lock_at_dump:
wrap_dump_with_lock(func=self._dump, task_lock_params=self._get_task_lock_params(), exist_check=self.exists)(
obj=obj, task_params=task_params, custom_labels=custom_labels
obj=obj,
task_params=task_params,
custom_labels=custom_labels,
required_task_outputs=required_task_outputs,
)
else:
self._dump(obj=obj, task_params=task_params, custom_labels=custom_labels)
self._dump(obj=obj, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs)

def remove(self) -> None:
if self.exists():
Expand All @@ -61,7 +71,13 @@ def _load(self) -> Any:
pass

@abstractmethod
def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
def _dump(
self,
obj,
task_params: Optional[dict[str, str]] = None,
custom_labels: dict[str, Any] | None = None,
required_task_outputs: dict[str, str] | None = None,
) -> None:
pass

@abstractmethod
Expand Down Expand Up @@ -98,11 +114,19 @@ def _load(self) -> Any:
with self._target.open('r') as f:
return self._processor.load(f)

def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
def _dump(
self,
obj,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, Any] | None = None,
required_task_outputs: dict[str, str] | None = None,
) -> None:
with self._target.open('w') as f:
self._processor.dump(obj, f)
if self.path().startswith('gs://'):
GCSObjectMetadataClient.add_task_state_labels(path=self.path(), task_params=task_params, custom_labels=custom_labels)
GCSObjectMetadataClient.add_task_state_labels(
path=self.path(), task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs
)

def _remove(self) -> None:
self._target.remove()
Expand Down Expand Up @@ -142,10 +166,18 @@ def _load(self) -> Any:
self._remove_temporary_directory()
return model

def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
def _dump(
self,
obj,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, Any] | None = None,
required_task_outputs: dict[str, str] | None = None,
) -> None:
self._make_temporary_directory()
self._save_function(obj, self._model_path())
make_target(self._load_function_path()).dump(self._load_function, task_params=task_params)
make_target(self._load_function_path()).dump(
self._load_function, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs
)
self._zip_client.make_archive()
self._remove_temporary_directory()

Expand Down
18 changes: 18 additions & 0 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
import random
import sys
import types
from dataclasses import dataclass
from importlib import import_module
from logging import getLogger
from typing import Any, Callable, Dict, Generator, Generic, Iterable, List, Optional, Set, TypeVar, Union, overload

from gokart.utils import map_flattenable_items

if sys.version_info < (3, 13):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this part is not needed?

from typing_extensions import deprecated
else:
Expand Down Expand Up @@ -362,11 +365,26 @@ def dump(self, obj: Any, target: Union[None, str, TargetOnKart] = None, custom_l
if self.fail_on_empty_dump and isinstance(obj, pd.DataFrame):
assert not obj.empty

@dataclass
class _RequiredTaskOutput:
task_name: str
output_path: str

_required_task_outputs = flatten(
map_flattenable_items(
lambda task: map_flattenable_items(
lambda output: _RequiredTaskOutput(task_name=task.get_task_family(), output_path=output.path()), task.output()
Comment thread
kitagry marked this conversation as resolved.
Outdated
),
self.requires(),
)
)
required_task_outputs = {r.task_name: r.output_path for r in _required_task_outputs}
self._get_output_target(target).dump(
obj,
lock_at_dump=self._lock_at_dump,
task_params=super().to_str_params(only_significant=True, only_public=True),
custom_labels=custom_labels,
required_task_outputs=required_task_outputs,
)

@staticmethod
Expand Down
17 changes: 16 additions & 1 deletion gokart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import sys
from io import BytesIO
from typing import Any, Iterable, Protocol, TypeVar, Union
from typing import Any, Callable, Iterable, Protocol, TypeVar, Union

import dill
import luigi
Expand Down Expand Up @@ -71,6 +71,21 @@ def flatten(targets: FlattenableItems[T]) -> list[T]:
return flat


K = TypeVar('K')


def map_flattenable_items(func: Callable[[T], K], items: FlattenableItems[T]) -> FlattenableItems[K]:
if isinstance(items, dict):
return {k: map_flattenable_items(func, v) for k, v in items.items()}
if isinstance(items, tuple):
return tuple(map_flattenable_items(func, i) for i in items)
if isinstance(items, str):
return func(items) # type: ignore
if isinstance(items, Iterable):
return [map_flattenable_items(func, i) for i in items]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return [map_flattenable_items(func, i) for i in items]
return map(lambda item: map_flattenable_items(func, i), items)

return func(items)


def load_dill_with_pandas_backward_compatibility(file: Union[FileLike, BytesIO]) -> Any:
"""Load binary dumped by dill with pandas backward compatibility.
pd.read_pickle can load binary dumped in backward pandas version, and also any objects dumped by pickle.
Expand Down
5 changes: 4 additions & 1 deletion test/test_gcs_obj_metadata_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ def test_mock_target_on_kart(self, mock_get_output_target):

task = _DummyTaskOnKart()
task.dump({'key': 'value'}, mock_target)
mock_target.dump.assert_called_once_with({'key': 'value'}, lock_at_dump=task._lock_at_dump, task_params={}, custom_labels=None)

mock_target.dump.assert_called_once_with(
{'key': 'value'}, lock_at_dump=task._lock_at_dump, task_params={}, custom_labels=None, required_task_outputs={}
)


if __name__ == '__main__':
Expand Down
18 changes: 17 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest

from gokart.utils import flatten
from gokart.utils import flatten, map_flattenable_items


class TestFlatten(unittest.TestCase):
Expand All @@ -18,3 +18,19 @@ def test_flatten_int(self):

def test_flatten_none(self):
self.assertEqual(flatten(None), [])


class TestMapFlatten(unittest.TestCase):
def test_map_flattenable_items(self):
self.assertEqual(map_flattenable_items(lambda x: str(x), {'a': 1, 'b': 2}), {'a': '1', 'b': '2'})
self.assertEqual(
map_flattenable_items(lambda x: str(x), (1, 2, 3, (4, 5, (6, 7, {'a': (8, 9, 0)})))),
('1', '2', '3', ('4', '5', ('6', '7', {'a': ('8', '9', '0')}))),
)
self.assertEqual(
map_flattenable_items(
lambda x: str(x),
{'a': [1, 2, 3, '4'], 'b': {'c': True, 'd': {'e': 5}}},
),
{'a': ['1', '2', '3', '4'], 'b': {'c': 'True', 'd': {'e': '5'}}},
)