diff --git a/databricks/sdk/common/types/__init__.py b/databricks/sdk/common/types/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/databricks/sdk/common/types/fieldmask.py b/databricks/sdk/common/types/fieldmask.py new file mode 100644 index 000000000..739188080 --- /dev/null +++ b/databricks/sdk/common/types/fieldmask.py @@ -0,0 +1,39 @@ +class FieldMask(object): + """Class for FieldMask message type.""" + + # This is based on the base implementation from protobuf. + # https://pigweed.googlesource.com/third_party/github/protocolbuffers/protobuf/+/HEAD/python/google/protobuf/internal/field_mask.py + # The original implementation only works with proto generated classes. + # Since our classes are not generated from proto files, we need to implement it manually. + + def __init__(self, field_mask=None): + """Initializes the FieldMask.""" + if field_mask: + self.paths = field_mask + + def ToJsonString(self) -> str: + """Converts FieldMask to string.""" + return ",".join(self.paths) + + def FromJsonString(self, value: str) -> None: + """Converts string to FieldMask.""" + if not isinstance(value, str): + raise ValueError("FieldMask JSON value not a string: {!r}".format(value)) + if value: + self.paths = value.split(",") + else: + self.paths = [] + + def __eq__(self, other) -> bool: + """Check equality based on paths.""" + if not isinstance(other, FieldMask): + return False + return self.paths == other.paths + + def __hash__(self) -> int: + """Hash based on paths tuple.""" + return hash(tuple(self.paths)) + + def __repr__(self) -> str: + """String representation for debugging.""" + return f"FieldMask(paths={self.paths})" diff --git a/databricks/sdk/service/_internal.py b/databricks/sdk/service/_internal.py index 1e501e0e0..c48035cab 100644 --- a/databricks/sdk/service/_internal.py +++ b/databricks/sdk/service/_internal.py @@ -1,6 +1,11 @@ import datetime import urllib.parse -from typing import Callable, Dict, Generic, Optional, Type, TypeVar +from typing import Callable, Dict, Generic, List, Optional, Type, TypeVar + +from google.protobuf.duration_pb2 import Duration +from google.protobuf.timestamp_pb2 import Timestamp + +from databricks.sdk.common.types.fieldmask import FieldMask def _from_dict(d: Dict[str, any], field: str, cls: Type) -> any: @@ -46,6 +51,93 @@ def _escape_multi_segment_path_parameter(param: str) -> str: return urllib.parse.quote(param) +def _timestamp(d: Dict[str, any], field: str) -> Optional[Timestamp]: + """ + Helper function to convert a timestamp string to a Timestamp object. + It takes a dictionary and a field name, and returns a Timestamp object. + The field name is the key in the dictionary that contains the timestamp string. + """ + if field not in d or not d[field]: + return None + ts = Timestamp() + ts.FromJsonString(d[field]) + return ts + + +def _repeated_timestamp(d: Dict[str, any], field: str) -> Optional[List[Timestamp]]: + """ + Helper function to convert a list of timestamp strings to a list of Timestamp objects. + It takes a dictionary and a field name, and returns a list of Timestamp objects. + The field name is the key in the dictionary that contains the list of timestamp strings. + """ + if field not in d or not d[field]: + return None + result = [] + for v in d[field]: + ts = Timestamp() + ts.FromJsonString(v) + result.append(ts) + return result + + +def _duration(d: Dict[str, any], field: str) -> Optional[Duration]: + """ + Helper function to convert a duration string to a Duration object. + It takes a dictionary and a field name, and returns a Duration object. + The field name is the key in the dictionary that contains the duration string. + """ + if field not in d or not d[field]: + return None + dur = Duration() + dur.FromJsonString(d[field]) + return dur + + +def _repeated_duration(d: Dict[str, any], field: str) -> Optional[List[Duration]]: + """ + Helper function to convert a list of duration strings to a list of Duration objects. + It takes a dictionary and a field name, and returns a list of Duration objects. + The field name is the key in the dictionary that contains the list of duration strings. + """ + if field not in d or not d[field]: + return None + result = [] + for v in d[field]: + dur = Duration() + dur.FromJsonString(v) + result.append(dur) + return result + + +def _fieldmask(d: Dict[str, any], field: str) -> Optional[FieldMask]: + """ + Helper function to convert a fieldmask string to a FieldMask object. + It takes a dictionary and a field name, and returns a FieldMask object. + The field name is the key in the dictionary that contains the fieldmask string. + """ + if field not in d or not d[field]: + return None + fm = FieldMask() + fm.FromJsonString(d[field]) + return fm + + +def _repeated_fieldmask(d: Dict[str, any], field: str) -> Optional[List[FieldMask]]: + """ + Helper function to convert a list of fieldmask strings to a list of FieldMask objects. + It takes a dictionary and a field name, and returns a list of FieldMask objects. + The field name is the key in the dictionary that contains the list of fieldmask strings. + """ + if field not in d or not d[field]: + return None + result = [] + for v in d[field]: + fm = FieldMask() + fm.FromJsonString(v) + result.append(fm) + return result + + ReturnType = TypeVar("ReturnType") diff --git a/pyproject.toml b/pyproject.toml index b1314929a..bf7ee1b57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ classifiers = [ dependencies = [ "requests>=2.28.1,<3", "google-auth~=2.0", + "protobuf>=4.21.0,<7.0", ] [project.urls] diff --git a/tests/test_fieldmask.py b/tests/test_fieldmask.py new file mode 100644 index 000000000..3505bd90c --- /dev/null +++ b/tests/test_fieldmask.py @@ -0,0 +1,80 @@ +import pytest + +from databricks.sdk.common.types.fieldmask import FieldMask + + +@pytest.mark.parametrize( + "input_paths,expected_result,description", + [ + (["field1", "field2", "field3"], "field1,field2,field3", "basic list of paths"), + (["single_field"], "single_field", "single path"), + ([], "", "empty paths list"), + (["user.name", "user.email", "address.street"], "user.name,user.email,address.street", "nested field paths"), + ], +) +def test_to_json_string(input_paths, expected_result, description): + """Test ToJsonString with various path configurations.""" + field_mask = FieldMask() + field_mask.paths = input_paths + + result = field_mask.ToJsonString() + + assert result == expected_result + + +@pytest.mark.parametrize( + "input_string,expected_paths,description", + [ + ("field1,field2,field3", ["field1", "field2", "field3"], "basic comma-separated string"), + ("single_field", ["single_field"], "single field"), + ("", [], "empty string"), + ("user.name,user.email,address.street", ["user.name", "user.email", "address.street"], "nested field paths"), + ("field1, field2 , field3", ["field1", " field2 ", " field3"], "spaces around commas"), + ], +) +def test_from_json_string_success_cases(input_string, expected_paths, description): + """Test FromJsonString with various valid input strings.""" + field_mask = FieldMask() + + field_mask.FromJsonString(input_string) + + assert field_mask.paths == expected_paths + + +@pytest.mark.parametrize( + "invalid_input,expected_error_substring,description", + [ + (123, "FieldMask JSON value not a string: 123", "non-string integer input"), + (None, "FieldMask JSON value not a string: None", "None input"), + (["field1", "field2"], "FieldMask JSON value not a string:", "list input"), + ({"field": "value"}, "FieldMask JSON value not a string:", "dict input"), + ], +) +def test_from_json_string_error_cases(invalid_input, expected_error_substring, description): + """Test FromJsonString raises ValueError for invalid input types.""" + field_mask = FieldMask() + + with pytest.raises(ValueError) as exc_info: + field_mask.FromJsonString(invalid_input) + + assert expected_error_substring in str(exc_info.value) + + +@pytest.mark.parametrize( + "original_paths,description", + [ + (["user.name", "user.email", "profile.settings"], "multiple nested fields"), + (["single_field"], "single field"), + ([], "empty paths"), + ], +) +def test_roundtrip_conversion(original_paths, description): + """Test that ToJsonString and FromJsonString are inverse operations.""" + field_mask = FieldMask() + field_mask.paths = original_paths + + # Convert to string and back. + json_string = field_mask.ToJsonString() + field_mask.FromJsonString(json_string) + + assert field_mask.paths == original_paths diff --git a/tests/test_internal.py b/tests/test_internal.py index d432b5cb4..b0417ec64 100644 --- a/tests/test_internal.py +++ b/tests/test_internal.py @@ -1,9 +1,15 @@ from dataclasses import dataclass from enum import Enum +import pytest +from google.protobuf.duration_pb2 import Duration +from google.protobuf.timestamp_pb2 import Timestamp + +from databricks.sdk.common.types.fieldmask import FieldMask from databricks.sdk.service._internal import ( - _enum, _escape_multi_segment_path_parameter, _from_dict, _repeated_dict, - _repeated_enum) + _duration, _enum, _escape_multi_segment_path_parameter, _fieldmask, + _from_dict, _repeated_dict, _repeated_duration, _repeated_enum, + _repeated_fieldmask, _repeated_timestamp, _timestamp) class A(Enum): @@ -52,3 +58,154 @@ def test_escape_multi_segment_path_parameter(): assert _escape_multi_segment_path_parameter("a/b") == "a/b" assert _escape_multi_segment_path_parameter("a?b") == "a%3Fb" assert _escape_multi_segment_path_parameter("a#b") == "a%23b" + + +@pytest.mark.parametrize( + "input_dict,field_name,expected_timestamp,description", + [ + ( + {"field": "2023-01-01T12:00:00Z"}, + "field", + Timestamp(seconds=1672574400), + "valid timestamp", + ), + ({}, "field", None, "missing field"), + ({"field": None}, "field", None, "None value"), + ({"field": ""}, "field", None, "empty value"), + ], +) +def test_timestamp(input_dict, field_name, expected_timestamp, description): + """Test _timestamp function with various input scenarios.""" + result = _timestamp(input_dict, field_name) + + if expected_timestamp is None: + assert result is None + else: + assert isinstance(result, Timestamp) + assert result == expected_timestamp + + +@pytest.mark.parametrize( + "input_dict,field_name,expected_timestamp_list,description", + [ + ( + {"field": ["2023-01-01T12:00:00Z", "2023-01-02T12:00:00Z"]}, + "field", + [Timestamp(seconds=1672574400), Timestamp(seconds=1672660800)], + "valid repeated timestamps", + ), + ({}, "field", [], "missing field"), + ({"field": None}, "field", [], "None value"), + ({"field": []}, "field", [], "empty list"), + ], +) +def test_repeated_timestamp(input_dict, field_name, expected_timestamp_list, description): + """Test _repeated_timestamp function with various input scenarios.""" + result = _repeated_timestamp(input_dict, field_name) + + if expected_timestamp_list is None or len(expected_timestamp_list) == 0: + assert result is None + else: + assert len(result) == len(expected_timestamp_list) + assert all(isinstance(ts, Timestamp) for ts in result) + for i, expected_timestamp in enumerate(expected_timestamp_list): + assert result[i] == expected_timestamp + + +@pytest.mark.parametrize( + "input_dict,field_name,expected_duration,description", + [ + ({"field": "3600s"}, "field", Duration(seconds=3600), "valid duration"), + ({}, "field", None, "missing field"), + ({"field": None}, "field", None, "None value"), + ({"field": ""}, "field", None, "empty value"), + ], +) +def test_duration(input_dict, field_name, expected_duration, description): + """Test _duration function with various input scenarios.""" + result = _duration(input_dict, field_name) + + if expected_duration is None: + assert result is None + else: + assert isinstance(result, Duration) + assert result == expected_duration + + +@pytest.mark.parametrize( + "input_dict,field_name,expected_duration_list,description", + [ + ( + {"field": ["3600s", "7200s"]}, + "field", + [Duration(seconds=3600), Duration(seconds=7200)], + "valid repeated durations", + ), + ({}, "field", [], "missing field"), + ({"field": None}, "field", None, "None value"), + ({"field": []}, "field", [], "empty list"), + ], +) +def test_repeated_duration(input_dict, field_name, expected_duration_list, description): + """Test _repeated_duration function with various input scenarios.""" + result = _repeated_duration(input_dict, field_name) + + if expected_duration_list is None or len(expected_duration_list) == 0: + assert result is None + else: + assert len(result) == len(expected_duration_list) + assert all(isinstance(dur, Duration) for dur in result) + for i, expected_duration in enumerate(expected_duration_list): + assert result[i] == expected_duration + + +@pytest.mark.parametrize( + "input_dict,field_name,expected_fieldmask,description", + [ + ( + {"field": "path1,path2"}, + "field", + FieldMask(field_mask=["path1", "path2"]), + "valid fieldmask", + ), + ({}, "field", None, "missing field"), + ({"field": None}, "field", None, "None value"), + ({"field": ""}, "field", None, "empty value"), + ], +) +def test_fieldmask(input_dict, field_name, expected_fieldmask, description): + """Test _fieldmask function with various input scenarios.""" + result = _fieldmask(input_dict, field_name) + + if expected_fieldmask is None: + assert result is None + else: + assert isinstance(result, FieldMask) + assert result == expected_fieldmask + + +@pytest.mark.parametrize( + "input_dict,field_name,expected_fieldmask_list,description", + [ + ( + {"field": ["path1,path2", "path3,path4"]}, + "field", + [FieldMask(field_mask=["path1", "path2"]), FieldMask(field_mask=["path3", "path4"])], + "valid repeated fieldmasks", + ), + ({}, "field", [], "missing field"), + ({"field": None}, "field", None, "None value"), + ({"field": []}, "field", [], "empty list"), + ], +) +def test_repeated_fieldmask(input_dict, field_name, expected_fieldmask_list, description): + """Test _repeated_fieldmask function with various input scenarios.""" + result = _repeated_fieldmask(input_dict, field_name) + + if expected_fieldmask_list is None or len(expected_fieldmask_list) == 0: + assert result is None + else: + assert len(result) == len(expected_fieldmask_list) + assert all(isinstance(fm, FieldMask) for fm in result) + for i, expected_fieldmask in enumerate(expected_fieldmask_list): + assert result[i] == expected_fieldmask