Skip to content

Commit 86c416d

Browse files
authored
Add enable_frame_num='sequence' mode to video readers. (#6237)
- extends the `enable_frame_num` argument in both the legacy (`fn.readers.video`) and experimental (`fn.experimental.readers.video`) video reader operators from a boolean to a string enum, following the same convention as `out_of_bounds_policy`: * ``"none"``/`False` (default) - no frame number output (previous `False`) * ``"scalar"``/`True` - returns the index of the first frame in the decoded sequence as a scalar output with shape `(1,)` (previous `True`) * ``"sequence"`` - returns the frame index of each decoded frame as an additional output with shape `(F,)`; padded frames get index `-1` - the `FrameNumPolicy` enum and `ParseFrameNumPolicy` helper are added to `video_utils.h` and shared by both readers. - tests are added for the `sequence` mode covering basic stride behavior, constant-padding (``-1`` sentinel), and consistency between `"scalar"` and `"sequence"` outputs. Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
1 parent ff8655b commit 86c416d

File tree

8 files changed

+418
-35
lines changed

8 files changed

+418
-35
lines changed

dali/operators/video/legacy/reader/nvdecoder/sequencewrapper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ struct SequenceWrapper {
4848

4949
timestamps.clear();
5050
timestamps.reserve(max_count);
51+
frame_idxs.clear();
5152

5253
if (!event_) {
5354
event_ = CUDAEvent::CreateWithFlags(cudaEventBlockingSync | cudaEventDisableTiming);
@@ -83,6 +84,7 @@ struct SequenceWrapper {
8384
int channels = -1;
8485
int label = -1;
8586
vector<double> timestamps;
87+
vector<int> frame_idxs;
8688
int first_frame_idx = -1;
8789
DALIDataType dtype = DALI_NO_TYPE;
8890
std::function<void(void)> read_sample_f;

dali/operators/video/legacy/reader/video_reader_op.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,14 @@ sequence and a warning. This option is mutually exclusive with `filenames`
108108
and `file_root`.)code",
109109
std::string())
110110
.AddOptionalArg("enable_frame_num",
111-
R"code(If the `file_list` or `filenames` argument is passed, returns the frame number
112-
output.)code",
113-
false)
111+
R"code(Determines what frame number information is returned as an additional output.
112+
Only available when `file_list` or `filenames` with `labels` is passed.
113+
114+
* ``None`` or ``False`` (default): No frame number output.
115+
* ``"scalar"`` or ``True``: Returns the index of the first frame in the decoded sequence, shape ``(1,)``.
116+
* ``"sequence"``: Returns the frame index of each decoded frame, shape ``(F,)``. For padded
117+
frames, the index is ``-1``.)code",
118+
std::string("none"))
114119
.AddOptionalArg("enable_timestamps",
115120
R"code(If the `file_list` or `filenames` argument is passed, returns the timestamps
116121
output. )code",

dali/operators/video/legacy/reader/video_reader_op.h

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@ inline int VideoReaderOutputFn(const OpSpec &spec) {
3030
std::vector<std::string> file_names = spec.GetRepeatedArgument<std::string>("filenames");
3131
std::vector<int> labels;
3232
bool has_labels_arg = spec.TryGetRepeatedArgument(labels, "labels");
33-
bool enable_frame_num = spec.GetArgument<bool>("enable_frame_num");
33+
FrameNumPolicy frame_num_policy =
34+
ParseFrameNumPolicy(spec.GetArgument<std::string>("enable_frame_num"));
3435
bool enable_timestamps = spec.GetArgument<bool>("enable_timestamps");
3536
int num_outputs = 1;
3637
if ((!file_names.empty() && has_labels_arg) || !file_root.empty() || !file_list.empty()) {
3738
++num_outputs;
3839
}
3940
if (!file_list.empty() || !file_names.empty()) {
40-
if (enable_frame_num) num_outputs++;
41+
if (frame_num_policy != FrameNumPolicy::None) num_outputs++;
4142
if (enable_timestamps) num_outputs++;
4243
}
4344
return num_outputs;
@@ -51,9 +52,10 @@ class VideoReader : public DataReader<GPUBackend, SequenceWrapper, SequenceWrapp
5152
filenames_(spec.GetRepeatedArgument<std::string>("filenames")),
5253
file_root_(spec.GetArgument<std::string>("file_root")),
5354
file_list_(spec.GetArgument<std::string>("file_list")),
54-
enable_frame_num_(spec.GetArgument<bool>("enable_frame_num")),
55+
frame_num_policy_(ParseFrameNumPolicy(spec.GetArgument<std::string>("enable_frame_num"))),
5556
enable_timestamps_(spec.GetArgument<bool>("enable_timestamps")),
5657
count_(spec.GetArgument<int>("sequence_length")),
58+
stride_(spec.GetArgument<int>("stride")),
5759
channels_(spec.GetArgument<int>("channels")),
5860
dtype_(spec.GetArgument<DALIDataType>("dtype")) {
5961
DALIImageType image_type(spec.GetArgument<DALIImageType>("image_type"));
@@ -75,7 +77,7 @@ class VideoReader : public DataReader<GPUBackend, SequenceWrapper, SequenceWrapp
7577

7678
can_use_frames_timestamps_ = !file_list_.empty() || (!filenames_.empty() && has_labels_arg);
7779

78-
DALI_ENFORCE(can_use_frames_timestamps_ || !enable_frame_num_,
80+
DALI_ENFORCE(can_use_frames_timestamps_ || frame_num_policy_ == FrameNumPolicy::None,
7981
"frame numbers can be enabled only when "
8082
"`file_list`, or `filenames` with `labels` argument are passed");
8183
DALI_ENFORCE(can_use_frames_timestamps_ || !enable_timestamps_,
@@ -99,7 +101,10 @@ class VideoReader : public DataReader<GPUBackend, SequenceWrapper, SequenceWrapp
99101
label_shape_ = uniform_list_shape(max_batch_size_, {1});
100102

101103
if (can_use_frames_timestamps_) {
102-
if (enable_frame_num_) frame_num_shape_ = label_shape_;
104+
if (frame_num_policy_ == FrameNumPolicy::Scalar)
105+
frame_num_shape_ = label_shape_;
106+
else if (frame_num_policy_ == FrameNumPolicy::Sequence)
107+
frame_num_shape_ = uniform_list_shape(max_batch_size_, {count_});
103108
if (enable_timestamps_) timestamp_shape_ = uniform_list_shape(max_batch_size_, {count_});
104109
}
105110

@@ -134,7 +139,7 @@ class VideoReader : public DataReader<GPUBackend, SequenceWrapper, SequenceWrapp
134139
label_output_ = &ws.Output<GPUBackend>(output_index++);
135140
label_output_->Resize(label_shape_, DALI_INT32);
136141
if (can_use_frames_timestamps_) {
137-
if (enable_frame_num_) {
142+
if (frame_num_policy_ != FrameNumPolicy::None) {
138143
frame_num_output_ = &ws.Output<GPUBackend>(output_index++);
139144
frame_num_output_->Resize(frame_num_shape_, DALI_INT32);
140145
}
@@ -163,10 +168,28 @@ class VideoReader : public DataReader<GPUBackend, SequenceWrapper, SequenceWrapp
163168
CUDA_CALL(
164169
cudaMemcpyAsync(label, &prefetched_video.label, sizeof(int), cudaMemcpyDefault, stream));
165170
if (can_use_frames_timestamps_) {
166-
if (enable_frame_num_) {
171+
if (frame_num_policy_ == FrameNumPolicy::Scalar) {
167172
auto *frame_num = frame_num_output_->mutable_tensor<int>(data_idx);
168173
CUDA_CALL(cudaMemcpyAsync(frame_num, &prefetched_video.first_frame_idx, sizeof(int),
169174
cudaMemcpyDefault, stream));
175+
} else if (frame_num_policy_ == FrameNumPolicy::Sequence) {
176+
// Compute per-frame frame indices from first_frame_idx and stride.
177+
// Frames beyond the actual decoded count (padded frames) get index -1.
178+
auto &idxs = prefetched_video.frame_idxs;
179+
idxs.resize(count_);
180+
for (int i = 0; i < count_; ++i) {
181+
idxs[i] = (i < prefetched_video.count)
182+
? (prefetched_video.first_frame_idx + i * stride_)
183+
: -1;
184+
}
185+
auto *frame_num_data = frame_num_output_->mutable_tensor<int>(data_idx);
186+
frame_num_output_->type_info().Copy<GPUBackend, CPUBackend>(
187+
frame_num_data,
188+
std::nullopt,
189+
idxs.data(),
190+
std::nullopt,
191+
idxs.size(),
192+
stream);
170193
}
171194
if (enable_timestamps_) {
172195
auto *timestamp = timestamp_output_->mutable_tensor<double>(data_idx);
@@ -212,9 +235,10 @@ class VideoReader : public DataReader<GPUBackend, SequenceWrapper, SequenceWrapp
212235
std::vector<int> labels_;
213236
std::string file_root_;
214237
std::string file_list_;
215-
bool enable_frame_num_;
238+
FrameNumPolicy frame_num_policy_;
216239
bool enable_timestamps_;
217240
int count_;
241+
int stride_;
218242
int channels_;
219243

220244
TensorListShape<> label_shape_;

dali/operators/video/legacy/reader/video_reader_op_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ TEST_F(VIDEO_READER_TEST_CLASS, FrameLabels) {
469469
.AddArg("device", "gpu")
470470
.AddArg("random_shuffle", false)
471471
.AddArg("sequence_length", sequence_length)
472-
.AddArg("enable_frame_num", true)
472+
.AddArg("enable_frame_num", "scalar")
473473
.AddArg("image_type", DALI_YCbCr)
474474
.AddArg("file_list", file_list_path)
475475
.AddOutput("frames", StorageDevice::GPU)
@@ -510,7 +510,7 @@ TEST_F(VIDEO_READER_TEST_CLASS, FrameLabelsFilenames) {
510510
.AddArg("device", "gpu")
511511
.AddArg("random_shuffle", false)
512512
.AddArg("sequence_length", sequence_length)
513-
.AddArg("enable_frame_num", true)
513+
.AddArg("enable_frame_num", "scalar")
514514
.AddArg("image_type", DALI_YCbCr)
515515
.AddArg("filenames", std::vector<std::string>{testing::dali_extra_path() +
516516
"/db/video/frame_num_timestamp/test.mp4"})
@@ -558,7 +558,7 @@ TEST_F(VIDEO_READER_TEST_CLASS, LabelsFilenames) {
558558
.AddArg("device", "gpu")
559559
.AddArg("random_shuffle", false)
560560
.AddArg("sequence_length", sequence_length)
561-
.AddArg("enable_frame_num", true)
561+
.AddArg("enable_frame_num", "scalar")
562562
.AddArg("image_type", DALI_YCbCr)
563563
.AddArg("filenames", std::vector<std::string>{testing::dali_extra_path() +
564564
"/db/video/frame_num_timestamp/test.mp4"})
@@ -621,7 +621,7 @@ TEST_F(VIDEO_READER_TEST_CLASS, FrameLabelsWithFileListFrameNum) {
621621
.AddArg("device", "gpu")
622622
.AddArg("random_shuffle", false)
623623
.AddArg("sequence_length", sequence_length)
624-
.AddArg("enable_frame_num", true)
624+
.AddArg("enable_frame_num", "scalar")
625625
.AddArg("enable_timestamps", true)
626626
.AddArg("file_list_frame_num", true)
627627
.AddArg("file_list_format", "frames") // equivalent to file_list_frame_num in the old decoder
@@ -702,7 +702,7 @@ TEST_F(VIDEO_READER_TEST_CLASS, TimestampLabels) {
702702
.AddArg("random_shuffle", false)
703703
.AddArg("sequence_length", sequence_length)
704704
.AddArg("enable_timestamps", true)
705-
.AddArg("enable_frame_num", true)
705+
.AddArg("enable_frame_num", "scalar")
706706
.AddArg("image_type", DALI_YCbCr)
707707
.AddArg("file_list", file_list_path)
708708
.AddOutput("frames", StorageDevice::GPU)

dali/operators/video/reader/video_reader_decoder_op.cc

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ struct VideoSample : public VideoSampleDesc {
9898
// to be filled by Prefetch
9999
Tensor<Backend> data_;
100100
std::vector<double> timestamps_;
101-
std::vector<int64_t> frame_idx_;
101+
std::vector<int32_t> frame_idx_;
102102
};
103103

104104
enum class FileListFormat {
@@ -394,7 +394,7 @@ class VideoReaderDecoder
394394

395395
explicit VideoReaderDecoder(const OpSpec &spec)
396396
: Base(spec),
397-
has_frame_idx_(spec.GetArgument<bool>("enable_frame_num")),
397+
frame_num_policy_(ParseFrameNumPolicy(spec.GetArgument<std::string>("enable_frame_num"))),
398398
has_timestamps_(spec.GetArgument<bool>("enable_timestamps")),
399399
boundary_type_(GetBoundaryType(spec)),
400400
image_type_(spec.GetArgument<DALIImageType>("image_type")) {
@@ -465,9 +465,16 @@ class VideoReaderDecoder
465465
output_desc.push_back({label_shape, DALI_INT32});
466466
}
467467

468-
if (has_frame_idx_) {
468+
if (frame_num_policy_ == FrameNumPolicy::Scalar) {
469469
TensorListShape<1> frame_idx_shape = uniform_list_shape<1>(batch_size, {1});
470470
output_desc.push_back({frame_idx_shape, DALI_INT32});
471+
} else if (frame_num_policy_ == FrameNumPolicy::Sequence) {
472+
TensorListShape<1> frame_idx_shape(batch_size);
473+
for (int sample_id = 0; sample_id < batch_size; ++sample_id) {
474+
auto num_frames = GetSample(sample_id).data_.shape()[0];
475+
frame_idx_shape.set_tensor_shape(sample_id, {num_frames});
476+
}
477+
output_desc.push_back({frame_idx_shape, DALI_INT32});
471478
}
472479

473480
if (has_timestamps_) {
@@ -526,10 +533,14 @@ class VideoReaderDecoder
526533
return make_cspan(&s.video_file_meta_->label, 1);
527534
});
528535
}
529-
if (has_frame_idx_) {
536+
if (frame_num_policy_ == FrameNumPolicy::Scalar) {
530537
OutputMetadata<int32_t>(ws, out_index++, [](auto &s) {
531538
return make_cspan(&s.start_, 1);
532539
});
540+
} else if (frame_num_policy_ == FrameNumPolicy::Sequence) {
541+
OutputMetadata<int32_t>(ws, out_index++, [](auto &s) {
542+
return make_cspan(s.frame_idx_);
543+
});
533544
}
534545
if (has_timestamps_) {
535546
OutputMetadata<double>(ws, out_index++, [](auto &s) {
@@ -601,6 +612,17 @@ class VideoReaderDecoder
601612
<< ", boundary_type=" << to_string(boundary_type_) << std::endl;
602613
int roi_start = sample->video_file_meta_->start_frame;
603614
int roi_end = sample->video_file_meta_->end_frame;
615+
if (frame_num_policy_ == FrameNumPolicy::Sequence) {
616+
sample->frame_idx_.resize(num_frames);
617+
for (int64_t i = 0; i < num_frames; ++i) {
618+
sample->frame_idx_[i] = static_cast<int32_t>(decoder_->HandleBoundary(
619+
boundary_type_,
620+
static_cast<int>(sample->start_ + i * sample->stride_),
621+
roi_start, roi_end));
622+
}
623+
} else {
624+
sample->frame_idx_.clear();
625+
}
604626
if (roi_start != 0 || roi_end != decoder_->NumFrames()) {
605627
frame_idxs_.clear();
606628
for (int frame_idx = sample->start_; frame_idx < sample->end_;
@@ -626,7 +648,7 @@ class VideoReaderDecoder
626648
}
627649

628650
private:
629-
bool has_frame_idx_;
651+
FrameNumPolicy frame_num_policy_;
630652
bool has_timestamps_;
631653
boundary::BoundaryType boundary_type_;
632654
DALIImageType image_type_;
@@ -658,22 +680,28 @@ The following codecs are supported by the GPU backend only:
658680
* AV1
659681
* MPEG-4
660682
661-
The outputs of the operator are: video, [labels], [frame_idx], [timestamp].
683+
The outputs of the operator are: video, [labels], [frame_num], [timestamps].
662684
663685
* ``video``: A sequence of frames with shape ``(F, H, W, C)`` where ``F`` is the number of frames in the sequence
664686
(can vary between samples), ``H`` is the frame height in pixels, ``W`` is the frame width in pixels, and ``C`` is
665687
the number of color channels.
666688
* ``labels``: Label associated with the sample. Only available when using ``labels`` with ``filenames``, or when
667689
using ``file_list`` or ``file_root``.
668-
* ``frame_idx``: Index of first frame in sequence. Only available when ``enable_frame_num=True``.
690+
* ``frame_num``: Frame number information. Shape and content depend on ``enable_frame_num``:
691+
692+
* ``"scalar"`` or ``True``: Index of the first frame in the decoded sequence, shape ``(1,)``.
693+
* ``"sequence"``: Frame index of each decoded frame, shape ``(F,)``. Padded frames (e.g. when
694+
using ``pad_mode='constant'``) have index ``-1``.
669695
* ``timestamps``: Time in seconds of each frame in the sequence. Only available when ``enable_timestamps=True``.
670696
)code")
671697
.NumInput(0)
672698
.OutputFn([](const OpSpec &spec) {
673699
bool has_labels = spec.HasArgument("labels") || spec.HasArgument("file_list") ||
674700
spec.HasArgument("file_root");
675-
return 1 + has_labels + spec.GetArgument<bool>("enable_frame_num") +
676-
spec.GetArgument<bool>("enable_timestamps");
701+
bool has_frame_num =
702+
ParseFrameNumPolicy(spec.GetArgument<std::string>("enable_frame_num")) !=
703+
FrameNumPolicy::None;
704+
return 1 + has_labels + has_frame_num + spec.GetArgument<bool>("enable_timestamps");
677705
})
678706
.AddOptionalArg("filenames",
679707
R"code(Absolute paths to the video files to load.
@@ -705,7 +733,7 @@ Default: ``timestamps``.)code",
705733
R"code(How to handle non-exact frame matches:
706734
707735
* ``start_down_end_up`` (default): Round start down and end up
708-
* ``start_up_end_down``: Round start up and end down
736+
* ``start_up_end_down``: Round start up and end down
709737
* ``all_up``: Round both up
710738
* ``all_down``: Round both down)code",
711739
"start_down_end_up")
@@ -717,9 +745,13 @@ Default: ``timestamps``.)code",
717745
nullptr)
718746
.AddArg("sequence_length", R"code(Frames to load per sequence.)code", DALI_INT32)
719747
.AddOptionalArg("enable_frame_num",
720-
R"code(If set, returns the index of the first frame in the decoded sequence
721-
as an additional output.)code",
722-
false)
748+
R"code(Determines what frame number information is returned as an additional output.
749+
750+
* ``"none"`` or ``False`` (default): No frame number output.
751+
* ``"scalar"`` or ``True``: Returns the index of the first frame in the decoded sequence, shape ``(1,)``.
752+
* ``"sequence"``: Returns the frame index of each decoded frame, shape ``(F,)``. For padded
753+
frames (e.g. when using ``pad_mode='constant'``), the index is ``-1``.)code",
754+
std::string("none"))
723755
.AddOptionalArg("enable_timestamps",
724756
R"code(If set, returns the timestamp of the frames in the decoded sequence
725757
as an additional output.)code",
@@ -736,7 +768,7 @@ When the value is less than 0, `step` is set to `sequence_length`.)code",
736768
R"code(How to handle videos with insufficient frames when using start_frame/sequence_length/stride:
737769
738770
* ``'none'``: Return shorter sequences if not enough frames: ABC -> ABC
739-
* ``'constant'``: Pad with a fixed value (specified by ``pad_value``): ABC -> ABCPPP
771+
* ``'constant'``: Pad with a fixed value (specified by ``pad_value``): ABC -> ABCPPP
740772
* ``'edge'`` or ``'repeat'``: Repeat the last valid frame: ABC -> ABCCCC
741773
* ``'reflect_1001'`` or ``'symmetric'``: Reflect padding, including the last element: ABC -> ABCCBA
742774
* ``'reflect_101'`` or ``'reflect'``: Reflect padding, not including the last element: ABC -> ABCBA
@@ -747,7 +779,7 @@ Not relevant when using ``frames`` argument.)code",
747779
R"code(Value(s) used to pad missing frames when ``pad_mode='constant'``'.
748780
749781
Each value must be in range [0, 255].
750-
If a single value is provided, it will be used for all channels.
782+
If a single value is provided, it will be used for all channels.
751783
Otherwise, the number of values must match the number of channels in the video.)code",
752784
std::vector<int>{
753785
0,

dali/operators/video/reader/video_reader_decoder_op_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class VideoReaderDecoderBaseTest : public VideoTestBase {
137137
.AddArg("device", backend)
138138
.AddArg("sequence_length", sequence_length)
139139
.AddArg("random_shuffle", true)
140-
.AddArg("enable_frame_num", true)
140+
.AddArg("enable_frame_num", "scalar")
141141
.AddArg("initial_fill", cfr_videos_[0].NumFrames())
142142
.AddArg(
143143
"filenames",

dali/operators/video/video_utils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,22 @@ std::vector<VideoFileMeta> GetVideoFiles(const std::string& file_root,
5959
const std::vector<int>& labels,
6060
const std::string& file_list);
6161

62+
enum class FrameNumPolicy {
63+
None, // no frame number output
64+
Scalar, // first frame index as a scalar with shape (1,)
65+
Sequence // per-frame indices with shape (F,); padded frames get -1
66+
};
67+
68+
inline FrameNumPolicy ParseFrameNumPolicy(const std::string &s) {
69+
// "True"/"False" are the Python str(bool) representations, kept for backward compatibility
70+
// with code that passes enable_frame_num=True/False (Python bools).
71+
if (s == "none" || s == "False") return FrameNumPolicy::None;
72+
if (s == "scalar" || s == "True") return FrameNumPolicy::Scalar;
73+
if (s == "sequence") return FrameNumPolicy::Sequence;
74+
DALI_FAIL(make_string("Invalid enable_frame_num value: '", s,
75+
"'. Valid values are: 'none', 'scalar', 'sequence'."));
76+
}
77+
6278
inline boundary::BoundaryType GetBoundaryType(const OpSpec &spec) {
6379
auto pad_mode_str = spec.template GetArgument<std::string>("pad_mode");
6480
boundary::BoundaryType boundary_type = boundary::BoundaryType::ISOLATED;

0 commit comments

Comments
 (0)