@@ -271,7 +271,7 @@ class PoolWithProgress:
271271 def __init__ (
272272 self ,
273273 progress_bar : ProgressBar ,
274- process_count : int = os . cpu_count () ,
274+ process_count : int ,
275275 refresh_rate_seconds : float = DEF_MAX_REFRESH_RATE ,
276276 sub_ticks : int = DEF_SUB_TICKS ,
277277 max_worker_reuse : Optional [int ] = None ,
@@ -764,8 +764,10 @@ def get_maximums(
764764 return poses
765765
766766 def _run_full_passes (self , progress_bar : Optional [ProgressBar ]):
767+ thread_count = self ._get_thread_count ()
768+
767769 for (i , frame_pass_builder ) in enumerate (self .FULL_PASSES ):
768- frame_pass = frame_pass_builder (self ._width , self ._height , True )
770+ frame_pass = frame_pass_builder (self ._width , self ._height , thread_count )
769771
770772 if (progress_bar is not None ):
771773 progress_bar .message (
@@ -861,7 +863,7 @@ def _run_segment_pre_initialized(
861863 frame_pass_builders : List [FramePassBuilder ],
862864 width : int ,
863865 height : int ,
864- allow_multi_threading : bool ,
866+ thread_count : int ,
865867 fix_frame_idx : int ,
866868 fix_frame_score : float ,
867869 progress_bar : Optional [ProgressBar ] = None ,
@@ -871,7 +873,7 @@ def _run_segment_pre_initialized(
871873 frame_pass_builders ,
872874 width ,
873875 height ,
874- allow_multi_threading ,
876+ thread_count ,
875877 fix_frame_idx ,
876878 fix_frame_score ,
877879 progress_bar ,
@@ -885,7 +887,7 @@ def _run_segment(
885887 frame_pass_builders : List [FramePassBuilder ],
886888 width : int ,
887889 height : int ,
888- allow_multi_threading : bool ,
890+ thread_count : int ,
889891 fix_frame_idx : int ,
890892 fix_frame_score ,
891893 progress_bar : Optional [ProgressBar ] = None ,
@@ -925,7 +927,7 @@ def _run_segment(
925927 progress_bar .inc_rerun_counter ()
926928
927929 for (i , frame_pass_builder ) in enumerate (frame_pass_builders ):
928- frame_pass = frame_pass_builder (width , height , allow_multi_threading )
930+ frame_pass = frame_pass_builder (width , height , thread_count )
929931
930932 sub_frame = frame_pass .run_pass (
931933 sub_frame ,
@@ -950,7 +952,14 @@ def _get_segment(self, index: int):
950952 sub_frame .frames = self ._frame_holder .frames [start :end ]
951953 sub_frame .metadata = self ._frame_holder .metadata
952954
953- return (sub_frame , self .SEGMENTED_PASSES , self ._width , self ._height , False , fix_frame - start , segment_score )
955+ return (
956+ sub_frame ,
957+ self .SEGMENTED_PASSES ,
958+ self ._width ,
959+ self ._height ,
960+ 0 ,
961+ fix_frame - start ,
962+ segment_score )
954963
955964 def _set_segment (self , index : int , frame_data : ForwardBackwardData ):
956965 start , end , fix_frame = self ._segments [index ]
@@ -1047,31 +1056,20 @@ def _run_segmented_passes(
10471056
10481057 self ._frame_holder .allow_pickle = False
10491058
1050- if (thread_count <= 0 ):
1059+ if (thread_count <= 1 ):
10511060 pass_count = (len (self .SEGMENTED_PASSES ) + 1 ) * total_segments
10521061
1053- passes_can_use_pool = any (b .clazz .UTILIZE_GLOBAL_POOL for b in self .SEGMENTED_PASSES )
1054- allow_multithread = self .settings .allow_pass_multithreading
1055-
10561062 wrapper_bar = NestedProgressIndicator (
10571063 progress_bar ,
10581064 total = pass_count ,
10591065 ticks = int (self ._frame_holder .num_frames / pass_count )
10601066 )
10611067 progress_bar .message ("Running on Segments..." )
10621068
1063- if (passes_can_use_pool and allow_multithread ):
1064- with PoolWithProgress .get_optimal_ctx ().Pool (processes = os .cpu_count ()) as pool :
1065- FramePass .GLOBAL_POOL = AntiCloseObject (pool )
1066- for is_pre_init , segment_idx in self ._iter_run_levels (segment_idxs , run_level_data ):
1067- for idx in segment_idx :
1068- frm , segs , width , height , __ , fix_frame_idx , fix_frame_score = self ._get_segment (idx )
1069- self ._run_segment (frm , segs , width , height , self .settings .allow_pass_multithreading , fix_frame_idx , fix_frame_score , wrapper_bar , is_pre_init )
1070- else :
1071- for is_pre_init , segment_idx in self ._iter_run_levels (segment_idxs , run_level_data ):
1072- for idx in segment_idx :
1073- frm , segs , width , height , __ , fix_frame_idx , fix_frame_score = self ._get_segment (idx )
1074- self ._run_segment (frm , segs , width , height , self .settings .allow_pass_multithreading , fix_frame_idx , fix_frame_score , wrapper_bar , is_pre_init )
1069+ for is_pre_init , segment_idx in self ._iter_run_levels (segment_idxs , run_level_data ):
1070+ for idx in segment_idx :
1071+ frm , segs , width , height , __ , fix_frame_idx , fix_frame_score = self ._get_segment (idx )
1072+ self ._run_segment (frm , segs , width , height , thread_count , fix_frame_idx , fix_frame_score , wrapper_bar , is_pre_init )
10751073
10761074 FramePass .GLOBAL_POOL = None
10771075 else :
@@ -1534,11 +1532,6 @@ def get_settings(cls) -> ConfigSpec:
15341532 "Defaults to None, which resolves to os.cpu_count() at runtime. "
15351533 "If set to 0, disables multithreading..."
15361534 ),
1537- "allow_pass_multithreading" : (
1538- True ,
1539- bool ,
1540- "Whether or not to allow frame passes to utilize multithreading. Defaults to True."
1541- ),
15421535 "segment_size" : (
15431536 200 ,
15441537 type_casters .RangedInteger (10 , np .inf ),
0 commit comments