Skip to content

Commit 0d4ab14

Browse files
authored
Prepare 0.0.2: revert GPU-related PRs on release branch (#169)
1 parent 430d258 commit 0d4ab14

File tree

9 files changed

+64
-170
lines changed

9 files changed

+64
-170
lines changed

benchmarks/decoders/BenchmarkDecodersMain.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ void runNDecodeIterations(
6363
decoder->addVideoStreamDecoder(-1);
6464
for (double pts : ptsList) {
6565
decoder->setCursorPtsInSeconds(pts);
66-
torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame;
66+
torch::Tensor tensor = decoder->getNextDecodedOutput().frame;
6767
}
6868
if (i + 1 == warmupIterations) {
6969
start = std::chrono::high_resolution_clock::now();
@@ -95,7 +95,7 @@ void runNdecodeIterationsGrabbingConsecutiveFrames(
9595
VideoDecoder::createFromFilePath(videoPath);
9696
decoder->addVideoStreamDecoder(-1);
9797
for (int j = 0; j < consecutiveFrameCount; ++j) {
98-
torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame;
98+
torch::Tensor tensor = decoder->getNextDecodedOutput().frame;
9999
}
100100
if (i + 1 == warmupIterations) {
101101
start = std::chrono::high_resolution_clock::now();

benchmarks/decoders/gpu_benchmark.py

Lines changed: 27 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -5,59 +5,43 @@
55
import torch.utils.benchmark as benchmark
66

77
import torchcodec
8-
import torchvision.transforms.v2.functional as F
8+
from torchvision.transforms import Resize
99

10-
RESIZED_WIDTH = 256
11-
RESIZED_HEIGHT = 256
1210

13-
14-
def transfer_and_resize_frame(frame, resize_device_string):
15-
# This should be a no-op if the frame is already on the target device.
16-
frame = frame.to(resize_device_string)
17-
frame = F.resize(frame, (RESIZED_HEIGHT, RESIZED_WIDTH))
11+
def transfer_and_resize_frame(frame, device):
12+
# This should be a no-op if the frame is already on the device.
13+
frame = frame.to(device)
14+
frame = Resize((256, 256))(frame)
1815
return frame
1916

2017

21-
def decode_full_video(video_path, decode_device_string, resize_device_string):
22-
# We use the core API instead of SimpleVideoDecoder because the core API
23-
# allows us to natively resize as part of the decode step.
24-
print(f"{decode_device_string=} {resize_device_string=}")
18+
def decode_full_video(video_path, decode_device):
2519
decoder = torchcodec.decoders._core.create_from_file(video_path)
2620
num_threads = None
27-
if "cuda" in decode_device_string:
21+
if "cuda" in decode_device:
2822
num_threads = 1
29-
width = None
30-
height = None
31-
if "native" in resize_device_string:
32-
width = RESIZED_WIDTH
33-
height = RESIZED_HEIGHT
3423
torchcodec.decoders._core.add_video_stream(
35-
decoder,
36-
stream_index=-1,
37-
device_string=decode_device_string,
38-
num_threads=num_threads,
39-
width=width,
40-
height=height,
24+
decoder, stream_index=0, device_string=decode_device, num_threads=num_threads
4125
)
42-
4326
start_time = time.time()
4427
frame_count = 0
4528
while True:
4629
try:
4730
frame, *_ = torchcodec.decoders._core.get_next_frame(decoder)
48-
if resize_device_string != "none" and "native" not in resize_device_string:
49-
frame = transfer_and_resize_frame(frame, resize_device_string)
31+
# You can do a resize to simulate extra preproc work that happens
32+
# on the GPU by uncommenting the following line:
33+
# frame = transfer_and_resize_frame(frame, decode_device)
5034

5135
frame_count += 1
5236
except Exception as e:
5337
print("EXCEPTION", e)
5438
break
55-
39+
# print(f"current {frame_count=}", flush=True)
5640
end_time = time.time()
5741
elapsed = end_time - start_time
5842
fps = frame_count / (end_time - start_time)
5943
print(
60-
f"****** DECODED full video {decode_device_string=} {frame_count=} {elapsed=} {fps=}"
44+
f"****** DECODED full video {decode_device=} {frame_count=} {elapsed=} {fps=}"
6145
)
6246
return frame_count, end_time - start_time
6347

@@ -70,12 +54,6 @@ def main():
7054
type=str,
7155
help="Comma-separated devices to test decoding on.",
7256
)
73-
parser.add_argument(
74-
"--resize_devices",
75-
default="cuda:0,cpu,native,none",
76-
type=str,
77-
help="Comma-separated devices to test preroc (resize) on. Use 'none' to specify no resize.",
78-
)
7957
parser.add_argument(
8058
"--video",
8159
type=str,
@@ -100,44 +78,23 @@ def main():
10078
decode_full_video(video_path, device)
10179
return
10280

103-
resize_devices = args.resize_devices.split(",")
104-
resize_devices = [d for d in resize_devices if d != ""]
105-
if len(resize_devices) == 0:
106-
resize_devices.append("none")
107-
108-
label = "Decode+Resize Time"
109-
11081
results = []
111-
for decode_device_string in args.devices.split(","):
112-
for resize_device_string in resize_devices:
113-
decode_label = decode_device_string
114-
if "cuda" in decode_label:
115-
# Shorten "cuda:0" to "cuda"
116-
decode_label = "cuda"
117-
resize_label = resize_device_string
118-
if "cuda" in resize_device_string:
119-
# Shorten "cuda:0" to "cuda"
120-
resize_label = "cuda"
121-
print("decode_device", decode_device_string)
122-
print("resize_device", resize_device_string)
123-
t = benchmark.Timer(
124-
stmt="decode_full_video(video_path, decode_device_string, resize_device_string)",
125-
globals={
126-
"decode_device_string": decode_device_string,
127-
"video_path": video_path,
128-
"decode_full_video": decode_full_video,
129-
"resize_device_string": resize_device_string,
130-
},
131-
label=label,
132-
description=f"video={os.path.basename(video_path)}",
133-
sub_label=f"D={decode_label} R={resize_label}",
134-
).blocked_autorange()
135-
results.append(t)
82+
for device in args.devices.split(","):
83+
print("device", device)
84+
t = benchmark.Timer(
85+
stmt="decode_full_video(video_path, device)",
86+
globals={
87+
"device": device,
88+
"video_path": video_path,
89+
"decode_full_video": decode_full_video,
90+
},
91+
label="Decode+Resize Time",
92+
sub_label=f"video={os.path.basename(video_path)}",
93+
description=f"decode_device={device}",
94+
).blocked_autorange()
95+
results.append(t)
13696
compare = benchmark.Compare(results)
13797
compare.print()
138-
print("Key: D=Decode, R=Resize")
139-
print("Native resize is done as part of the decode step")
140-
print("none resize means there is no resize step -- native or otherwise")
14198

14299

143100
if __name__ == "__main__":

examples/basic_example.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,3 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
171171
# %%
172172
plot(frame_at_2_seconds.data, "Frame displayed at 2 seconds")
173173
plot(first_two_seconds.data, "Frames displayed during [0, 2) seconds")
174-
175-
# %%
176-
# Using a CUDA GPU to accelerate decoding
177-
# ---------------------------------------
178-
#
179-
# If you have a CUDA GPU that has NVDEC, you can decode on the GPU.
180-
if torch.cuda.is_available():
181-
cuda_decoder = SimpleVideoDecoder(raw_video_bytes, device="cuda:0")
182-
cuda_frame = cuda_decoder.get_frame_displayed_at(seconds=2)
183-
print(cuda_frame.data.device) # should be cuda:0
184-
plot(cuda_frame.data.to("cpu"), "Frame displayed at 2 seconds on CUDA")

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getDecodedOutputWithFilter(
759759
if (activeStreamIndices_.size() == 0) {
760760
throw std::runtime_error("No active streams configured.");
761761
}
762-
VLOG(9) << "Starting getNextDecodedOutputNoDemux()";
762+
VLOG(9) << "Starting getNextDecodedOutput()";
763763
resetDecodeStats();
764764
if (maybeDesiredPts_.has_value()) {
765765
VLOG(9) << "maybeDesiredPts_=" << *maybeDesiredPts_;
@@ -920,7 +920,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
920920
return output;
921921
}
922922

923-
VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestampNoDemux(
923+
VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestamp(
924924
double seconds) {
925925
for (auto& [streamIndex, stream] : streams_) {
926926
double frameStartTime = ptsToSeconds(stream.currentPts, stream.timeBase);
@@ -985,7 +985,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
985985

986986
int64_t pts = stream.allFrames[frameIndex].pts;
987987
setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase));
988-
return getNextDecodedOutputNoDemux();
988+
return getNextDecodedOutput();
989989
}
990990

991991
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndexes(
@@ -1138,7 +1138,7 @@ VideoDecoder::getFramesDisplayedByTimestampInRange(
11381138
return output;
11391139
}
11401140

1141-
VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux() {
1141+
VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutput() {
11421142
return getDecodedOutputWithFilter(
11431143
[this](int frameStreamIndex, AVFrame* frame) {
11441144
StreamInfo& activeStream = streams_[frameStreamIndex];

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ class VideoDecoder {
153153

154154
// ---- SINGLE FRAME SEEK AND DECODING API ----
155155
// Places the cursor at the first frame on or after the position in seconds.
156-
// Calling getNextDecodedOutputNoDemux() will return the first frame at or
157-
// after this position.
156+
// Calling getNextFrameAsTensor() will return the first frame at or after this
157+
// position.
158158
void setCursorPtsInSeconds(double seconds);
159159
struct DecodedOutput {
160160
// The actual decoded output as a Tensor.
@@ -180,14 +180,13 @@ class VideoDecoder {
180180
};
181181
// Decodes the frame where the current cursor position is. It also advances
182182
// the cursor to the next frame.
183-
DecodedOutput getNextDecodedOutputNoDemux();
184-
// Decodes the first frame in any added stream that is visible at a given
185-
// timestamp. Frames in the video have a presentation timestamp and a
186-
// duration. For example, if a frame has presentation timestamp of 5.0s and a
187-
// duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0).
188-
// i.e. it will be returned when this function is called with seconds=5.0 or
189-
// seconds=5.999, etc.
190-
DecodedOutput getFrameDisplayedAtTimestampNoDemux(double seconds);
183+
DecodedOutput getNextDecodedOutput();
184+
// Decodes the frame that is visible at a given timestamp. Frames in the video
185+
// have a presentation timestamp and a duration. For example, if a frame has
186+
// presentation timestamp of 5.0s and a duration of 1.0s, it will be visible
187+
// in the timestamp range [5.0, 6.0). i.e. it will be returned when this
188+
// function is called with seconds=5.0 or seconds=5.999, etc.
189+
DecodedOutput getFrameDisplayedAtTimestamp(double seconds);
191190
DecodedOutput getFrameAtIndex(int streamIndex, int64_t frameIndex);
192191
struct BatchDecodedOutput {
193192
torch::Tensor frames;

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder) {
147147
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
148148
VideoDecoder::DecodedOutput result;
149149
try {
150-
result = videoDecoder->getNextDecodedOutputNoDemux();
150+
result = videoDecoder->getNextDecodedOutput();
151151
} catch (const VideoDecoder::EndOfFileException& e) {
152152
throw pybind11::stop_iteration(e.what());
153153
}
@@ -161,7 +161,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder) {
161161

162162
OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds) {
163163
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
164-
auto result = videoDecoder->getFrameDisplayedAtTimestampNoDemux(seconds);
164+
auto result = videoDecoder->getFrameDisplayedAtTimestamp(seconds);
165165
return makeOpsDecodedOutput(result);
166166
}
167167

src/torchcodec/decoders/_simple_video_decoder.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pathlib import Path
1010
from typing import Iterable, Iterator, Literal, Tuple, Union
1111

12-
from torch import device as torch_device, Tensor
12+
from torch import Tensor
1313

1414
from torchcodec.decoders import _core as core
1515

@@ -89,14 +89,6 @@ class SimpleVideoDecoder:
8989
This can be either "NCHW" (default) or "NHWC", where N is the batch
9090
size, C is the number of channels, H is the height, and W is the
9191
width of the frames.
92-
device (torch.device, optional): The device to use for decoding.
93-
Currently we only support CPU and CUDA devices. If CUDA is used,
94-
we use NVDEC and CUDA to do decoding and color-conversion
95-
respectively. The resulting frame is left on the GPU for further
96-
processing.
97-
You can either pass in a string like "cpu" or "cuda:0" or a
98-
torch.device like torch.device("cuda:0").
99-
Default: ``torch.device("cpu")``.
10092
10193
.. note::
10294
@@ -114,7 +106,6 @@ def __init__(
114106
self,
115107
source: Union[str, Path, bytes, Tensor],
116108
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
117-
device: Union[str, torch_device] = torch_device("cpu"),
118109
):
119110
if isinstance(source, str):
120111
self._decoder = core.create_from_file(source)
@@ -138,20 +129,7 @@ def __init__(
138129
)
139130

140131
core.scan_all_streams_to_update_metadata(self._decoder)
141-
num_threads = None
142-
if isinstance(device, str):
143-
device = torch_device(device)
144-
if device.type == "cuda":
145-
# Using multiple CPU threads seems to slow down decoding on CUDA.
146-
# CUDA internally uses dedicated hardware to do decoding so we
147-
# don't need CPU software threads here.
148-
num_threads = 1
149-
core.add_video_stream(
150-
self._decoder,
151-
dimension_order=dimension_order,
152-
device_string=str(device),
153-
num_threads=num_threads,
154-
)
132+
core.add_video_stream(self._decoder, dimension_order=dimension_order)
155133

156134
self.metadata, self._stream_index = _get_and_validate_stream_metadata(
157135
self._decoder

0 commit comments

Comments
 (0)