Skip to content

Commit 03c636e

Browse files
committed
debug streamer for dml device
1 parent 481b1e9 commit 03c636e

File tree

1 file changed

+46
-14
lines changed

1 file changed

+46
-14
lines changed

depth.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -715,16 +715,24 @@ def overlay_fps(rgb: torch.Tensor, fps: float, color=(0.0, 255.0, 0.0)) -> torch
715715
return rgb * (1.0 - alpha) + overlay_color * alpha
716716

717717
# generate left and right eye view for streamer
718-
def make_sbs_core(rgb: torch.Tensor, depth: torch.Tensor, ipd_uv=0.064,
719-
depth_ratio=1.0, display_mode="Half-SBS") -> torch.Tensor:
718+
import torch
719+
import torch.nn.functional as F
720+
721+
def make_sbs_core(rgb: torch.Tensor,
722+
depth: torch.Tensor,
723+
ipd_uv=0.064,
724+
depth_ratio=1.0,
725+
display_mode="Half-SBS",
726+
device=DEVICE) -> torch.Tensor:
720727
"""
721728
Core tensor operations for side-by-side stereo.
729+
Keeps CUDA fast path (grid_sample) and fallback path (gather).
722730
Compatible with torch.compile.
723731
Inputs:
724-
rgb: [C,H,W] float tensor
725-
depth: [H,W] float tensor
732+
rgb: [C,H,W] float tensor
733+
depth: [H,W] float tensor
726734
Returns:
727-
SBS image [C,H,W] float tensor (0-255)
735+
SBS image [C,H,W] float tensor (0-255 range)
728736
"""
729737
C, H, W = rgb.shape
730738
img = rgb.unsqueeze(0) # [1,C,H,W]
@@ -734,14 +742,33 @@ def make_sbs_core(rgb: torch.Tensor, depth: torch.Tensor, ipd_uv=0.064,
734742
depth_strength = 0.05
735743
shifts = inv * max_px * depth_strength
736744

737-
# grid_sample path (works for CUDA/CPU/MPS)
738-
xs = torch.linspace(-1.0, 1.0, W, device=rgb.device, dtype=rgb.dtype).view(1,1,W).expand(1,H,W)
739-
ys = torch.linspace(-1.0, 1.0, H, device=rgb.device, dtype=rgb.dtype).view(1,H,1).expand(1,H,W)
740-
shift_norm = shifts * (2.0 / (W-1))
741-
grid_left = torch.stack([xs + shift_norm, ys], dim=-1)
742-
grid_right = torch.stack([xs - shift_norm, ys], dim=-1)
743-
sampled_left = F.grid_sample(img, grid_left, mode="bilinear", padding_mode="border", align_corners=True)[0]
744-
sampled_right = F.grid_sample(img, grid_right, mode="bilinear", padding_mode="border", align_corners=True)[0]
745+
# CUDA fast path: grid_sample
746+
if "CUDA" in DEVICE_INFO:
747+
xs = torch.linspace(-1.0, 1.0, W, device=device, dtype=MODEL_DTYPE).view(1, 1, W).expand(1, H, W)
748+
ys = torch.linspace(-1.0, 1.0, H, device=device, dtype=MODEL_DTYPE).view(1, H, 1).expand(1, H, W)
749+
shift_norm = shifts * (2.0 / (W - 1))
750+
751+
grid_left = torch.stack([xs + shift_norm, ys], dim=-1)
752+
grid_right = torch.stack([xs - shift_norm, ys], dim=-1)
753+
754+
sampled_left = F.grid_sample(img, grid_left, mode="bilinear",
755+
padding_mode="border", align_corners=True)[0]
756+
sampled_right = F.grid_sample(img, grid_right, mode="bilinear",
757+
padding_mode="border", align_corners=True)[0]
758+
759+
# Fallback path: vectorized gather (DirectML / MPS / CPU safe)
760+
else:
761+
base = torch.arange(W, device=device).view(1, -1).expand(H, -1).float()
762+
coords_left = (base + shifts).clamp(0, W - 1).long() # [H,W]
763+
coords_right = (base - shifts).clamp(0, W - 1).long() # [H,W]
764+
765+
# Left eye
766+
gather_idx_left = coords_left.unsqueeze(0).expand(C, H, W).unsqueeze(0) # [1,C,H,W]
767+
sampled_left = torch.gather(img.expand(1, C, H, W), 3, gather_idx_left)[0] # [C,H,W]
768+
769+
# Right eye
770+
gather_idx_right = coords_right.unsqueeze(0).expand(C, H, W).unsqueeze(0)
771+
sampled_right = torch.gather(img.expand(1, C, H, W), 3, gather_idx_right)[0]
745772

746773
# Aspect pad helper
747774
def pad_to_aspect_tensor(tensor, target_ratio=(16, 9)):
@@ -759,9 +786,14 @@ def pad_to_aspect_tensor(tensor, target_ratio=(16, 9)):
759786
pad_left = (new_w - w) // 2
760787
return F.pad(tensor, (pad_left, new_w - w - pad_left, 0, 0))
761788

789+
# Aspect pad & arrange SBS/TAB
762790
left = pad_to_aspect_tensor(sampled_left)
763791
right = pad_to_aspect_tensor(sampled_right)
764-
out = torch.cat([left, right], dim=2 if display_mode != "TAB" else 1)
792+
793+
if display_mode == "TAB":
794+
out = torch.cat([left, right], dim=1)
795+
else:
796+
out = torch.cat([left, right], dim=2)
765797

766798
if display_mode != "Full-SBS":
767799
out = F.interpolate(out.unsqueeze(0), size=left.shape[1:], mode="area")[0]

0 commit comments

Comments
 (0)