@@ -95,7 +95,6 @@ def decode_av(pipe_input, frame_callback, put_metadata, output_width, output_hei
9595 frame = cast (av .VideoFrame , frame )
9696 if frame .pts is None :
9797 continue
98-
9998 # drop frames that come in too fast
10099 # TODO also check timing relative to wall clock
101100 pts_time = frame .time
@@ -108,15 +107,27 @@ def decode_av(pipe_input, frame_callback, put_metadata, output_width, output_hei
108107 else :
109108 # not delayed, so use prev pts to allow more jitter
110109 next_pts_time = next_pts_time + frame_interval
111-
112- # h = 512
113- # w = int((512 * frame.width / frame.height) / 2) * 2 # force divisible by 2
114- # if frame.height > frame.width:
115- # w = 512
116- # h = int((512 * frame.height / frame.width) / 2) * 2
117-
118- frame = reformatter .reformat (frame , format = 'rgba' , width = output_width , height = output_height )
119- avframe = InputFrame .from_av_video (frame )
110+ # Convert frame to image
111+ image = frame .to_image ()
112+ if image .mode != "RGB" :
113+ image = image .convert ("RGB" )
114+ width , height = image .size
115+
116+ if output_width == output_height and width != height :
117+ # Crop to center square if output is square but input isn't
118+ square_size = min (width , height )
119+ start_x = width // 2 - square_size // 2
120+ start_y = height // 2 - square_size // 2
121+ image = image .crop ((start_x , start_y , start_x + square_size , start_y + square_size ))
122+ elif (output_width , output_height ) != (width , height ):
123+ # Resize if dimensions don't match output
124+ image = image .resize ((output_width , output_height ))
125+
126+ # Convert to tensor
127+ image_np = np .array (image ).astype (np .float32 ) / 255.0
128+ tensor = torch .tensor (image_np ).unsqueeze (0 )
129+
130+ avframe = InputFrame .from_av_video (tensor , frame .pts , frame .time_base )
120131 avframe .log_timestamps ["frame_init" ] = time .time ()
121132 frame_callback (avframe )
122133 continue
0 commit comments