|
36 | 36 | from torchvision.transforms.functional import center_crop |
37 | 37 |
|
38 | 38 | try: |
39 | | - from torchvision.transforms.functional import resize |
| 39 | + from torchvision.transforms.functional import InterpolationMode, resize |
| 40 | + |
| 41 | + def interpolation_fn(interpolation): # noqa: D103 |
| 42 | + return InterpolationMode(interpolation) |
| 43 | + |
40 | 44 | except ImportError: |
| 45 | + |
| 46 | + def interpolation_fn(interpolation): # noqa: D103 |
| 47 | + return interpolation |
| 48 | + |
41 | 49 | from torchvision.transforms.functional_tensor import resize |
42 | 50 |
|
43 | 51 | _has_tv = True |
@@ -65,6 +73,14 @@ def new_fun(self, observation_spec): |
65 | 73 |
|
66 | 74 |
|
67 | 75 | def _apply_to_composite_inv(function): |
| 76 | + # Changes the input_spec following a transform function. |
| 77 | + # The usage is: if an env expects a certain input (e.g. a double tensor) |
| 78 | + # but the input has to be transformed (e.g. it is float), this function will |
| 79 | + # modify the spec to get a spec that from the outside matches what is given |
| 80 | + # (ie a float). |
| 81 | + # Now since EnvBase.step ignores new inputs (ie the root level of the |
| 82 | + # tensor is not updated) an out_key that does not match the in_key has |
| 83 | + # no effect on the spec. |
68 | 84 | def new_fun(self, input_spec): |
69 | 85 | if isinstance(input_spec, CompositeSpec): |
70 | 86 | d = input_spec._specs |
@@ -996,7 +1012,7 @@ def __init__( |
996 | 1012 | super().__init__(in_keys=in_keys, out_keys=out_keys) |
997 | 1013 | self.w = int(w) |
998 | 1014 | self.h = int(h) |
999 | | - self.interpolation = interpolation |
| 1015 | + self.interpolation = interpolation_fn(interpolation) |
1000 | 1016 |
|
1001 | 1017 | def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: |
1002 | 1018 | # flatten if necessary |
|
0 commit comments