|
25 | 25 | from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
26 | 26 |
|
27 | 27 |
|
| 28 | +@dataclasses.dataclass |
| 29 | +class Feature: |
| 30 | + input_ids: torch.Tensor |
| 31 | + segment_ids: np.ndarray |
| 32 | + |
| 33 | + def __eq__(self, o: object) -> bool: |
| 34 | + if not isinstance(o, Feature): |
| 35 | + return NotImplemented |
| 36 | + |
| 37 | + return torch.equal(self.input_ids, o.input_ids) and np.equal(self.segment_ids, o.segment_ids).all() |
| 38 | + |
| 39 | + |
| 40 | +@dataclasses.dataclass |
| 41 | +class ModelExample: |
| 42 | + example_ids: List[str] |
| 43 | + feature: Feature |
| 44 | + label: torch.Tensor |
| 45 | + some_constant: int = dataclasses.field(init=False) |
| 46 | + |
| 47 | + def __post_init__(self): |
| 48 | + self.some_constant = 7 |
| 49 | + |
| 50 | + def __eq__(self, o: object) -> bool: |
| 51 | + if not isinstance(o, ModelExample): |
| 52 | + return NotImplemented |
| 53 | + |
| 54 | + return ( |
| 55 | + self.example_ids == o.example_ids |
| 56 | + and self.feature == o.feature |
| 57 | + and torch.equal(self.label, o.label) |
| 58 | + and self.some_constant == o.some_constant |
| 59 | + ) |
| 60 | + |
| 61 | + |
| 62 | +@dataclasses.dataclass |
| 63 | +class WithClassVar: |
| 64 | + class_var: ClassVar[int] = 0 |
| 65 | + dummy: Any |
| 66 | + |
| 67 | + def __eq__(self, o: object) -> bool: |
| 68 | + if not isinstance(o, WithClassVar): |
| 69 | + return NotImplemented |
| 70 | + elif isinstance(self.dummy, torch.Tensor): |
| 71 | + return torch.equal(self.dummy, o.dummy) |
| 72 | + |
| 73 | + return self.dummy == o.dummy |
| 74 | + |
| 75 | + |
| 76 | +@dataclasses.dataclass |
| 77 | +class WithInitVar: |
| 78 | + dummy: Any |
| 79 | + override: InitVar[Optional[Any]] = None |
| 80 | + |
| 81 | + def __post_init__(self, override: Optional[Any]): |
| 82 | + if override is not None: |
| 83 | + self.dummy = override |
| 84 | + |
| 85 | + def __eq__(self, o: object) -> bool: |
| 86 | + if not isinstance(o, WithInitVar): |
| 87 | + return NotImplemented |
| 88 | + elif isinstance(self.dummy, torch.Tensor): |
| 89 | + return torch.equal(self.dummy, o.dummy) |
| 90 | + |
| 91 | + return self.dummy == o.dummy |
| 92 | + |
| 93 | + |
| 94 | +@dataclasses.dataclass |
| 95 | +class WithClassAndInitVar: |
| 96 | + class_var: ClassVar[torch.Tensor] = torch.tensor(0) |
| 97 | + dummy: Any |
| 98 | + override: InitVar[Optional[Any]] = torch.tensor(1) |
| 99 | + |
| 100 | + def __post_init__(self, override: Optional[Any]): |
| 101 | + if override is not None: |
| 102 | + self.dummy = override |
| 103 | + |
| 104 | + def __eq__(self, o: object) -> bool: |
| 105 | + if not isinstance(o, WithClassAndInitVar): |
| 106 | + return NotImplemented |
| 107 | + elif isinstance(self.dummy, torch.Tensor): |
| 108 | + return torch.equal(self.dummy, o.dummy) |
| 109 | + |
| 110 | + return self.dummy == o.dummy |
| 111 | + |
| 112 | + |
28 | 113 | def test_recursive_application_to_collection():
|
29 | 114 | ntc = namedtuple("Foo", ["bar"])
|
30 | 115 |
|
31 |
| - @dataclasses.dataclass |
32 |
| - class Feature: |
33 |
| - input_ids: torch.Tensor |
34 |
| - segment_ids: np.ndarray |
35 |
| - |
36 |
| - def __eq__(self, o: object) -> bool: |
37 |
| - if not isinstance(o, Feature): |
38 |
| - return NotImplemented |
39 |
| - else: |
40 |
| - return torch.equal(self.input_ids, o.input_ids) and np.equal(self.segment_ids, o.segment_ids).all() |
41 |
| - |
42 |
| - @dataclasses.dataclass |
43 |
| - class ModelExample: |
44 |
| - example_ids: List[str] |
45 |
| - feature: Feature |
46 |
| - label: torch.Tensor |
47 |
| - some_constant: int = dataclasses.field(init=False) |
48 |
| - |
49 |
| - def __post_init__(self): |
50 |
| - self.some_constant = 7 |
51 |
| - |
52 |
| - def __eq__(self, o: object) -> bool: |
53 |
| - if not isinstance(o, ModelExample): |
54 |
| - return NotImplemented |
55 |
| - else: |
56 |
| - return ( |
57 |
| - self.example_ids == o.example_ids |
58 |
| - and self.feature == o.feature |
59 |
| - and torch.equal(self.label, o.label) |
60 |
| - and self.some_constant == o.some_constant |
61 |
| - ) |
62 |
| - |
63 |
| - @dataclasses.dataclass |
64 |
| - class WithClassVar: |
65 |
| - class_var: ClassVar[int] = 0 |
66 |
| - dummy: Any |
67 |
| - |
68 |
| - def __eq__(self, o: object) -> bool: |
69 |
| - if not isinstance(o, WithClassVar): |
70 |
| - return NotImplemented |
71 |
| - elif isinstance(self.dummy, torch.Tensor): |
72 |
| - return torch.equal(self.dummy, o.dummy) |
73 |
| - else: |
74 |
| - return self.dummy == o.dummy |
75 |
| - |
76 |
| - @dataclasses.dataclass |
77 |
| - class WithInitVar: |
78 |
| - dummy: Any |
79 |
| - override: InitVar[Optional[Any]] = None |
80 |
| - |
81 |
| - def __post_init__(self, override: Optional[Any]): |
82 |
| - if override is not None: |
83 |
| - self.dummy = override |
84 |
| - |
85 |
| - def __eq__(self, o: object) -> bool: |
86 |
| - if not isinstance(o, WithInitVar): |
87 |
| - return NotImplemented |
88 |
| - elif isinstance(self.dummy, torch.Tensor): |
89 |
| - return torch.equal(self.dummy, o.dummy) |
90 |
| - else: |
91 |
| - return self.dummy == o.dummy |
92 |
| - |
93 |
| - @dataclasses.dataclass |
94 |
| - class WithClassAndInitVar: |
95 |
| - class_var: ClassVar[torch.Tensor] = torch.tensor(0) |
96 |
| - dummy: Any |
97 |
| - override: InitVar[Optional[Any]] = torch.tensor(1) |
98 |
| - |
99 |
| - def __post_init__(self, override: Optional[Any]): |
100 |
| - if override is not None: |
101 |
| - self.dummy = override |
102 |
| - |
103 |
| - def __eq__(self, o: object) -> bool: |
104 |
| - if not isinstance(o, WithClassAndInitVar): |
105 |
| - return NotImplemented |
106 |
| - elif isinstance(self.dummy, torch.Tensor): |
107 |
| - return torch.equal(self.dummy, o.dummy) |
108 |
| - else: |
109 |
| - return self.dummy == o.dummy |
110 |
| - |
111 | 116 | model_example = ModelExample(
|
112 | 117 | example_ids=["i-1", "i-2", "i-3"],
|
113 | 118 | feature=Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])),
|
@@ -303,6 +308,36 @@ def fn(a, b):
|
303 | 308 | assert reduced is None
|
304 | 309 |
|
305 | 310 |
|
| 311 | +def test_apply_to_collections_dataclass(): |
| 312 | + to_reduce_1 = Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])) |
| 313 | + to_reduce_2 = Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=np.array([4.0, 5.0, 6.0])) |
| 314 | + |
| 315 | + def fn(a, b): |
| 316 | + return a + b |
| 317 | + |
| 318 | + reduced = apply_to_collections(to_reduce_1, to_reduce_2, (torch.Tensor, numbers.Number, np.ndarray), fn) |
| 319 | + |
| 320 | + assert reduced == Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=np.array([8.0, 10.0, 12.0])) |
| 321 | + |
| 322 | + model_example = ModelExample( |
| 323 | + example_ids=["i-1", "i-2", "i-3"], |
| 324 | + feature=to_reduce_1, |
| 325 | + label=torch.tensor([7.0, 8.0, 9.0]), |
| 326 | + ) |
| 327 | + |
| 328 | + # different types |
| 329 | + with pytest.raises(TypeError, match="Expected inputs to be dataclasses of the same type"): |
| 330 | + apply_to_collections(to_reduce_1, [1, 2], (torch.Tensor, numbers.Number, np.ndarray), fn) |
| 331 | + |
| 332 | + # unmatched fields |
| 333 | + with pytest.raises(TypeError, match="Dataclasses fields do not match"): |
| 334 | + apply_to_collections(to_reduce_1, model_example, (torch.Tensor, numbers.Number, np.ndarray), fn) |
| 335 | + |
| 336 | + classvar = WithClassVar(torch.arange(3)) # dataclass with same number but different type of fields |
| 337 | + with pytest.raises(TypeError, match="Dataclasses fields do not match"): |
| 338 | + apply_to_collections(to_reduce_1, classvar, (torch.Tensor, numbers.Number, np.ndarray), fn) |
| 339 | + |
| 340 | + |
306 | 341 | def test_apply_to_collection_frozen_dataclass():
|
307 | 342 | @dataclasses.dataclass(frozen=True)
|
308 | 343 | class Foo:
|
|
0 commit comments