Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions gokart/gcs_obj_metadata_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ def _get_patched_obj_metadata(
if not isinstance(metadata, dict):
logger.warning(f'metadata is not a dict: {metadata}, something wrong was happened when getting response when get bucket and object information.')
return metadata
if not task_params and not custom_labels:
return metadata
# Maximum size of metadata for each object is 8 KiB.
# [Link]: https://cloud.google.com/storage/quotas#objects
normalized_task_params_labels = GCSObjectMetadataClient._normalize_labels(task_params)
Expand All @@ -117,18 +115,17 @@ def _get_patched_obj_metadata(

@staticmethod
def _get_serialized_string(required_task_outputs: FlattenableItems[RequiredTaskOutput]) -> FlattenableItems[str]:
def _iterable_flatten(nested_list: Iterable) -> Iterable[str]:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah—since str is also an Iterable, calling:

_iterable_flatten([["a"]]) will never terminate.

The following implementation of _iterable_flatten works correctly:

def _iterable_flatten(nested_list: Iterable) -> Iterable[str]:
    for item in nested_list:
        if isinstance(item, str):
            yield item
        elif isinstance(item, Iterable):
            yield from _iterable_flatten(item)
        else:
            yield item

However, since recursion is only needed for _get_serialized_string, you removed iterable_flatten entirely, right? @kitagry

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes👍

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kitagry Please include that background in the PR comment for future reference.

for item in nested_list:
if isinstance(item, Iterable):
yield from _iterable_flatten(item)
else:
yield item

if isinstance(required_task_outputs, dict):
if isinstance(required_task_outputs, RequiredTaskOutput):
return required_task_outputs.serialize()
elif isinstance(required_task_outputs, dict):
return {k: GCSObjectMetadataClient._get_serialized_string(v) for k, v in required_task_outputs.items()}
if isinstance(required_task_outputs, Iterable):
return list(_iterable_flatten([GCSObjectMetadataClient._get_serialized_string(ro) for ro in required_task_outputs]))
return [required_task_outputs.serialize()]
elif isinstance(required_task_outputs, Iterable):
return [GCSObjectMetadataClient._get_serialized_string(ro) for ro in required_task_outputs]
else:
raise TypeError(
f'Unsupported type for required_task_outputs: {type(required_task_outputs)}. '
'It should be RequiredTaskOutput, dict, or iterable of RequiredTaskOutput.'
)

@staticmethod
def _merge_custom_labels_and_task_params_labels(
Expand Down
27 changes: 27 additions & 0 deletions test/test_gcs_obj_metadata_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import gokart
from gokart.gcs_obj_metadata_client import GCSObjectMetadataClient
from gokart.required_task_output import RequiredTaskOutput
from gokart.target import TargetOnKart


Expand Down Expand Up @@ -113,6 +114,32 @@ def test_get_patched_obj_metadata_with_conflicts(self):
self.assertEqual(got['created_by'], 'hoge fuga')
self.assertEqual(got['param1'], 'a' * 10)

def test_get_patched_obj_metadata_with_required_task_outputs(self):
got = GCSObjectMetadataClient._get_patched_obj_metadata(
{},
required_task_outputs=[
RequiredTaskOutput(task_name='task1', output_path='path/to/output1'),
],
)

self.assertIsInstance(got, dict)
self.assertIn('__required_task_outputs', got)
self.assertEqual(got['__required_task_outputs'], '[{"__gokart_task_name": "task1", "__gokart_output_path": "path/to/output1"}]')

def test_get_patched_obj_metadata_with_nested_required_task_outputs(self):
got = GCSObjectMetadataClient._get_patched_obj_metadata(
{},
required_task_outputs={
'nested_task': {'nest': RequiredTaskOutput(task_name='task1', output_path='path/to/output1')},
},
)

self.assertIsInstance(got, dict)
self.assertIn('__required_task_outputs', got)
self.assertEqual(
got['__required_task_outputs'], '{"nested_task": {"nest": {"__gokart_task_name": "task1", "__gokart_output_path": "path/to/output1"}}}'
)


class TestGokartTask(unittest.TestCase):
@patch.object(_DummyTaskOnKart, '_get_output_target')
Expand Down