|
9 | 9 | import hashlib
|
10 | 10 | import importlib.util
|
11 | 11 | import multiprocessing as mp
|
| 12 | +import time |
12 | 13 | import warnings
|
13 | 14 | import weakref
|
14 | 15 | from copy import copy
|
@@ -10823,3 +10824,112 @@ def _transform_observation_spec(
|
10823 | 10824 | )
|
10824 | 10825 | )
|
10825 | 10826 | return observation_spec
|
| 10827 | + |
| 10828 | + |
| 10829 | +class Timer(Transform): |
| 10830 | + """A transform that measures the time intervals between `inv` and `call` operations in an environment. |
| 10831 | +
|
| 10832 | + The `Timer` transform is used to track the time elapsed between the `inv` call and the `call`, |
| 10833 | + and between the `call` and the `inv` call. This is useful for performance monitoring and debugging |
| 10834 | + within an environment. The time is measured in seconds and stored as a tensor with the default |
| 10835 | + dtype from PyTorch. If the tensordict has a batch size (e.g., in batched environments), the time will be expended |
| 10836 | + to the size of the input tensordict. |
| 10837 | +
|
| 10838 | + Attributes: |
| 10839 | + out_keys: The keys of the output tensordict for the inverse transform. Defaults to |
| 10840 | + `out_keys = [f"{time_key}_step", f"{time_key}_policy"]`, where the first key represents |
| 10841 | + the time it takes to make a step in the environment, and the second key represents the |
| 10842 | + time it takes to execute the policy. |
| 10843 | + time_key: A prefix for the keys where the time intervals will be stored in the tensordict. |
| 10844 | + Defaults to `"time"`. |
| 10845 | +
|
| 10846 | + Examples: |
| 10847 | + >>> from torchrl.envs import Timer, GymEnv |
| 10848 | + >>> |
| 10849 | + >>> env = GymEnv("Pendulum-v1").append_transform(Timer()) |
| 10850 | + >>> r = env.rollout(10) |
| 10851 | + >>> print("time for policy", r["time_policy"]) |
| 10852 | + time for policy tensor([0.0000, 0.0882, 0.0004, 0.0002, 0.0002, 0.0002, 0.0002, 0.0002, 0.0002, |
| 10853 | + 0.0002]) |
| 10854 | + >>> print("time for step", r["time_step"]) |
| 10855 | + time for step tensor([9.5797e-04, 1.6289e-03, 9.7990e-05, 8.0824e-05, 9.0837e-05, 7.6056e-05, |
| 10856 | + 8.2016e-05, 7.6056e-05, 8.1062e-05, 7.7009e-05]) |
| 10857 | + """ |
| 10858 | + |
| 10859 | + def __init__(self, out_keys: Sequence[NestedKey] = None, time_key: str = "time"): |
| 10860 | + if out_keys is None: |
| 10861 | + out_keys = [f"{time_key}_step", f"{time_key}_policy"] |
| 10862 | + elif len(out_keys) != 2: |
| 10863 | + raise TypeError(f"Expected two out_keys. Got out_keys={out_keys}.") |
| 10864 | + super().__init__([], out_keys) |
| 10865 | + self.time_key = time_key |
| 10866 | + self.last_inv_time = None |
| 10867 | + self.last_call_time = None |
| 10868 | + |
| 10869 | + def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase: |
| 10870 | + self.last_inv_time = time.time() |
| 10871 | + return tensordict |
| 10872 | + |
| 10873 | + def _maybe_expand_and_set(self, key, time_elapsed, tensordict): |
| 10874 | + if isinstance(key, tuple): |
| 10875 | + parent_td = tensordict.get(key[:-1]) |
| 10876 | + key = key[-1] |
| 10877 | + else: |
| 10878 | + parent_td = tensordict |
| 10879 | + batch_size = parent_td.batch_size |
| 10880 | + if batch_size: |
| 10881 | + # Get the parent shape |
| 10882 | + time_elapsed_expand = time_elapsed.expand(parent_td.batch_size) |
| 10883 | + else: |
| 10884 | + time_elapsed_expand = time_elapsed |
| 10885 | + parent_td.set(key, time_elapsed_expand) |
| 10886 | + |
| 10887 | + def _reset( |
| 10888 | + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase |
| 10889 | + ) -> TensorDictBase: |
| 10890 | + current_time = time.time() |
| 10891 | + if self.last_inv_time is not None: |
| 10892 | + time_elapsed = torch.tensor( |
| 10893 | + current_time - self.last_inv_time, device=tensordict.device |
| 10894 | + ) |
| 10895 | + self._maybe_expand_and_set(self.out_keys[0], time_elapsed, tensordict_reset) |
| 10896 | + self.last_call_time = current_time |
| 10897 | + # Placeholder |
| 10898 | + self._maybe_expand_and_set(self.out_keys[1], time_elapsed * 0, tensordict_reset) |
| 10899 | + return tensordict_reset |
| 10900 | + |
| 10901 | + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: |
| 10902 | + current_time = time.time() |
| 10903 | + if self.last_call_time is not None: |
| 10904 | + time_elapsed = torch.tensor( |
| 10905 | + current_time - self.last_call_time, device=tensordict.device |
| 10906 | + ) |
| 10907 | + self._maybe_expand_and_set(self.out_keys[1], time_elapsed, tensordict) |
| 10908 | + self.last_inv_time = current_time |
| 10909 | + return tensordict |
| 10910 | + |
| 10911 | + def _step( |
| 10912 | + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase |
| 10913 | + ) -> TensorDictBase: |
| 10914 | + current_time = time.time() |
| 10915 | + if self.last_inv_time is not None: |
| 10916 | + time_elapsed = torch.tensor( |
| 10917 | + current_time - self.last_inv_time, device=tensordict.device |
| 10918 | + ) |
| 10919 | + self._maybe_expand_and_set(self.out_keys[0], time_elapsed, next_tensordict) |
| 10920 | + self.last_call_time = current_time |
| 10921 | + # presumbly no need to worry about batch size incongruencies here |
| 10922 | + next_tensordict.set(self.out_keys[1], tensordict.get(self.out_keys[1])) |
| 10923 | + return next_tensordict |
| 10924 | + |
| 10925 | + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: |
| 10926 | + observation_spec[self.out_keys[0]] = Unbounded( |
| 10927 | + shape=observation_spec.shape, device=observation_spec.device |
| 10928 | + ) |
| 10929 | + observation_spec[self.out_keys[1]] = Unbounded( |
| 10930 | + shape=observation_spec.shape, device=observation_spec.device |
| 10931 | + ) |
| 10932 | + return observation_spec |
| 10933 | + |
| 10934 | + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: |
| 10935 | + raise NotImplementedError(FORWARD_NOT_IMPLEMENTED) |
0 commit comments