1414import dataclasses
1515import numbers
1616from collections import defaultdict , namedtuple , OrderedDict
17- from typing import List
17+ from dataclasses import InitVar
18+ from typing import Any , ClassVar , List , Optional
1819
1920import numpy as np
2021import pytest
@@ -31,6 +32,12 @@ class Feature:
3132 input_ids : torch .Tensor
3233 segment_ids : np .ndarray
3334
35+ def __eq__ (self , o : object ) -> bool :
36+ if not isinstance (o , Feature ):
37+ return NotImplemented
38+ else :
39+ return torch .equal (self .input_ids , o .input_ids ) and np .equal (self .segment_ids , o .segment_ids ).all ()
40+
3441 @dataclasses .dataclass
3542 class ModelExample :
3643 example_ids : List [str ]
@@ -41,6 +48,71 @@ class ModelExample:
4148 def __post_init__ (self ):
4249 self .some_constant = 7
4350
51+ def __eq__ (self , o : object ) -> bool :
52+ if not isinstance (o , ModelExample ):
53+ return NotImplemented
54+ else :
55+ return (
56+ self .example_ids == o .example_ids
57+ and self .feature == o .feature
58+ and torch .equal (self .label , o .label )
59+ and self .some_constant == o .some_constant
60+ )
61+
62+ @dataclasses .dataclass
63+ class WithClassVar :
64+ class_var : ClassVar [int ] = 0
65+ dummy : Any
66+
67+ def __eq__ (self , o : object ) -> bool :
68+ if not isinstance (o , WithClassVar ):
69+ return NotImplemented
70+ elif isinstance (self .dummy , torch .Tensor ):
71+ return torch .equal (self .dummy , o .dummy )
72+ else :
73+ return self .dummy == o .dummy
74+
75+ @dataclasses .dataclass
76+ class WithInitVar :
77+ dummy : Any
78+ override : InitVar [Optional [Any ]] = None
79+
80+ def __post_init__ (self , override : Optional [Any ]):
81+ if override is not None :
82+ self .dummy = override
83+
84+ def __eq__ (self , o : object ) -> bool :
85+ if not isinstance (o , WithInitVar ):
86+ return NotImplemented
87+ elif isinstance (self .dummy , torch .Tensor ):
88+ return torch .equal (self .dummy , o .dummy )
89+ else :
90+ return self .dummy == o .dummy
91+
92+ @dataclasses .dataclass
93+ class WithClassAndInitVar :
94+ class_var : ClassVar [torch .Tensor ] = torch .tensor (0 )
95+ dummy : Any
96+ override : InitVar [Optional [Any ]] = torch .tensor (1 )
97+
98+ def __post_init__ (self , override : Optional [Any ]):
99+ if override is not None :
100+ self .dummy = override
101+
102+ def __eq__ (self , o : object ) -> bool :
103+ if not isinstance (o , WithClassAndInitVar ):
104+ return NotImplemented
105+ elif isinstance (self .dummy , torch .Tensor ):
106+ return torch .equal (self .dummy , o .dummy )
107+ else :
108+ return self .dummy == o .dummy
109+
110+ model_example = ModelExample (
111+ example_ids = ["i-1" , "i-2" , "i-3" ],
112+ feature = Feature (input_ids = torch .tensor ([1.0 , 2.0 , 3.0 ]), segment_ids = np .array ([4.0 , 5.0 , 6.0 ])),
113+ label = torch .tensor ([7.0 , 8.0 , 9.0 ]),
114+ )
115+
44116 to_reduce = {
45117 "a" : torch .tensor ([1.0 ]), # Tensor
46118 "b" : [torch .tensor ([2.0 ])], # list
@@ -50,13 +122,18 @@ def __post_init__(self):
50122 "f" : "this_is_a_dummy_str" , # string
51123 "g" : 12.0 , # number
52124 "h" : Feature (input_ids = torch .tensor ([1.0 , 2.0 , 3.0 ]), segment_ids = np .array ([4.0 , 5.0 , 6.0 ])), # dataclass
53- "i" : ModelExample (
54- example_ids = ["i-1" , "i-2" , "i-3" ],
55- feature = Feature (input_ids = torch .tensor ([1.0 , 2.0 , 3.0 ]), segment_ids = np .array ([4.0 , 5.0 , 6.0 ])),
56- label = torch .tensor ([7.0 , 8.0 , 9.0 ]),
57- ), # nested dataclass
125+ "i" : model_example , # nested dataclass
126+ "j" : WithClassVar (torch .arange (3 )), # dataclass with class variable
127+ "k" : WithInitVar ("this_gets_overridden" , torch .tensor ([2.0 ])), # dataclass with init-only variable
128+ "l" : WithClassAndInitVar (model_example , None ), # nested dataclass with class and init-only variables
58129 }
59130
131+ model_example_result = ModelExample (
132+ example_ids = ["i-1" , "i-2" , "i-3" ],
133+ feature = Feature (input_ids = torch .tensor ([2.0 , 4.0 , 6.0 ]), segment_ids = np .array ([8.0 , 10.0 , 12.0 ])),
134+ label = torch .tensor ([14.0 , 16.0 , 18.0 ]),
135+ )
136+
60137 expected_result = {
61138 "a" : torch .tensor ([2.0 ]),
62139 "b" : [torch .tensor ([4.0 ])],
@@ -66,32 +143,31 @@ def __post_init__(self):
66143 "f" : "this_is_a_dummy_str" ,
67144 "g" : 24.0 ,
68145 "h" : Feature (input_ids = torch .tensor ([2.0 , 4.0 , 6.0 ]), segment_ids = np .array ([8.0 , 10.0 , 12.0 ])),
69- "i" : ModelExample (
70- example_ids = ["i-1" , "i-2" , "i-3" ],
71- feature = Feature (input_ids = torch .tensor ([2.0 , 4.0 , 6.0 ]), segment_ids = np .array ([8.0 , 10.0 , 12.0 ])),
72- label = torch .tensor ([14.0 , 16.0 , 18.0 ]),
73- ),
146+ "i" : model_example_result ,
147+ "j" : WithClassVar (torch .arange (0 , 6 , 2 )),
148+ "k" : WithInitVar (torch .tensor ([4.0 ])),
149+ "l" : WithClassAndInitVar (model_example_result , None ),
74150 }
75151
76152 reduced = apply_to_collection (to_reduce , (torch .Tensor , numbers .Number , np .ndarray ), lambda x : x * 2 )
77153
78- assert isinstance (reduced , dict ), " Type Consistency of dict not preserved"
154+ assert isinstance (reduced , dict ), "Type Consistency of dict not preserved"
79155 assert all (x in reduced for x in to_reduce ), "Not all entries of the dict were preserved"
80156 assert all (
81157 isinstance (reduced [k ], type (expected_result [k ])) for k in to_reduce
82158 ), "At least one type was not correctly preserved"
83159
84160 assert isinstance (reduced ["a" ], torch .Tensor ), "Reduction Result of a Tensor should be a Tensor"
85- assert torch .allclose (expected_result ["a" ], reduced ["a" ]), "Reduction of a tensor does not yield the expected value"
161+ assert torch .equal (expected_result ["a" ], reduced ["a" ]), "Reduction of a tensor does not yield the expected value"
86162
87163 assert isinstance (reduced ["b" ], list ), "Reduction Result of a list should be a list"
88164 assert all (
89- torch .allclose (x , y ) for x , y in zip (reduced ["b" ], expected_result ["b" ])
165+ torch .equal (x , y ) for x , y in zip (reduced ["b" ], expected_result ["b" ])
90166 ), "At least one value of list reduction did not come out as expected"
91167
92168 assert isinstance (reduced ["c" ], tuple ), "Reduction Result of a tuple should be a tuple"
93169 assert all (
94- torch .allclose (x , y ) for x , y in zip (reduced ["c" ], expected_result ["c" ])
170+ torch .equal (x , y ) for x , y in zip (reduced ["c" ], expected_result ["c" ])
95171 ), "At least one value of tuple reduction did not come out as expected"
96172
97173 assert isinstance (reduced ["d" ], ntc ), "Type Consistency for named tuple not given"
@@ -109,34 +185,30 @@ def __post_init__(self):
109185 assert isinstance (reduced ["g" ], numbers .Number ), "Reduction of a number should result in a number"
110186 assert reduced ["g" ] == expected_result ["g" ], "Reduction of a number did not yield the desired result"
111187
112- assert dataclasses .is_dataclass (reduced ["h" ]) and not isinstance (
113- reduced ["h" ], type
114- ), "Reduction of a dataclass should result in a dataclass"
115- assert torch .allclose (
116- reduced ["h" ].input_ids , expected_result ["h" ].input_ids
117- ), "Reduction of a dataclass did not yield the desired result"
118- assert np .allclose (
119- reduced ["h" ].segment_ids , expected_result ["h" ].segment_ids
120- ), "Reduction of a dataclass did not yield the desired result"
121-
122- assert dataclasses .is_dataclass (reduced ["i" ]) and not isinstance (
123- reduced ["i" ], type
124- ), "Reduction of a dataclass should result in a dataclass"
125- assert dataclasses .is_dataclass (reduced ["i" ].feature ) and not isinstance (
126- reduced ["i" ].feature , type
127- ), "Reduction of a nested dataclass should result in a nested dataclass"
128- assert (
129- reduced ["i" ].example_ids == expected_result ["i" ].example_ids
130- ), "Reduction of a nested dataclass did not yield the desired result"
131- assert torch .allclose (
132- reduced ["i" ].label , expected_result ["i" ].label
133- ), "Reduction of a nested dataclass did not yield the desired result"
134- assert torch .allclose (
135- reduced ["i" ].feature .input_ids , expected_result ["i" ].feature .input_ids
136- ), "Reduction of a nested dataclass did not yield the desired result"
137- assert np .allclose (
138- reduced ["i" ].feature .segment_ids , expected_result ["i" ].feature .segment_ids
139- ), "Reduction of a nested dataclass did not yield the desired result"
188+ def _assert_dataclass_reduction (actual , expected , dataclass_type : str = "" ):
189+ assert dataclasses .is_dataclass (actual ) and not isinstance (
190+ actual , type
191+ ), f"Reduction of a { dataclass_type } dataclass should result in a dataclass"
192+ for field in dataclasses .fields (actual ):
193+ if dataclasses .is_dataclass (field .type ):
194+ _assert_dataclass_reduction (getattr (actual , field .name ), getattr (expected , field .name ), "nested" )
195+ assert actual == expected , f"Reduction of a { dataclass_type } dataclass did not yield the desired result"
196+
197+ _assert_dataclass_reduction (reduced ["h" ], expected_result ["h" ])
198+
199+ _assert_dataclass_reduction (reduced ["i" ], expected_result ["i" ])
200+
201+ dataclass_type = "ClassVar-containing"
202+ _assert_dataclass_reduction (reduced ["j" ], expected_result ["j" ], dataclass_type )
203+ assert WithClassVar .class_var == 0 , f"Reduction of a { dataclass_type } dataclass should not change the class var"
204+
205+ _assert_dataclass_reduction (reduced ["k" ], expected_result ["k" ], "InitVar-containing" )
206+
207+ dataclass_type = "Class-and-InitVar-containing"
208+ _assert_dataclass_reduction (reduced ["l" ], expected_result ["l" ], dataclass_type )
209+ assert torch .equal (
210+ WithClassAndInitVar .class_var , torch .tensor (0 )
211+ ), f"Reduction of a { dataclass_type } dataclass should not change the class var"
140212
141213 # mapping support
142214 reduced = apply_to_collection ({"a" : 1 , "b" : 2 }, int , lambda x : str (x ))
0 commit comments