Skip to content

Commit 834115b

Browse files
armanalrohitgr7
andauthored
update apply_to_collections to support dataclass inputs (#11889)
Co-authored-by: rohitgr7 <[email protected]>
1 parent 97121a5 commit 834115b

File tree

3 files changed

+172
-81
lines changed

3 files changed

+172
-81
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
172172

173173
- Added support for Habana Accelerator (HPU) ([#11808](https://github.com/PyTorchLightning/pytorch-lightning/pull/11808))
174174

175+
- Added support for dataclasses in `apply_to_collections` ([#11889](https://github.com/PyTorchLightning/pytorch-lightning/pull/11889))
176+
177+
175178

176179
### Changed
177180

pytorch_lightning/utilities/apply_func.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def apply_to_collection(
138138
result = deepcopy(data, memo=memo)
139139
# apply function to each field
140140
for field_name, (field_value, field_init) in fields.items():
141+
v = None
141142
if field_init:
142143
v = apply_to_collection(
143144
field_value,
@@ -215,13 +216,65 @@ def apply_to_collections(
215216
is_namedtuple = _is_namedtuple(data1)
216217
is_sequence = isinstance(data1, Sequence) and not isinstance(data1, str)
217218
if (is_namedtuple or is_sequence) and data2 is not None:
218-
assert len(data1) == len(data2), "Sequence collections have different sizes"
219+
assert len(data1) == len(data2), "Sequence collections have different sizes."
219220
out = [
220221
apply_to_collections(v1, v2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
221222
for v1, v2 in zip(data1, data2)
222223
]
223224
return elem_type(*out) if is_namedtuple else elem_type(out)
224225

226+
if _is_dataclass_instance(data1) and data2 is not None:
227+
if not _is_dataclass_instance(data2):
228+
raise TypeError(
229+
"Expected inputs to be dataclasses of the same type or to have identical fields"
230+
f" but got input 1 of type {type(data1)} and input 2 of type {type(data2)}."
231+
)
232+
if not (
233+
len(dataclasses.fields(data1)) == len(dataclasses.fields(data2))
234+
and all(map(lambda f1, f2: isinstance(f1, type(f2)), dataclasses.fields(data1), dataclasses.fields(data2)))
235+
):
236+
raise TypeError("Dataclasses fields do not match.")
237+
# make a deepcopy of the data,
238+
# but do not deepcopy mapped fields since the computation would
239+
# be wasted on values that likely get immediately overwritten
240+
data = [data1, data2]
241+
fields: List[dict] = [{}, {}]
242+
memo: dict = {}
243+
for i in range(len(data)):
244+
for field in dataclasses.fields(data[i]):
245+
field_value = getattr(data[i], field.name)
246+
fields[i][field.name] = (field_value, field.init)
247+
if i == 0:
248+
memo[id(field_value)] = field_value
249+
250+
result = deepcopy(data1, memo=memo)
251+
252+
# apply function to each field
253+
for ((field_name, (field_value1, field_init1)), (_, (field_value2, field_init2))) in zip(
254+
fields[0].items(), fields[1].items()
255+
):
256+
v = None
257+
if field_init1 and field_init2:
258+
v = apply_to_collections(
259+
field_value1,
260+
field_value2,
261+
dtype,
262+
function,
263+
*args,
264+
wrong_dtype=wrong_dtype,
265+
**kwargs,
266+
)
267+
if not field_init1 or not field_init2 or v is None: # retain old value
268+
return apply_to_collection(data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
269+
try:
270+
setattr(result, field_name, v)
271+
except dataclasses.FrozenInstanceError as e:
272+
raise MisconfigurationException(
273+
"A frozen dataclass was passed to `apply_to_collections` but this is not allowed."
274+
" HINT: is your batch a frozen dataclass?"
275+
) from e
276+
return result
277+
225278
return apply_to_collection(data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
226279

227280

tests/utilities/test_apply_func.py

Lines changed: 115 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -25,89 +25,94 @@
2525
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2626

2727

28+
@dataclasses.dataclass
29+
class Feature:
30+
input_ids: torch.Tensor
31+
segment_ids: np.ndarray
32+
33+
def __eq__(self, o: object) -> bool:
34+
if not isinstance(o, Feature):
35+
return NotImplemented
36+
37+
return torch.equal(self.input_ids, o.input_ids) and np.equal(self.segment_ids, o.segment_ids).all()
38+
39+
40+
@dataclasses.dataclass
41+
class ModelExample:
42+
example_ids: List[str]
43+
feature: Feature
44+
label: torch.Tensor
45+
some_constant: int = dataclasses.field(init=False)
46+
47+
def __post_init__(self):
48+
self.some_constant = 7
49+
50+
def __eq__(self, o: object) -> bool:
51+
if not isinstance(o, ModelExample):
52+
return NotImplemented
53+
54+
return (
55+
self.example_ids == o.example_ids
56+
and self.feature == o.feature
57+
and torch.equal(self.label, o.label)
58+
and self.some_constant == o.some_constant
59+
)
60+
61+
62+
@dataclasses.dataclass
63+
class WithClassVar:
64+
class_var: ClassVar[int] = 0
65+
dummy: Any
66+
67+
def __eq__(self, o: object) -> bool:
68+
if not isinstance(o, WithClassVar):
69+
return NotImplemented
70+
elif isinstance(self.dummy, torch.Tensor):
71+
return torch.equal(self.dummy, o.dummy)
72+
73+
return self.dummy == o.dummy
74+
75+
76+
@dataclasses.dataclass
77+
class WithInitVar:
78+
dummy: Any
79+
override: InitVar[Optional[Any]] = None
80+
81+
def __post_init__(self, override: Optional[Any]):
82+
if override is not None:
83+
self.dummy = override
84+
85+
def __eq__(self, o: object) -> bool:
86+
if not isinstance(o, WithInitVar):
87+
return NotImplemented
88+
elif isinstance(self.dummy, torch.Tensor):
89+
return torch.equal(self.dummy, o.dummy)
90+
91+
return self.dummy == o.dummy
92+
93+
94+
@dataclasses.dataclass
95+
class WithClassAndInitVar:
96+
class_var: ClassVar[torch.Tensor] = torch.tensor(0)
97+
dummy: Any
98+
override: InitVar[Optional[Any]] = torch.tensor(1)
99+
100+
def __post_init__(self, override: Optional[Any]):
101+
if override is not None:
102+
self.dummy = override
103+
104+
def __eq__(self, o: object) -> bool:
105+
if not isinstance(o, WithClassAndInitVar):
106+
return NotImplemented
107+
elif isinstance(self.dummy, torch.Tensor):
108+
return torch.equal(self.dummy, o.dummy)
109+
110+
return self.dummy == o.dummy
111+
112+
28113
def test_recursive_application_to_collection():
29114
ntc = namedtuple("Foo", ["bar"])
30115

31-
@dataclasses.dataclass
32-
class Feature:
33-
input_ids: torch.Tensor
34-
segment_ids: np.ndarray
35-
36-
def __eq__(self, o: object) -> bool:
37-
if not isinstance(o, Feature):
38-
return NotImplemented
39-
else:
40-
return torch.equal(self.input_ids, o.input_ids) and np.equal(self.segment_ids, o.segment_ids).all()
41-
42-
@dataclasses.dataclass
43-
class ModelExample:
44-
example_ids: List[str]
45-
feature: Feature
46-
label: torch.Tensor
47-
some_constant: int = dataclasses.field(init=False)
48-
49-
def __post_init__(self):
50-
self.some_constant = 7
51-
52-
def __eq__(self, o: object) -> bool:
53-
if not isinstance(o, ModelExample):
54-
return NotImplemented
55-
else:
56-
return (
57-
self.example_ids == o.example_ids
58-
and self.feature == o.feature
59-
and torch.equal(self.label, o.label)
60-
and self.some_constant == o.some_constant
61-
)
62-
63-
@dataclasses.dataclass
64-
class WithClassVar:
65-
class_var: ClassVar[int] = 0
66-
dummy: Any
67-
68-
def __eq__(self, o: object) -> bool:
69-
if not isinstance(o, WithClassVar):
70-
return NotImplemented
71-
elif isinstance(self.dummy, torch.Tensor):
72-
return torch.equal(self.dummy, o.dummy)
73-
else:
74-
return self.dummy == o.dummy
75-
76-
@dataclasses.dataclass
77-
class WithInitVar:
78-
dummy: Any
79-
override: InitVar[Optional[Any]] = None
80-
81-
def __post_init__(self, override: Optional[Any]):
82-
if override is not None:
83-
self.dummy = override
84-
85-
def __eq__(self, o: object) -> bool:
86-
if not isinstance(o, WithInitVar):
87-
return NotImplemented
88-
elif isinstance(self.dummy, torch.Tensor):
89-
return torch.equal(self.dummy, o.dummy)
90-
else:
91-
return self.dummy == o.dummy
92-
93-
@dataclasses.dataclass
94-
class WithClassAndInitVar:
95-
class_var: ClassVar[torch.Tensor] = torch.tensor(0)
96-
dummy: Any
97-
override: InitVar[Optional[Any]] = torch.tensor(1)
98-
99-
def __post_init__(self, override: Optional[Any]):
100-
if override is not None:
101-
self.dummy = override
102-
103-
def __eq__(self, o: object) -> bool:
104-
if not isinstance(o, WithClassAndInitVar):
105-
return NotImplemented
106-
elif isinstance(self.dummy, torch.Tensor):
107-
return torch.equal(self.dummy, o.dummy)
108-
else:
109-
return self.dummy == o.dummy
110-
111116
model_example = ModelExample(
112117
example_ids=["i-1", "i-2", "i-3"],
113118
feature=Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])),
@@ -303,6 +308,36 @@ def fn(a, b):
303308
assert reduced is None
304309

305310

311+
def test_apply_to_collections_dataclass():
312+
to_reduce_1 = Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0]))
313+
to_reduce_2 = Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0]))
314+
315+
def fn(a, b):
316+
return a + b
317+
318+
reduced = apply_to_collections(to_reduce_1, to_reduce_2, (torch.Tensor, numbers.Number, np.ndarray), fn)
319+
320+
assert reduced == Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0]))
321+
322+
model_example = ModelExample(
323+
example_ids=["i-1", "i-2", "i-3"],
324+
feature=to_reduce_1,
325+
label=torch.tensor([7.0, 8.0, 9.0]),
326+
)
327+
328+
# different types
329+
with pytest.raises(TypeError, match="Expected inputs to be dataclasses of the same type"):
330+
apply_to_collections(to_reduce_1, [1, 2], (torch.Tensor, numbers.Number, np.ndarray), fn)
331+
332+
# unmatched fields
333+
with pytest.raises(TypeError, match="Dataclasses fields do not match"):
334+
apply_to_collections(to_reduce_1, model_example, (torch.Tensor, numbers.Number, np.ndarray), fn)
335+
336+
classvar = WithClassVar(torch.arange(3)) # dataclass with same number but different type of fields
337+
with pytest.raises(TypeError, match="Dataclasses fields do not match"):
338+
apply_to_collections(to_reduce_1, classvar, (torch.Tensor, numbers.Number, np.ndarray), fn)
339+
340+
306341
def test_apply_to_collection_frozen_dataclass():
307342
@dataclasses.dataclass(frozen=True)
308343
class Foo:

0 commit comments

Comments
 (0)