Skip to content

Commit 0d4cb70

Browse files
rchen152The dataclass_array Authors
authored andcommitted
Silence some pytype errors.
PiperOrigin-RevId: 510514175
1 parent 29a85ba commit 0d4cb70

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

dataclass_array/array_dataclass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def replace(self: _DcT, **kwargs: Any) -> _DcT:
422422
}
423423

424424
# Create the new object
425-
new_self = dataclasses.replace(self, **init_kwargs)
425+
new_self = dataclasses.replace(self, **init_kwargs) # pytype: disable=wrong-arg-types # re-none
426426

427427
# TODO(epot): Could try to unify logic bellow with `tree_unflatten`
428428

@@ -677,7 +677,7 @@ def tree_flatten(self) -> tuple[list[DcOrArray], _TreeMetadata]:
677677
array_field_names=list(self._all_array_fields.keys()),
678678
non_array_field_kwargs={
679679
f.name: getattr(self, f.name)
680-
for f in dataclasses.fields(self)
680+
for f in dataclasses.fields(self) # pytype: disable=wrong-arg-types # re-none
681681
if f.name not in self._all_array_fields # pylint: disable=unsupported-membership-test
682682
},
683683
)
@@ -698,7 +698,7 @@ def tree_unflatten(
698698
)
699699
init_fields = {}
700700
non_init_fields = {}
701-
fields = {f.name: f for f in dataclasses.fields(cls)}
701+
fields = {f.name: f for f in dataclasses.fields(cls)} # pytype: disable=wrong-arg-types # re-none
702702
for k, v in metadata.non_array_field_kwargs.items():
703703
if fields[k].init:
704704
init_fields[k] = v

dataclass_array/utils/tree_utils_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def add_fn(x, y):
6666

6767
x0 = {'a': [1, 2]}
6868
x1 = {'a': [10, 20]}
69-
assert tree_utils.tree_map(add_fn, x0, x1) == {'a': [11, 22]}
69+
assert tree_utils.tree_map(add_fn, x0, x1) == {'a': [11, 22]} # pytype: disable=wrong-arg-types # re-none
7070

7171

7272
def test_tree_map_chex():
@@ -79,4 +79,4 @@ def add_fn(x, y):
7979

8080
x0 = {'a': A(x=1, y=2)}
8181
x1 = {'a': A(x=10, y=20)}
82-
assert tree_utils.tree_map(add_fn, x0, x1) == {'a': A(x=11, y=22)}
82+
assert tree_utils.tree_map(add_fn, x0, x1) == {'a': A(x=11, y=22)} # pytype: disable=wrong-arg-types # re-none

0 commit comments

Comments
 (0)