|
24 | 24 | from distutils.util import strtobool
|
25 | 25 | from functools import wraps
|
26 | 26 | from importlib import import_module
|
27 |
| -from typing import Any, Callable, cast, Dict, TypeVar, Union |
| 27 | +from typing import Any, Callable, cast, Dict, Tuple, TypeVar, Union |
28 | 28 |
|
29 | 29 | import numpy as np
|
30 | 30 | import torch
|
31 | 31 | from packaging.version import parse
|
32 | 32 | from tensordict import unravel_key
|
33 | 33 |
|
34 | 34 | from tensordict.utils import NestedKey
|
35 |
| -from torch import multiprocessing as mp |
| 35 | +from torch import multiprocessing as mp, Tensor |
36 | 36 |
|
37 | 37 | try:
|
38 | 38 | from torch.compiler import is_compiling
|
@@ -872,6 +872,70 @@ def set_mode(self, type: Any | None) -> None:
|
872 | 872 | self._mode = type
|
873 | 873 |
|
874 | 874 |
|
| 875 | +def _standardize( |
| 876 | + input: Tensor, |
| 877 | + exclude_dims: Tuple[int] = (), |
| 878 | + mean: Tensor | None = None, |
| 879 | + std: Tensor | None = None, |
| 880 | + eps: float | None = None, |
| 881 | +): |
| 882 | + """Standardizes the input tensor with the possibility of excluding specific dims from the statistics. |
| 883 | +
|
| 884 | + Useful when processing multi-agent data to keep the agent dimensions independent. |
| 885 | +
|
| 886 | + Args: |
| 887 | + input (Tensor): the input tensor to be standardized. |
| 888 | + exclude_dims (Tuple[int]): dimensions to exclude from the statistics, can be negative. Default: (). |
| 889 | + mean (Tensor): a mean to be used for standardization. Must be of shape broadcastable to input. Default: None. |
| 890 | + std (Tensor): a standard deviation to be used for standardization. Must be of shape broadcastable to input. Default: None. |
| 891 | + eps (float): epsilon to be used for numerical stability. Default: float32 resolution. |
| 892 | +
|
| 893 | + """ |
| 894 | + if eps is None: |
| 895 | + if input.dtype.is_floating_point: |
| 896 | + eps = torch.finfo(torch.float).resolution |
| 897 | + else: |
| 898 | + eps = 1e-6 |
| 899 | + |
| 900 | + len_exclude_dims = len(exclude_dims) |
| 901 | + if not len_exclude_dims: |
| 902 | + if mean is None: |
| 903 | + mean = input.mean() |
| 904 | + else: |
| 905 | + # Assume dtypes are compatible |
| 906 | + mean = torch.as_tensor(mean, device=input.device) |
| 907 | + if std is None: |
| 908 | + std = input.std() |
| 909 | + else: |
| 910 | + # Assume dtypes are compatible |
| 911 | + std = torch.as_tensor(std, device=input.device) |
| 912 | + return (input - mean) / std.clamp_min(eps) |
| 913 | + |
| 914 | + input_shape = input.shape |
| 915 | + exclude_dims = [ |
| 916 | + d if d >= 0 else d + len(input_shape) for d in exclude_dims |
| 917 | + ] # Make negative dims positive |
| 918 | + |
| 919 | + if len(set(exclude_dims)) != len_exclude_dims: |
| 920 | + raise ValueError("Exclude dims has repeating elements") |
| 921 | + if any(dim < 0 or dim >= len(input_shape) for dim in exclude_dims): |
| 922 | + raise ValueError( |
| 923 | + f"exclude_dims={exclude_dims} provided outside bounds for input of shape={input_shape}" |
| 924 | + ) |
| 925 | + if len_exclude_dims == len(input_shape): |
| 926 | + warnings.warn( |
| 927 | + "_standardize called but all dims were excluded from the statistics, returning unprocessed input" |
| 928 | + ) |
| 929 | + return input |
| 930 | + |
| 931 | + included_dims = tuple(d for d in range(len(input_shape)) if d not in exclude_dims) |
| 932 | + if mean is None: |
| 933 | + mean = torch.mean(input, keepdim=True, dim=included_dims) |
| 934 | + if std is None: |
| 935 | + std = torch.std(input, keepdim=True, dim=included_dims) |
| 936 | + return (input - mean) / std.clamp_min(eps) |
| 937 | + |
| 938 | + |
875 | 939 | @wraps(torch.compile)
|
876 | 940 | def compile_with_warmup(*args, warmup: int = 1, **kwargs):
|
877 | 941 | """Compile a model with warm-up.
|
|
0 commit comments