Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 67 additions & 17 deletions backend/OBSController.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,62 @@ def __init__(self):
pass

def validate_ip(self, host: str):
if host in ("localhost", "127.0.0.1"):
return True
"""
Validate host address (IPv4, IPv6, or hostname).

# We're explicitly disallowing non-localhost DNS entries here.
# Continuing this pattern for now, but this is probably the wrong thing
# to do long-term.
Returns True if the host is valid, False otherwise.
"""
if not host or not host.strip():
return False

try:
addr = ipaddress.ip_address(host)
host = host.strip()

# Handle bracket-wrapped IPv6 addresses [::1]
if host.startswith('[') and host.endswith(']'):
host = host[1:-1]

# And we're disallowing IPv6 entries here, for compatibility with
# previous implementations. Again, probably the wrong thing
# long-term, but implementing this way to mitigate risk while we're
# in a bad-push state.
if not addr.version == ipaddress.IPv4Address.version:
raise ValueError()
# Try to parse as IP address (IPv4 or IPv6)
try:
ipaddress.ip_address(host)
return True
except ValueError:
pass

# Try to validate as hostname/DNS entry
# Valid hostname: alphanumeric, hyphens, dots, max 253 chars
# Each label (between dots) max 63 chars, can't start/end with hyphen
if len(host) > 253:
return False

# Allow localhost explicitly
if host.lower() == "localhost":
return True

# Split into labels and validate each
labels = host.split('.')
if not labels:
return False

# Pattern for valid hostname labels
for label in labels:
if not label or len(label) > 63:
return False
# Must start and end with alphanumeric
if not (label[0].isalnum() and label[-1].isalnum()):
return False
# Labels can't be all numeric (would be confused with IP)
if label.isdigit():
# If all labels are numeric, it looks like an IP address
# and should have been caught by ipaddress.ip_address()
if all(l.isdigit() for l in labels):
return False
# Middle characters can be alphanumeric or hyphen
for char in label:
if not (char.isalnum() or char == '-'):
return False

return True

def on_connect(self, obs):
self.connected = True

Expand All @@ -58,18 +94,32 @@ def connect_to(self, host=None, port=None, timeout=1, legacy=False, **kwargs):
self.event_obs.disconnect()
return False

# For IPv6 addresses, wrap in brackets if not already wrapped
# This is required for WebSocket URL construction (ws://[::1]:port)
connection_host = host
try:
addr = ipaddress.ip_address(host)
if isinstance(addr, ipaddress.IPv6Address):
# Only wrap if not already wrapped
if not (host.startswith('[') and host.endswith(']')):
connection_host = f"[{host}]"
log.debug(f"Wrapped IPv6 address: {host} -> {connection_host}")
except ValueError:
# Not an IP address, use as-is (hostname)
pass

try:
log.debug(f"Trying to connect to obs with legacy: {legacy}")
super().__init__(host=host, port=port, timeout=timeout, legacy=legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs)
self.event_obs = obsws(host=host, port=port, timeout=timeout, legacy=legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs)
super().__init__(host=connection_host, port=port, timeout=timeout, legacy=legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs)
self.event_obs = obsws(host=connection_host, port=port, timeout=timeout, legacy=legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs)
self.connect()
log.info("Successfully connected to OBS")
return True
except (obswebsocket.exceptions.ConnectionFailure, ValueError) as e:
try:
log.error(f"Failed to connect to OBS with legacy: {legacy}, trying with legacy: {not legacy}")
super().__init__(host=host, port=port, timeout=timeout, legacy=not legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs)
self.event_obs = obsws(host=host, port=port, timeout=timeout, legacy=not legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs)
super().__init__(host=connection_host, port=port, timeout=timeout, legacy=not legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs)
self.event_obs = obsws(host=connection_host, port=port, timeout=timeout, legacy=not legacy, on_connect=self.on_connect, on_disconnect=self.on_disconnect, authreconnect=5, **kwargs)
self.connect()
log.info("Successfully connected to OBS")

Expand Down