|
19 | 19 | # :ref:`creating_decoder`. |
20 | 20 |
|
21 | 21 | from typing import Optional |
22 | | -import torch |
| 22 | + |
23 | 23 | import requests |
| 24 | +import torch |
24 | 25 |
|
25 | 26 |
|
26 | 27 | # Video source: https://www.pexels.com/video/dog-eating-854132/ |
|
33 | 34 | raw_video_bytes = response.content |
34 | 35 |
|
35 | 36 |
|
36 | | -def plot(frames: torch.Tensor, title : Optional[str] = None): |
| 37 | +def plot(frames: torch.Tensor, title: Optional[str] = None): |
37 | 38 | try: |
38 | | - from torchvision.utils import make_grid |
39 | | - from torchvision.transforms.v2.functional import to_pil_image |
40 | 39 | import matplotlib.pyplot as plt |
| 40 | + from torchvision.transforms.v2.functional import to_pil_image |
| 41 | + from torchvision.utils import make_grid |
41 | 42 | except ImportError: |
42 | 43 | print("Cannot plot, please run `pip install torchvision matplotlib`") |
43 | 44 | return |
44 | 45 |
|
45 | | - plt.rcParams["savefig.bbox"] = 'tight' |
| 46 | + plt.rcParams["savefig.bbox"] = "tight" |
46 | 47 | fig, ax = plt.subplots() |
47 | 48 | ax.imshow(to_pil_image(make_grid(frames))) |
48 | 49 | ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) |
@@ -76,7 +77,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None): |
76 | 77 | # --------------------------------------- |
77 | 78 |
|
78 | 79 | first_frame = decoder[0] # using a single int index |
79 | | -every_twenty_frame = decoder[0 : -1 : 20] # using slices |
| 80 | +every_twenty_frame = decoder[0:-1:20] # using slices |
80 | 81 |
|
81 | 82 | print(f"{first_frame.shape = }") |
82 | 83 | print(f"{first_frame.dtype = }") |
@@ -106,9 +107,10 @@ def plot(frames: torch.Tensor, title : Optional[str] = None): |
106 | 107 | # The decoder is a normal iterable object and can be iterated over like so: |
107 | 108 |
|
108 | 109 | for frame in decoder: |
109 | | - assert ( |
110 | | - isinstance(frame, torch.Tensor) |
111 | | - and frame.shape == (3, decoder.metadata.height, decoder.metadata.width) |
| 110 | + assert isinstance(frame, torch.Tensor) and frame.shape == ( |
| 111 | + 3, |
| 112 | + decoder.metadata.height, |
| 113 | + decoder.metadata.width, |
112 | 114 | ) |
113 | 115 |
|
114 | 116 | # %% |
|
0 commit comments