@@ -145,6 +145,8 @@ def get_devices():
145145 "IPD" : 0.064 ,
146146 "Display Mode" : "Half-SBS" ,
147147 "FP16" : True ,
148+ "Inference Optimizer" : "None" ,
149+ "Recompile TensorRT" : False ,
148150 "Download Path" : "models" ,
149151 "HF Endpoint" : "https://hf-mirror.com" ,
150152 "Device" : 0 ,
@@ -171,6 +173,8 @@ def get_devices():
171173 "Anti-aliasing:" : "Anti-aliasing:" ,
172174 "Foreground Scale:" : "Foreground Scale:" ,
173175 "FP16" : "FP16" ,
176+ "Inference Optimizer:" : "Inference Optimizer:" ,
177+ "Recompile TensorRT:" : "Recompile TensorRT" ,
174178 "Download Path:" : "Download Path:" ,
175179 "Browse..." : "Browse..." ,
176180 "Stop" : "Stop" ,
@@ -231,6 +235,8 @@ def get_devices():
231235 "Anti-aliasing:" : "抗锯齿:" ,
232236 "Foreground Scale:" : "前景缩放:" ,
233237 "FP16" : "半精度浮点 (F16)" ,
238+ "Inference Optimizer:" : "推理优化器:" ,
239+ "Recompile TensorRT:" : "重新编译TensorRT" ,
234240 "Download Path:" : "下载路径:" ,
235241 "Browse..." : "浏览..." ,
236242 "Stop" : "停止" ,
@@ -284,12 +290,20 @@ def __init__(self):
284290 self .pad = {"padx" : 8 , "pady" : 6 }
285291 self .title (f"Desktop2Stereo v{ VERSION } " )
286292 self .minsize (800 , 420 ) # Increased height for new controls
287- self .resizable (False , False )
293+ self .resizable (True , False )
288294 self .language = "EN"
289295 self .loaded_model_list = DEFAULT_MODEL_LIST .copy ()
290296 self .selected_window_name = ""
291297 self ._window_objects = [] # Store window objects for reference
292298 self .cfg = {} # Store the loaded configuration
299+
300+ # Add optimizer options based on device type
301+ self .optimizer_options = {
302+ "CUDA" : ["TensorRT" , "Torch.compile" , "None" ],
303+ "DirectML" : ["Unlock Thread (Streamer)" , "None" ],
304+ # "MPS": ["Torch.compile", "None"],
305+ "Other" : ["None" ]
306+ }
293307
294308 try :
295309 icon_img = Image .open ("icon.ico" )
@@ -329,6 +343,7 @@ def __init__(self):
329343 self .language_var .set (self .language )
330344 self .protocol ("WM_DELETE_WINDOW" , self .on_close ) # Bind to Close of GUI to turn off all threads
331345 self .process = None # Keep track of the spawned process
346+ self .inference_optimizer_var .trace_add ("write" , self .on_optimizer_change )
332347
333348 def on_close (self ):
334349 """Handle GUI window closing: stop process & cleanup."""
@@ -501,12 +516,26 @@ def create_widgets(self):
501516 self .depth_model_cb .grid (row = 10 , column = 1 , columnspan = 2 , sticky = "ew" , ** self .pad )
502517 self .depth_model_cb .bind ("<<ComboboxSelected>>" , self .on_depth_model_change )
503518
519+
520+ # Add Inference Optimizer dropdown after Device selection
521+ self .label_inference_optimizer = ttk .Label (self .content_frame , text = "Inference Optimizer:" )
522+ self .label_inference_optimizer .grid (row = 11 , column = 0 , sticky = "w" , ** self .pad )
523+
524+ self .inference_optimizer_var = tk .StringVar ()
525+ self .inference_optimizer_cb = ttk .Combobox (self .content_frame , textvariable = self .inference_optimizer_var , state = "readonly" )
526+ self .inference_optimizer_cb .grid (row = 11 , column = 1 , columnspan = 2 , sticky = "ew" , ** self .pad )
527+
528+ self .recompile_trt_var = tk .BooleanVar ()
529+ self .recompile_trt_cb = ttk .Checkbutton (self .content_frame , text = "Recompile TensorRT" , variable = self .recompile_trt_var )
530+ self .recompile_trt_cb .grid (row = 11 , column = 3 , sticky = "w" , ** self .pad )
531+
504532 # HF Endpoint
505533 self .label_hf_endpoint = ttk .Label (self .content_frame , text = "HF Endpoint:" )
506- self .label_hf_endpoint .grid (row = 11 , column = 0 , sticky = "w" , ** self .pad )
534+ self .label_hf_endpoint .grid (row = 12 , column = 0 , sticky = "w" , ** self .pad )
507535 self .hf_endpoint_var = tk .StringVar ()
508- self .hf_endpoint_entry = ttk .Entry (self .content_frame , textvariable = self .hf_endpoint_var )
509- self .hf_endpoint_entry .grid (row = 11 , column = 1 , sticky = "ew" , ** self .pad )
536+ self .hf_endpoint_cb = ttk .Combobox (self .content_frame , textvariable = self .hf_endpoint_var , state = "normal" )
537+ self .hf_endpoint_cb ["values" ] = ["https://huggingface.co" , "https://hf-mirror.com" ]
538+ self .hf_endpoint_cb .grid (row = 12 , column = 1 , sticky = "ew" , ** self .pad )
510539
511540 # Streamer Host and Port (only visible when run mode is streamer)
512541 self .label_streamer_host = ttk .Label (self .content_frame , text = "Streamer URL:" )
@@ -534,10 +563,10 @@ def create_widgets(self):
534563 self .btn_reset .grid (row = 10 , column = 3 , sticky = "ew" , ** self .pad )
535564
536565 self .btn_stop = ttk .Button (self .content_frame , text = "Stop" , command = self .stop_process )
537- self .btn_stop .grid (row = 11 , column = 2 , sticky = "ew" , ** self .pad )
566+ self .btn_stop .grid (row = 12 , column = 2 , sticky = "ew" , ** self .pad )
538567
539568 self .btn_run = ttk .Button (self .content_frame , text = "Run" , command = self .save_settings )
540- self .btn_run .grid (row = 11 , column = 3 , sticky = "ew" , ** self .pad )
569+ self .btn_run .grid (row = 12 , column = 3 , sticky = "ew" , ** self .pad )
541570
542571 # Column weights inside content frame
543572 for col in range (4 ):
@@ -546,7 +575,54 @@ def create_widgets(self):
546575 # Status bar at bottom
547576 self .status_label = tk .Label (self , text = "" , anchor = "w" , relief = "sunken" , padx = 20 , pady = 4 )
548577 self .status_label .grid (row = 1 , column = 0 , sticky = "we" ) # no padding
549-
578+ # Bind device change event
579+ self .device_var .trace_add ("write" , self .on_device_change )
580+
581+ def on_optimizer_change (self , * args ):
582+ """Show/hide TensorRT recompile option based on optimizer selection"""
583+ if self .inference_optimizer_var .get () == "TensorRT" :
584+ self .recompile_trt_cb .grid ()
585+ else :
586+ self .recompile_trt_cb .grid_remove ()
587+
588+ def on_device_change (self , * args ):
589+ """Update Inference Optimizer options based on selected device"""
590+ device_label = self .device_var .get ()
591+
592+ # Determine device type from label
593+ if "CUDA" in device_label :
594+ device_type = "CUDA"
595+ elif "DirectML" in device_label :
596+ device_type = "DirectML"
597+ # elif "MPS" in device_label:
598+ # device_type = "MPS"
599+ else :
600+ device_type = "Other"
601+
602+ # Update optimizer options
603+ self .inference_optimizer_cb ["values" ] = self .optimizer_options [device_type ]
604+
605+ # Set default value if not already set
606+ if not self .inference_optimizer_var .get ():
607+ self .inference_optimizer_var .set ("None" )
608+
609+ # Show/hide based on device type (always show, but could be hidden if needed)
610+ # Show/hide based on device type
611+ if device_type == "Other" :
612+ self .label_inference_optimizer .grid_remove ()
613+ self .inference_optimizer_cb .grid_remove ()
614+ self .recompile_trt_cb .grid_remove ()
615+ else :
616+ self .label_inference_optimizer .grid ()
617+ self .inference_optimizer_cb .grid ()
618+
619+ # Show/hide TensorRT recompile option based on current optimizer selection
620+ if self .inference_optimizer_var .get () == "TensorRT" :
621+ self .recompile_trt_cb .grid ()
622+ else :
623+ self .recompile_trt_cb .grid_remove ()
624+
625+
550626 def refresh_window_list (self ):
551627 """Refresh the list of available windows"""
552628 try :
@@ -719,7 +795,9 @@ def update_language_texts(self):
719795 self .label_run_mode .config (text = texts .get ("Run Mode:" , "Run Mode:" ))
720796 localized_run_vals = [texts .get ("Viewer" , "Viewer" ), texts .get ("Streamer" , "Streamer" )]
721797 self .run_mode_cb ["values" ] = localized_run_vals
722-
798+ # Add Inference Optimizer text update
799+ self .label_inference_optimizer .config (text = texts .get ("Inference Optimizer:" , "Inference Optimizer:" ))
800+ self .recompile_trt_cb .config (text = texts .get ("Recompile TensorRT:" , "Recompile TensorRT" ))
723801
724802 # Select the appropriate label
725803 if self .run_mode_key == "Viewer" :
@@ -945,7 +1023,11 @@ def apply_config(self, cfg, keep_optional=True):
9451023 self .foreground_scale_cb .set (str (cfg .get ("Foreground Scale" , DEFAULTS ["Foreground Scale" ])))
9461024 self .fp16_var .set (cfg .get ("FP16" , DEFAULTS ["FP16" ]))
9471025 self .download_var .set (cfg .get ("Download Path" , DEFAULTS ["Download Path" ]))
948- self .hf_endpoint_var .set (cfg .get ("HF Endpoint" , DEFAULTS ["HF Endpoint" ]))
1026+ hf_endpoint = cfg .get ("HF Endpoint" , DEFAULTS ["HF Endpoint" ])
1027+ self .hf_endpoint_var .set (hf_endpoint )
1028+ # If the endpoint is not in the predefined list, add it
1029+ if hf_endpoint not in self .hf_endpoint_cb ["values" ]:
1030+ self .hf_endpoint_cb ["values" ] = list (self .hf_endpoint_cb ["values" ]) + [hf_endpoint ]
9491031 if keep_optional : # no update for language
9501032 self .language_var .set (cfg .get ("Language" , DEFAULTS ["Language" ]))
9511033
@@ -964,6 +1046,32 @@ def apply_config(self, cfg, keep_optional=True):
9641046 self .update_language_texts ()
9651047 self .on_run_mode_change ()
9661048 self .on_capture_mode_change ()
1049+
1050+ # Apply Inference Optimizer setting with validation
1051+ saved_optimizer = cfg .get ("Inference Optimizer" , DEFAULTS ["Inference Optimizer" ])
1052+
1053+ # Validate that the saved optimizer is compatible with the current device
1054+ device_label = self .device_var .get ()
1055+ if "CUDA" in device_label :
1056+ device_type = "CUDA"
1057+ elif "DirectML" in device_label :
1058+ device_type = "DirectML"
1059+ else :
1060+ device_type = "Other"
1061+
1062+ # Check if saved optimizer is valid for current device
1063+ valid_optimizers = self .optimizer_options [device_type ]
1064+ if saved_optimizer not in valid_optimizers :
1065+ # Reset to default if incompatible
1066+ saved_optimizer = "None"
1067+
1068+ self .inference_optimizer_var .set (saved_optimizer )
1069+
1070+ # Trigger device change to update optimizer options
1071+ self .recompile_trt_var .set (cfg .get ("Recompile TensorRT" , DEFAULTS ["Recompile TensorRT" ]))
1072+
1073+ # Trigger device change to update optimizer options
1074+ self .on_device_change ()
9671075
9681076 def update_depth_resolution_options (self , model_name ):
9691077 """Update depth resolution options based on selected model"""
@@ -1071,6 +1179,8 @@ def save_settings(self):
10711179 "Run Mode" : self .run_mode_key ,
10721180 "Streamer Port" : int (self .streamer_port_var .get ()),
10731181 "Stream Quality" : int (self .stream_quality_cb .get ()),
1182+ "Inference Optimizer" : self .inference_optimizer_var .get (),
1183+ "Recompile TensorRT" : self .recompile_trt_var .get (),
10741184 }
10751185 success = self .save_yaml ("settings.yaml" , cfg )
10761186 if success :
0 commit comments