forked from sktime/pytorch-forecasting
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_utils.py
More file actions
621 lines (517 loc) · 18.8 KB
/
_utils.py
File metadata and controls
621 lines (517 loc) · 18.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
"""
Helper functions for PyTorch forecasting
"""
from collections import namedtuple
from collections.abc import Callable
from contextlib import redirect_stdout
import inspect
import os
from typing import Any, Union
import lightning.pytorch as pl
import torch
from torch import nn
from torch.fft import irfft, rfft
import torch.nn.functional as F
from torch.nn.utils import rnn
def integer_histogram(
data: torch.LongTensor, min: None | int = None, max: None | int = None
) -> torch.Tensor:
"""
Create histogram of integers in predefined range
Args:
data: data for which to create histogram
min: minimum of histogram, is inferred from data by default
max: maximum of histogram, is inferred from data by default
Returns:
histogram
"""
uniques, counts = torch.unique(data, return_counts=True)
if min is None:
min = uniques.min()
if max is None:
max = uniques.max()
hist = torch.zeros(max - min + 1, dtype=torch.long, device=data.device).scatter(
dim=0, index=uniques - min, src=counts
)
return hist
def groupby_apply(
keys: torch.Tensor,
values: torch.Tensor,
bins: int = 95,
reduction: str = "mean",
return_histogram: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Groupby apply for torch tensors
Args:
keys: tensor of groups (``0`` to ``bins``)
values: values to aggregate - same size as keys
bins: total number of groups
reduction: either "mean" or "sum"
return_histogram: if to return histogram on top
Returns:
tensor of size ``bins`` with aggregated values
and optionally with counts of values
"""
if reduction == "mean":
reduce = torch.mean
elif reduction == "sum":
reduce = torch.sum
else:
raise ValueError(
f"Unknown reduction '{reduction}'. Expected one of {{'mean', 'sum'}}."
)
uniques, counts = keys.unique(return_counts=True)
groups = torch.stack(
[reduce(item) for item in torch.split_with_sizes(values, tuple(counts))]
)
reduced = torch.zeros(bins, dtype=values.dtype, device=values.device).scatter(
dim=0, index=uniques, src=groups
)
if return_histogram:
hist = torch.zeros(bins, dtype=torch.long, device=values.device).scatter(
dim=0, index=uniques, src=counts
)
return reduced, hist
else:
return reduced
def profile(
function: Callable, profile_fname: str, filter: str = "", period=0.0001, **kwargs
):
"""
Profile a given function with ``vmprof``.
Args:
function (Callable): function to profile
profile_fname (str): path where to save profile (`.txt` file will be saved with line profile)
filter (str, optional): filter name (e.g. module name) to filter profile. Defaults to "".
period (float, optional): frequency of calling profiler in seconds. Defaults to 0.0001.
""" # noqa : E501
import vmprof
from vmprof.show import LinesPrinter
# profiler config
with open(profile_fname, "wb+") as fd:
# start profiler
vmprof.enable(fd.fileno(), lines=True, period=period)
# run function
function(**kwargs)
# stop profiler
vmprof.disable()
# write report to disk
if kwargs.get("lines", True):
with open(f"{os.path.splitext(profile_fname)[0]}.txt", "w") as f:
with redirect_stdout(f):
LinesPrinter(filter=filter).show(profile_fname)
def get_embedding_size(n: int, max_size: int = 100) -> int:
"""
Determine empirically good embedding sizes (formula taken from fastai).
Args:
n (int): number of classes
max_size (int, optional): maximum embedding size. Defaults to 100.
Returns:
int: embedding size
"""
if n > 2:
return min(round(1.6 * n**0.56), max_size)
else:
return 1
def create_mask(
size: int, lengths: torch.LongTensor, inverse: bool = False
) -> torch.BoolTensor:
"""
Create boolean masks of shape len(lengths) x size.
An entry at (i, j) is True if lengths[i] > j.
Args:
size (int): size of second dimension
lengths (torch.LongTensor): tensor of lengths
inverse (bool, optional): If true, boolean mask is inverted. Defaults to False.
Returns:
torch.BoolTensor: mask
"""
if inverse: # return where values are
return torch.arange(size, device=lengths.device).unsqueeze(
0
) < lengths.unsqueeze(-1)
else: # return where no values are
return torch.arange(size, device=lengths.device).unsqueeze(
0
) >= lengths.unsqueeze(-1)
_NEXT_FAST_LEN = {}
def next_fast_len(size):
"""
Returns the next largest number ``n >= size`` whose prime factors are all
2, 3, or 5. These sizes are efficient for fast fourier transforms.
Equivalent to :func:`scipy.fftpack.next_fast_len`.
Implementation from pyro
:param int size: A positive number.
:returns: A possibly larger number.
:rtype int:
"""
try:
return _NEXT_FAST_LEN[size]
except KeyError:
pass
assert isinstance(size, int) and size > 0
next_size = size
while True:
remaining = next_size
for n in (2, 3, 5):
while remaining % n == 0:
remaining //= n
if remaining == 1:
_NEXT_FAST_LEN[size] = next_size
return next_size
next_size += 1
def autocorrelation(input, dim=0):
"""
Computes the autocorrelation of samples at dimension ``dim``.
Reference: https://en.wikipedia.org/wiki/Autocorrelation#Efficient_computation
Implementation copied form `pyro <https://github.com/pyro-ppl/pyro/blob/dev/pyro/ops/stats.py>`_.
:param torch.Tensor input: the input tensor.
:param int dim: the dimension to calculate autocorrelation.
:returns torch.Tensor: autocorrelation of ``input``.
"""
# Adapted from Stan implementation
# https://github.com/stan-dev/math/blob/develop/stan/math/prim/mat/fun/autocorrelation.hpp
N = input.size(dim)
M = next_fast_len(N)
M2 = 2 * M
# transpose dim with -1 for Fourier transform
input = input.transpose(dim, -1)
# centering and padding x
centered_signal = input - input.mean(dim=-1, keepdim=True)
# Fourier transform
freqvec = torch.view_as_real(rfft(centered_signal, n=M2))
# take square of magnitude of freqvec (or freqvec x freqvec*)
freqvec_gram = freqvec.pow(2).sum(-1)
# inverse Fourier transform
autocorr = irfft(freqvec_gram, n=M2)
# truncate and normalize the result, then transpose back to original shape
autocorr = autocorr[..., :N]
autocorr = autocorr / torch.tensor(
range(N, 0, -1), dtype=input.dtype, device=input.device
)
autocorr = autocorr / autocorr[..., :1]
return autocorr.transpose(dim, -1)
def unpack_sequence(
sequence: torch.Tensor | rnn.PackedSequence,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Unpack RNN sequence.
Args:
sequence (Union[torch.Tensor, rnn.PackedSequence]): RNN packed sequence or tensor of which
first index are samples and second are timesteps
Returns:
Tuple[torch.Tensor, torch.Tensor]: tuple of unpacked sequence and length of samples
""" # noqa : E501
if isinstance(sequence, rnn.PackedSequence):
sequence, lengths = rnn.pad_packed_sequence(sequence, batch_first=True)
# batch sizes reside on the CPU by default -> we need to bring them to GPU
lengths = lengths.to(sequence.device)
else:
lengths = torch.ones(
sequence.size(0), device=sequence.device, dtype=torch.long
) * sequence.size(1)
return sequence, lengths
def concat_sequences(
sequences: list[torch.Tensor] | list[rnn.PackedSequence],
) -> torch.Tensor | rnn.PackedSequence:
"""
Concatenate RNN sequences.
Args:
sequences (Union[List[torch.Tensor], List[rnn.PackedSequence]): list of RNN packed sequences or tensors of which
first index are samples and second are timesteps
Returns:
Union[torch.Tensor, rnn.PackedSequence]: concatenated sequence
""" # noqa : E501
if isinstance(sequences[0], rnn.PackedSequence):
return rnn.pack_sequence(sequences, enforce_sorted=False)
elif isinstance(sequences[0], torch.Tensor):
return torch.cat(sequences, dim=0)
elif isinstance(sequences[0], tuple | list):
return tuple(
concat_sequences([sequences[ii][i] for ii in range(len(sequences))])
for i in range(len(sequences[0]))
)
else:
raise ValueError("Unsupported sequence type")
def padded_stack(
tensors: list[torch.Tensor],
side: str = "right",
mode: str = "constant",
value: int | float = 0,
) -> torch.Tensor:
"""
Stack tensors along first dimension and pad them along last dimension to ensure their size is equal.
Args:
tensors (List[torch.Tensor]): list of tensors to stack
side (str): side on which to pad - "left" or "right". Defaults to "right".
mode (str): 'constant', 'reflect', 'replicate' or 'circular'. Default: 'constant'
value (Union[int, float]): value to use for constant padding
Returns:
torch.Tensor: stacked tensor
""" # noqa : E501
full_size = max([x.size(-1) for x in tensors])
def make_padding(pad):
if side == "left":
return (pad, 0)
elif side == "right":
return (0, pad)
else:
raise ValueError(f"side for padding '{side}' is unknown")
out = torch.stack(
[
(
F.pad(x, make_padding(full_size - x.size(-1)), mode=mode, value=value)
if full_size - x.size(-1) > 0
else x
)
for x in tensors
],
dim=0,
)
return out
def to_list(value: Any) -> list[Any]:
"""
Convert value or list to list of values.
If already list, return object directly
Args:
value (Any): value to convert
Returns:
List[Any]: list of values
"""
if isinstance(value, tuple | list) and not isinstance(value, rnn.PackedSequence):
return value
else:
return [value]
def unsqueeze_like(tensor: torch.Tensor, like: torch.Tensor):
"""
Unsqueeze last dimensions of tensor to match another tensor's number of dimensions.
Args:
tensor (torch.Tensor): tensor to unsqueeze
like (torch.Tensor): tensor whose dimensions to match
"""
n_unsqueezes = like.ndim - tensor.ndim
if n_unsqueezes < 0:
raise ValueError(f"tensor.ndim={tensor.ndim} > like.ndim={like.ndim}")
elif n_unsqueezes == 0:
return tensor
else:
return tensor[(...,) + (None,) * n_unsqueezes]
def apply_to_list(obj: list[Any] | Any, func: Callable) -> list[Any] | Any:
"""
Apply function to a list of objects or directly if passed value is not a list.
This is useful if the passed object could be either a list to whose elements
a function needs to be applied or just an object to which to apply the function.
Args:
obj (Union[List[Any], Any]): list/tuple on whose elements to apply function,
otherwise object to whom to apply function
func (Callable): function to apply
Returns:
Union[List[Any], Any]: list of objects or object depending on function output
and if input ``obj`` is of type list/tuple
"""
if isinstance(obj, tuple | list) and not isinstance(obj, rnn.PackedSequence):
return [func(o) for o in obj]
else:
return func(obj)
class OutputMixIn:
"""
MixIn to give namedtuple some access capabilities of a dictionary
"""
def __getitem__(self, k):
if isinstance(k, str):
return getattr(self, k)
else:
return super().__getitem__(k)
def get(self, k, default=None):
return getattr(self, k, default)
def items(self):
return zip(self._fields, self)
def keys(self):
return self._fields
def iget(self, idx: int | slice):
"""Select item(s) row-wise.
Args:
idx ([int, slice]): item to select
Returns:
Output of single item.
"""
return self.__class__(*(x[idx] for x in self))
class TupleOutputMixIn:
"""MixIn to give output a namedtuple-like access capabilitieswith ``to_network_output() function``.""" # noqa : E501
def to_network_output(self, **results):
"""
Convert output into a named (and immuatable) tuple.
This allows tracing the modules as graphs and prevents modifying the output.
Returns:
named tuple
"""
if hasattr(self, "_output_class"):
Output = self._output_class
else:
OutputTuple = namedtuple("output", results)
class Output(OutputMixIn, OutputTuple):
pass
self._output_class = Output
return self._output_class(**results)
def move_to_device(
x: dict[str, torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]]
| torch.Tensor
| list[torch.Tensor]
| tuple[torch.Tensor],
device: str | torch.DeviceObjType,
) -> (
dict[str, torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]]
| torch.Tensor
| list[torch.Tensor]
| tuple[torch.Tensor]
):
"""
Move object to device.
Args:
x (dictionary of list of tensors): object (e.g. dictionary) of tensors to move to device
device (Union[str, torch.DeviceObjType]): device, e.g. "cpu"
Returns:
x on targeted device
""" # noqa: E501
if isinstance(device, str):
if device == "mps":
if hasattr(torch.backends, device):
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device("mps")
else:
device = torch.device("cpu")
else:
device = torch.device(device)
if isinstance(x, dict):
for name in x.keys():
x[name] = move_to_device(x[name], device=device)
elif isinstance(x, OutputMixIn):
for xi in x:
move_to_device(xi, device=device)
return x
elif isinstance(x, torch.Tensor) and x.device != device:
x = x.to(device)
elif isinstance(x, tuple | list) and x[0].device != device:
x = [move_to_device(xi, device=device) for xi in x]
return x
def detach(
x: dict[str, torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]]
| torch.Tensor
| list[torch.Tensor]
| tuple[torch.Tensor],
) -> (
dict[str, torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]]
| torch.Tensor
| list[torch.Tensor]
| tuple[torch.Tensor]
):
"""
Detach object
Args:
x: object to detach
Returns:
detached object
"""
if isinstance(x, torch.Tensor):
return x.detach()
elif isinstance(x, dict):
return {name: detach(xi) for name, xi in x.items()}
elif isinstance(x, OutputMixIn):
return x.__class__(**{name: detach(xi) for name, xi in x.items()})
elif isinstance(x, tuple | list):
return [detach(xi) for xi in x]
else:
return x
def masked_op(
tensor: torch.Tensor, op: str = "mean", dim: int = 0, mask: torch.Tensor = None
) -> torch.Tensor:
"""Calculate operation on masked tensor.
Args:
tensor (torch.Tensor): tensor to conduct operation over
op (str): operation to apply. One of ["mean", "sum"]. Defaults to "mean".
dim (int, optional): dimension to average over. Defaults to 0.
mask (torch.Tensor, optional): boolean mask to apply (True=will take mean, False=ignore).
Masks nan values by default.
Returns:
torch.Tensor: tensor with averaged out dimension
""" # noqa : E501
if mask is None:
mask = ~torch.isnan(tensor)
masked = tensor.masked_fill(~mask, 0.0)
summed = masked.sum(dim=dim)
if op == "mean":
return summed / mask.sum(dim=dim) # Find the average
elif op == "sum":
return summed
else:
raise ValueError(f"unknown operation {op}")
def repr_class(
obj,
attributes: list[str] | dict[str, Any],
max_characters_before_break: int = 100,
extra_attributes: dict[str, Any] = None,
) -> str:
"""Print class name and parameters.
Args:
obj: class to format
attributes (Union[List[str], Dict[str]]): list of attributes to show or dictionary of attributes and values
to show max_characters_before_break (int): number of characters before breaking the into multiple lines
extra_attributes (Dict[str, Any]): extra attributes to show in angled brackets
Returns:
str
""" # noqa E501
if extra_attributes is None:
extra_attributes = {}
# get attributes
if isinstance(attributes, tuple | list):
attributes = {
name: getattr(obj, name) for name in attributes if hasattr(obj, name)
}
attributes_strings = [f"{name}={repr(value)}" for name, value in attributes.items()]
# get header
header_name = obj.__class__.__name__
# add extra attributes
if len(extra_attributes) > 0:
extra_attributes_strings = [
f"{name}={repr(value)}" for name, value in extra_attributes.items()
]
if (
len(header_name) + 2 + len(", ".join(extra_attributes_strings))
> max_characters_before_break
):
header = f"{header_name}[\n\t" + ",\n\t".join(attributes_strings) + "\n]("
else:
header = f"{header_name}[{', '.join(extra_attributes_strings)}]("
else:
header = f"{header_name}("
# create final representation
attributes_string = ", ".join(attributes_strings)
if (
len(attributes_string) + len(header.split("\n")[-1]) + 1
> max_characters_before_break
):
attributes_string = "\n\t" + ",\n\t".join(attributes_strings) + "\n"
return f"{header}{attributes_string})"
class InitialParameterRepresenterMixIn:
def __repr__(self) -> str:
if isinstance(self, nn.Module):
return super().__repr__()
else:
attributes = list(inspect.signature(self.__class__).parameters.keys())
return repr_class(self, attributes=attributes)
def extra_repr(self) -> str:
"""
Return extra information about parameters for representation/logging.
"""
if isinstance(self, pl.LightningModule):
return "\t" + repr(self.hparams).replace("\n", "\n\t")
else:
attributes = list(inspect.signature(self.__class__).parameters.keys())
return ", ".join(
[
f"{name}={repr(getattr(self, name))}"
for name in attributes
if hasattr(self, name)
]
)