Skip to content

Commit 3ae423b

Browse files
committed
address comments
1 parent f911f5f commit 3ae423b

File tree

4 files changed

+95
-32
lines changed

4 files changed

+95
-32
lines changed
Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,37 @@
11
class FieldMask(object):
22
"""Class for FieldMask message type."""
3+
# This is based on the base implementation from protobuf.
4+
# https://pigweed.googlesource.com/third_party/github/protocolbuffers/protobuf/+/HEAD/python/google/protobuf/internal/field_mask.py
5+
# The original implementation only works with proto generated classes.
6+
# Since our classes are not generated from proto files, we need to implement it manually.
37

4-
def ToJsonString(self):
8+
def __init__(self):
9+
"""Initialize FieldMask with empty paths."""
10+
self.paths = []
11+
12+
def ToJsonString(self) -> str:
513
"""Converts FieldMask to string."""
614
return ",".join(self.paths)
715

8-
def FromJsonString(self, value):
16+
def FromJsonString(self, value: str) -> None:
917
"""Converts string to FieldMask."""
1018
if not isinstance(value, str):
1119
raise ValueError("FieldMask JSON value not a string: {!r}".format(value))
1220
if value:
1321
self.paths = value.split(",")
1422
else:
1523
self.paths = []
24+
25+
def __eq__(self, other) -> bool:
26+
"""Check equality based on paths."""
27+
if not isinstance(other, FieldMask):
28+
return False
29+
return self.paths == other.paths
30+
31+
def __hash__(self) -> int:
32+
"""Hash based on paths tuple."""
33+
return hash(tuple(self.paths))
34+
35+
def __repr__(self) -> str:
36+
"""String representation for debugging."""
37+
return f"FieldMask(paths={self.paths})"

databricks/sdk/service/_internal.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ def _escape_multi_segment_path_parameter(param: str) -> str:
5252

5353

5454
def _timestamp(d: Dict[str, any], field: str) -> Optional[Timestamp]:
55+
"""
56+
Helper function to convert a timestamp string to a Timestamp object.
57+
It takes a dictionary and a field name, and returns a Timestamp object.
58+
The field name is the key in the dictionary that contains the timestamp string.
59+
"""
5560
if field not in d or not d[field]:
5661
return None
5762
ts = Timestamp()
@@ -60,6 +65,11 @@ def _timestamp(d: Dict[str, any], field: str) -> Optional[Timestamp]:
6065

6166

6267
def _repeated_timestamp(d: Dict[str, any], field: str) -> Optional[List[Timestamp]]:
68+
"""
69+
Helper function to convert a list of timestamp strings to a list of Timestamp objects.
70+
It takes a dictionary and a field name, and returns a list of Timestamp objects.
71+
The field name is the key in the dictionary that contains the list of timestamp strings.
72+
"""
6373
if field not in d or not d[field]:
6474
return None
6575
result = []
@@ -71,6 +81,11 @@ def _repeated_timestamp(d: Dict[str, any], field: str) -> Optional[List[Timestam
7181

7282

7383
def _duration(d: Dict[str, any], field: str) -> Optional[Duration]:
84+
"""
85+
Helper function to convert a duration string to a Duration object.
86+
It takes a dictionary and a field name, and returns a Duration object.
87+
The field name is the key in the dictionary that contains the duration string.
88+
"""
7489
if field not in d or not d[field]:
7590
return None
7691
dur = Duration()
@@ -79,6 +94,11 @@ def _duration(d: Dict[str, any], field: str) -> Optional[Duration]:
7994

8095

8196
def _repeated_duration(d: Dict[str, any], field: str) -> Optional[List[Duration]]:
97+
"""
98+
Helper function to convert a list of duration strings to a list of Duration objects.
99+
It takes a dictionary and a field name, and returns a list of Duration objects.
100+
The field name is the key in the dictionary that contains the list of duration strings.
101+
"""
82102
if field not in d or not d[field]:
83103
return None
84104
result = []
@@ -90,6 +110,11 @@ def _repeated_duration(d: Dict[str, any], field: str) -> Optional[List[Duration]
90110

91111

92112
def _fieldmask(d: Dict[str, any], field: str) -> Optional[FieldMask]:
113+
"""
114+
Helper function to convert a fieldmask string to a FieldMask object.
115+
It takes a dictionary and a field name, and returns a FieldMask object.
116+
The field name is the key in the dictionary that contains the fieldmask string.
117+
"""
93118
if field not in d or not d[field]:
94119
return None
95120
fm = FieldMask()
@@ -98,6 +123,11 @@ def _fieldmask(d: Dict[str, any], field: str) -> Optional[FieldMask]:
98123

99124

100125
def _repeated_fieldmask(d: Dict[str, any], field: str) -> Optional[List[FieldMask]]:
126+
"""
127+
Helper function to convert a list of fieldmask strings to a list of FieldMask objects.
128+
It takes a dictionary and a field name, and returns a list of FieldMask objects.
129+
The field name is the key in the dictionary that contains the list of fieldmask strings.
130+
"""
101131
if field not in d or not d[field]:
102132
return None
103133
result = []

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ classifiers = [
2727
dependencies = [
2828
"requests>=2.28.1,<3",
2929
"google-auth~=2.0",
30+
"protobuf>=4.21.0,<7.0",
3031
]
3132

3233
[project.urls]

tests/test_internal.py

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -77,35 +77,38 @@ def test_timestamp(input_dict, field_name, expected_result, expected_json, descr
7777
assert result is None
7878
else:
7979
assert isinstance(result, Timestamp)
80-
assert result.ToJsonString() == expected_json
80+
ts = Timestamp()
81+
ts.FromJsonString(expected_json)
82+
assert result == ts
8183

8284

8385
@pytest.mark.parametrize(
84-
"input_dict,field_name,expected_length,expected_json_list,description",
86+
"input_dict,field_name,expected_json_list,description",
8587
[
8688
(
8789
{"field": ["2023-01-01T12:00:00Z", "2023-01-02T12:00:00Z"]},
8890
"field",
89-
2,
9091
["2023-01-01T12:00:00Z", "2023-01-02T12:00:00Z"],
9192
"valid repeated timestamps",
9293
),
93-
({}, "field", None, None, "missing field"),
94-
({"field": None}, "field", None, None, "None value"),
95-
({"field": []}, "field", None, None, "empty list"),
94+
({}, "field", [], "missing field"),
95+
({"field": None}, "field", None, "None value"),
96+
({"field": []}, "field", [], "empty list"),
9697
],
9798
)
98-
def test_repeated_timestamp(input_dict, field_name, expected_length, expected_json_list, description):
99+
def test_repeated_timestamp(input_dict, field_name, expected_json_list, description):
99100
"""Test _repeated_timestamp function with various input scenarios."""
100101
result = _repeated_timestamp(input_dict, field_name)
101102

102-
if expected_length is None:
103+
if expected_json_list is None or len(expected_json_list) == 0:
103104
assert result is None
104105
else:
105-
assert len(result) == expected_length
106+
assert len(result) == len(expected_json_list)
106107
assert all(isinstance(ts, Timestamp) for ts in result)
107108
for i, expected_json in enumerate(expected_json_list):
108-
assert result[i].ToJsonString() == expected_json
109+
ts = Timestamp()
110+
ts.FromJsonString(expected_json)
111+
assert result[i] == ts
109112

110113

111114
@pytest.mark.parametrize(
@@ -125,29 +128,33 @@ def test_duration(input_dict, field_name, expected_result, expected_json, descri
125128
assert result is None
126129
else:
127130
assert isinstance(result, Duration)
128-
assert result.ToJsonString() == expected_json
131+
dur = Duration()
132+
dur.FromJsonString(expected_json)
133+
assert result == dur
129134

130135

131136
@pytest.mark.parametrize(
132-
"input_dict,field_name,expected_length,expected_json_list,description",
137+
"input_dict,field_name,expected_json_list,description",
133138
[
134-
({"field": ["3600s", "7200s"]}, "field", 2, ["3600s", "7200s"], "valid repeated durations"),
135-
({}, "field", None, None, "missing field"),
136-
({"field": None}, "field", None, None, "None value"),
137-
({"field": []}, "field", None, None, "empty list"),
139+
({"field": ["3600s", "7200s"]}, "field", ["3600s", "7200s"], "valid repeated durations"),
140+
({}, "field", [], "missing field"),
141+
({"field": None}, "field", None, "None value"),
142+
({"field": []}, "field", [], "empty list"),
138143
],
139144
)
140-
def test_repeated_duration(input_dict, field_name, expected_length, expected_json_list, description):
145+
def test_repeated_duration(input_dict, field_name, expected_json_list, description):
141146
"""Test _repeated_duration function with various input scenarios."""
142147
result = _repeated_duration(input_dict, field_name)
143148

144-
if expected_length is None:
149+
if expected_json_list is None or len(expected_json_list) == 0:
145150
assert result is None
146151
else:
147-
assert len(result) == expected_length
152+
assert len(result) == len(expected_json_list)
148153
assert all(isinstance(dur, Duration) for dur in result)
149154
for i, expected_json in enumerate(expected_json_list):
150-
assert result[i].ToJsonString() == expected_json
155+
dur = Duration()
156+
dur.FromJsonString(expected_json)
157+
assert result[i] == dur
151158

152159

153160
@pytest.mark.parametrize(
@@ -167,32 +174,35 @@ def test_fieldmask(input_dict, field_name, expected_result, expected_json, descr
167174
assert result is None
168175
else:
169176
assert isinstance(result, FieldMask)
170-
assert result.ToJsonString() == expected_json
177+
fm = FieldMask()
178+
fm.FromJsonString(expected_json)
179+
assert result == fm
171180

172181

173182
@pytest.mark.parametrize(
174-
"input_dict,field_name,expected_length,expected_json_list,description",
183+
"input_dict,field_name,expected_json_list,description",
175184
[
176185
(
177186
{"field": ["path1,path2", "path3,path4"]},
178187
"field",
179-
2,
180188
["path1,path2", "path3,path4"],
181189
"valid repeated fieldmasks",
182190
),
183-
({}, "field", None, None, "missing field"),
184-
({"field": None}, "field", None, None, "None value"),
185-
({"field": []}, "field", None, None, "empty list"),
191+
({}, "field", [], "missing field"),
192+
({"field": None}, "field", None, "None value"),
193+
({"field": []}, "field", [], "empty list"),
186194
],
187195
)
188-
def test_repeated_fieldmask(input_dict, field_name, expected_length, expected_json_list, description):
196+
def test_repeated_fieldmask(input_dict, field_name, expected_json_list, description):
189197
"""Test _repeated_fieldmask function with various input scenarios."""
190198
result = _repeated_fieldmask(input_dict, field_name)
191199

192-
if expected_length is None:
200+
if expected_json_list is None or len(expected_json_list) == 0:
193201
assert result is None
194202
else:
195-
assert len(result) == expected_length
203+
assert len(result) == len(expected_json_list)
196204
assert all(isinstance(fm, FieldMask) for fm in result)
197205
for i, expected_json in enumerate(expected_json_list):
198-
assert result[i].ToJsonString() == expected_json
206+
fm = FieldMask()
207+
fm.FromJsonString(expected_json)
208+
assert result[i] == fm

0 commit comments

Comments
 (0)