@@ -715,16 +715,24 @@ def overlay_fps(rgb: torch.Tensor, fps: float, color=(0.0, 255.0, 0.0)) -> torch
715715 return rgb * (1.0 - alpha ) + overlay_color * alpha
716716
717717# generate left and right eye view for streamer
718- def make_sbs_core (rgb : torch .Tensor , depth : torch .Tensor , ipd_uv = 0.064 ,
719- depth_ratio = 1.0 , display_mode = "Half-SBS" ) -> torch .Tensor :
718+ import torch
719+ import torch .nn .functional as F
720+
721+ def make_sbs_core (rgb : torch .Tensor ,
722+ depth : torch .Tensor ,
723+ ipd_uv = 0.064 ,
724+ depth_ratio = 1.0 ,
725+ display_mode = "Half-SBS" ,
726+ device = DEVICE ) -> torch .Tensor :
720727 """
721728 Core tensor operations for side-by-side stereo.
729+ Keeps CUDA fast path (grid_sample) and fallback path (gather).
722730 Compatible with torch.compile.
723731 Inputs:
724- rgb: [C,H,W] float tensor
725- depth: [H,W] float tensor
732+ rgb: [C,H,W] float tensor
733+ depth: [H,W] float tensor
726734 Returns:
727- SBS image [C,H,W] float tensor (0-255)
735+ SBS image [C,H,W] float tensor (0-255 range )
728736 """
729737 C , H , W = rgb .shape
730738 img = rgb .unsqueeze (0 ) # [1,C,H,W]
@@ -734,14 +742,33 @@ def make_sbs_core(rgb: torch.Tensor, depth: torch.Tensor, ipd_uv=0.064,
734742 depth_strength = 0.05
735743 shifts = inv * max_px * depth_strength
736744
737- # grid_sample path (works for CUDA/CPU/MPS)
738- xs = torch .linspace (- 1.0 , 1.0 , W , device = rgb .device , dtype = rgb .dtype ).view (1 ,1 ,W ).expand (1 ,H ,W )
739- ys = torch .linspace (- 1.0 , 1.0 , H , device = rgb .device , dtype = rgb .dtype ).view (1 ,H ,1 ).expand (1 ,H ,W )
740- shift_norm = shifts * (2.0 / (W - 1 ))
741- grid_left = torch .stack ([xs + shift_norm , ys ], dim = - 1 )
742- grid_right = torch .stack ([xs - shift_norm , ys ], dim = - 1 )
743- sampled_left = F .grid_sample (img , grid_left , mode = "bilinear" , padding_mode = "border" , align_corners = True )[0 ]
744- sampled_right = F .grid_sample (img , grid_right , mode = "bilinear" , padding_mode = "border" , align_corners = True )[0 ]
745+ # CUDA fast path: grid_sample
746+ if "CUDA" in DEVICE_INFO :
747+ xs = torch .linspace (- 1.0 , 1.0 , W , device = device , dtype = MODEL_DTYPE ).view (1 , 1 , W ).expand (1 , H , W )
748+ ys = torch .linspace (- 1.0 , 1.0 , H , device = device , dtype = MODEL_DTYPE ).view (1 , H , 1 ).expand (1 , H , W )
749+ shift_norm = shifts * (2.0 / (W - 1 ))
750+
751+ grid_left = torch .stack ([xs + shift_norm , ys ], dim = - 1 )
752+ grid_right = torch .stack ([xs - shift_norm , ys ], dim = - 1 )
753+
754+ sampled_left = F .grid_sample (img , grid_left , mode = "bilinear" ,
755+ padding_mode = "border" , align_corners = True )[0 ]
756+ sampled_right = F .grid_sample (img , grid_right , mode = "bilinear" ,
757+ padding_mode = "border" , align_corners = True )[0 ]
758+
759+ # Fallback path: vectorized gather (DirectML / MPS / CPU safe)
760+ else :
761+ base = torch .arange (W , device = device ).view (1 , - 1 ).expand (H , - 1 ).float ()
762+ coords_left = (base + shifts ).clamp (0 , W - 1 ).long () # [H,W]
763+ coords_right = (base - shifts ).clamp (0 , W - 1 ).long () # [H,W]
764+
765+ # Left eye
766+ gather_idx_left = coords_left .unsqueeze (0 ).expand (C , H , W ).unsqueeze (0 ) # [1,C,H,W]
767+ sampled_left = torch .gather (img .expand (1 , C , H , W ), 3 , gather_idx_left )[0 ] # [C,H,W]
768+
769+ # Right eye
770+ gather_idx_right = coords_right .unsqueeze (0 ).expand (C , H , W ).unsqueeze (0 )
771+ sampled_right = torch .gather (img .expand (1 , C , H , W ), 3 , gather_idx_right )[0 ]
745772
746773 # Aspect pad helper
747774 def pad_to_aspect_tensor (tensor , target_ratio = (16 , 9 )):
@@ -759,9 +786,14 @@ def pad_to_aspect_tensor(tensor, target_ratio=(16, 9)):
759786 pad_left = (new_w - w ) // 2
760787 return F .pad (tensor , (pad_left , new_w - w - pad_left , 0 , 0 ))
761788
789+ # Aspect pad & arrange SBS/TAB
762790 left = pad_to_aspect_tensor (sampled_left )
763791 right = pad_to_aspect_tensor (sampled_right )
764- out = torch .cat ([left , right ], dim = 2 if display_mode != "TAB" else 1 )
792+
793+ if display_mode == "TAB" :
794+ out = torch .cat ([left , right ], dim = 1 )
795+ else :
796+ out = torch .cat ([left , right ], dim = 2 )
765797
766798 if display_mode != "Full-SBS" :
767799 out = F .interpolate (out .unsqueeze (0 ), size = left .shape [1 :], mode = "area" )[0 ]
0 commit comments