|
7 | 7 |
|
8 | 8 | # pyre-strict |
9 | 9 |
|
10 | | -from typing import Iterable, Iterator, List, Optional, Tuple |
| 10 | +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple |
11 | 11 |
|
12 | 12 | import torch |
13 | 13 | from torch import nn, Tensor |
@@ -236,3 +236,51 @@ def configure_optimizers_and_lr_scheduler( |
236 | 236 | my_optimizer, gamma=0.9 |
237 | 237 | ) |
238 | 238 | return my_optimizer, my_lr_scheduler |
| 239 | + |
| 240 | + |
| 241 | +class DummyStatefulDataLoader: |
| 242 | + """Dummy Dataloader that implements state_dict and load_state_dict""" |
| 243 | + |
| 244 | + def __init__(self, dataloader: DataLoader) -> None: |
| 245 | + self.dataloader = dataloader |
| 246 | + |
| 247 | + def state_dict(self) -> Dict[str, Any]: |
| 248 | + return {"current_batch": 1} |
| 249 | + |
| 250 | + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: |
| 251 | + return None |
| 252 | + |
| 253 | + def __iter__(self) -> Iterator[object]: |
| 254 | + return iter(self.dataloader) |
| 255 | + |
| 256 | + |
| 257 | +def generate_dummy_stateful_dataloader( |
| 258 | + num_samples: int, input_dim: int, batch_size: int |
| 259 | +) -> DummyStatefulDataLoader: |
| 260 | + return DummyStatefulDataLoader( |
| 261 | + DataLoader( |
| 262 | + dataset=RandomIterableDataset(input_dim, num_samples), |
| 263 | + batch_size=batch_size, |
| 264 | + ) |
| 265 | + ) |
| 266 | + |
| 267 | + |
| 268 | +class DummyMeanMetric: |
| 269 | + def __init__(self) -> None: |
| 270 | + super().__init__() |
| 271 | + self.sum: float = 0.0 |
| 272 | + self.count: int = 0 |
| 273 | + |
| 274 | + def update(self, value: float) -> None: |
| 275 | + self.sum += value |
| 276 | + self.count += 1 |
| 277 | + |
| 278 | + def compute(self) -> float: |
| 279 | + return self.sum / self.count if self.count > 0 else 0.0 |
| 280 | + |
| 281 | + def state_dict(self) -> Dict[str, Any]: |
| 282 | + return {"sum": self.sum, "count": self.count} |
| 283 | + |
| 284 | + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: |
| 285 | + self.sum = state_dict["sum"] |
| 286 | + self.count = state_dict["count"] |
0 commit comments