@@ -1197,22 +1197,52 @@ def test_pts_to_dts_fallback(self, seek_mode):
1197
1197
torch .testing .assert_close (decoder [0 ], decoder [10 ])
1198
1198
1199
1199
@needs_cuda
1200
- @pytest .mark .parametrize ("asset" , (H264_10BITS , H265_10BITS ))
1201
- def test_10bit_videos_cuda (self , asset ):
1200
+ def test_10bit_videos_cuda (self ):
1202
1201
# Assert that we raise proper error on different kinds of 10bit videos.
1203
1202
1204
1203
# TODO we should investigate how to support 10bit videos on GPU.
1205
1204
# See https://github.com/pytorch/torchcodec/issues/776
1206
1205
1207
- decoder = VideoDecoder ( asset . path , device = "cuda" )
1206
+ asset = H265_10BITS
1208
1207
1209
- if asset is H265_10BITS :
1210
- match = "The AVFrame is p010le, but we expected AV_PIX_FMT_NV12."
1211
- else :
1212
- match = "Expected format to be AV_PIX_FMT_CUDA, got yuv420p10le."
1213
- with pytest . raises ( RuntimeError , match = match ):
1208
+ decoder = VideoDecoder ( asset . path , device = "cuda" )
1209
+ with pytest . raises (
1210
+ RuntimeError ,
1211
+ match = "The AVFrame is p010le, but we expected AV_PIX_FMT_NV12." ,
1212
+ ):
1214
1213
decoder .get_frame_at (0 )
1215
1214
1215
+ @needs_cuda
1216
+ def test_10bit_gpu_fallsback_to_cpu (self ):
1217
+ # Test for 10-bit videos that aren't supported by NVDEC: we decode and
1218
+ # do the color conversion on the CPU.
1219
+ # Here we just assert that the GPU results are the same as the CPU
1220
+ # results.
1221
+ # TODO see other TODO below in test_10bit_videos_cpu: we should validate
1222
+ # the frames against a reference.
1223
+
1224
+ # We know from previous tests that the H264_10BITS video isn't supported
1225
+ # by NVDEC, so NVDEC decodes it on the CPU.
1226
+ asset = H264_10BITS
1227
+
1228
+ decoder_gpu = VideoDecoder (asset .path , device = "cuda" )
1229
+ decoder_cpu = VideoDecoder (asset .path )
1230
+
1231
+ frame_indices = [0 , 10 , 20 , 5 ]
1232
+ for frame_index in frame_indices :
1233
+ frame_gpu = decoder_gpu .get_frame_at (frame_index ).data
1234
+ assert frame_gpu .device .type == "cuda"
1235
+ frame_cpu = decoder_cpu .get_frame_at (frame_index ).data
1236
+ assert_frames_equal (frame_gpu .cpu (), frame_cpu )
1237
+
1238
+ # We also check a batch API just to be on the safe side, making sure the
1239
+ # pre-allocated tensor is passed down correctly to the CPU
1240
+ # implementation.
1241
+ frames_gpu = decoder_gpu .get_frames_at (frame_indices ).data
1242
+ assert frames_gpu .device .type == "cuda"
1243
+ frames_cpu = decoder_cpu .get_frames_at (frame_indices ).data
1244
+ assert_frames_equal (frames_gpu .cpu (), frames_cpu )
1245
+
1216
1246
@pytest .mark .parametrize ("asset" , (H264_10BITS , H265_10BITS ))
1217
1247
def test_10bit_videos_cpu (self , asset ):
1218
1248
# This just validates that we can decode 10-bit videos on CPU.
0 commit comments