Skip to content

Commit 90654b4

Browse files
committed
Simplify selection of features
1 parent 6b4e36c commit 90654b4

File tree

9 files changed

+151
-86
lines changed

9 files changed

+151
-86
lines changed

sshuttle/client.py

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from sshuttle.ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
1515
from sshuttle.helpers import log, debug1, debug2, debug3, Fatal, islocal, \
1616
resolvconf_nameservers
17-
from sshuttle.methods import get_method
17+
from sshuttle.methods import get_method, Features
1818

1919
_extra_fd = os.open('/dev/null', os.O_RDONLY)
2020

@@ -505,19 +505,44 @@ def main(listenip_v6, listenip_v4,
505505

506506
fw = FirewallClient(method_name)
507507

508-
features = fw.method.get_supported_features()
508+
# Get family specific subnet lists
509+
if dns:
510+
nslist += resolvconf_nameservers()
511+
512+
subnets = subnets_include + subnets_exclude # we don't care here
513+
subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6]
514+
nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
515+
subnets_v4 = [i for i in subnets if i[0] == socket.AF_INET]
516+
nslist_v4 = [i for i in nslist if i[0] == socket.AF_INET]
517+
518+
# Check features available
519+
avail = fw.method.get_supported_features()
520+
required = Features()
521+
509522
if listenip_v6 == "auto":
510-
if features.ipv6:
523+
if avail.ipv6:
511524
listenip_v6 = ('::1', 0)
512525
else:
513526
listenip_v6 = None
514527

528+
required.ipv6 = len(subnets_v6) > 0 or len(nslist_v6) > 0
529+
required.udp = avail.udp
530+
required.dns = len(nslist) > 0
531+
532+
fw.method.assert_features(required)
533+
534+
if required.ipv6 and listenip_v6 is None:
535+
raise Fatal("IPv6 required but not listening.")
536+
537+
# display features enabled
538+
debug1("IPv6 enabled: %r\n" % required.ipv6)
539+
debug1("UDP enabled: %r\n" % required.udp)
540+
debug1("DNS enabled: %r\n" % required.dns)
541+
542+
# bind to required ports
515543
if listenip_v4 == "auto":
516544
listenip_v4 = ('127.0.0.1', 0)
517545

518-
udp = features.udp
519-
debug1("UDP enabled: %r\n" % udp)
520-
521546
if listenip_v6 and listenip_v6[1] and listenip_v4 and listenip_v4[1]:
522547
# if both ports given, no need to search for a spare port
523548
ports = [0, ]
@@ -536,7 +561,7 @@ def main(listenip_v6, listenip_v4,
536561
tcp_listener = MultiListener()
537562
tcp_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
538563

539-
if udp:
564+
if required.udp:
540565
udp_listener = MultiListener(socket.SOCK_DGRAM)
541566
udp_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
542567
else:
@@ -584,10 +609,7 @@ def main(listenip_v6, listenip_v4,
584609
udp_listener.print_listening("UDP redirector")
585610

586611
bound = False
587-
if dns or nslist:
588-
if dns:
589-
nslist += resolvconf_nameservers()
590-
dns = True
612+
if required.dns:
591613
# search for spare port for DNS
592614
debug2('Binding DNS:')
593615
ports = range(12300, 9000, -1)
@@ -628,17 +650,41 @@ def main(listenip_v6, listenip_v4,
628650
dnsport_v4 = 0
629651
dns_listener = None
630652

631-
fw.method.check_settings(udp, dns)
653+
# Last minute sanity checks.
654+
# These should never fail.
655+
# If these do fail, something is broken above.
656+
if len(subnets_v6) > 0:
657+
assert required.ipv6
658+
if redirectport_v6 == 0:
659+
raise Fatal("IPv6 subnets defined but not listening")
660+
661+
if len(nslist_v6) > 0:
662+
assert required.dns
663+
assert required.ipv6
664+
if dnsport_v6 == 0:
665+
raise Fatal("IPv6 ns servers defined but not listening")
666+
667+
if len(subnets_v4) > 0:
668+
if redirectport_v4 == 0:
669+
raise Fatal("IPv4 subnets defined but not listening")
670+
671+
if len(nslist_v4) > 0:
672+
if dnsport_v4 == 0:
673+
raise Fatal("IPv4 ns servers defined but not listening")
674+
675+
# setup method specific stuff on listeners
632676
fw.method.setup_tcp_listener(tcp_listener)
633677
if udp_listener:
634678
fw.method.setup_udp_listener(udp_listener)
635679
if dns_listener:
636680
fw.method.setup_udp_listener(dns_listener)
637681

682+
# start the firewall
638683
fw.setup(subnets_include, subnets_exclude, nslist,
639684
redirectport_v6, redirectport_v4, dnsport_v6, dnsport_v4,
640-
udp)
685+
required.udp)
641686

687+
# start the client process
642688
try:
643689
return _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename,
644690
python, latency_control, dns_listener,

sshuttle/firewall.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -178,26 +178,23 @@ def main(method_name, syslog):
178178
try:
179179
debug1('firewall manager: setting up.\n')
180180

181-
nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
182181
subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6]
183-
if port_v6 > 0:
182+
nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
183+
184+
if len(subnets_v6) > 0 or len(nslist_v6) > 0:
184185
debug2('firewall manager: setting up IPv6.\n')
185186
method.setup_firewall(
186187
port_v6, dnsport_v6, nslist_v6,
187188
socket.AF_INET6, subnets_v6, udp)
188-
elif len(subnets_v6) > 0:
189-
debug1("IPv6 subnets defined but IPv6 disabled\n")
190189

191-
nslist_v4 = [i for i in nslist if i[0] == socket.AF_INET]
192190
subnets_v4 = [i for i in subnets if i[0] == socket.AF_INET]
193-
if port_v4 > 0:
191+
nslist_v4 = [i for i in nslist if i[0] == socket.AF_INET]
192+
193+
if len(subnets_v4) > 0 or len(nslist_v4) > 0:
194194
debug2('firewall manager: setting up IPv4.\n')
195195
method.setup_firewall(
196196
port_v4, dnsport_v4, nslist_v4,
197197
socket.AF_INET, subnets_v4, udp)
198-
elif len(subnets_v4) > 0:
199-
debug1('firewall manager: '
200-
'IPv4 subnets defined but IPv4 disabled\n')
201198

202199
stdout.write('STARTED\n')
203200

sshuttle/methods/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,13 @@ def setup_tcp_listener(self, tcp_listener):
6262
def setup_udp_listener(self, udp_listener):
6363
pass
6464

65-
def check_settings(self, udp, dns):
66-
if udp:
67-
Fatal("UDP support not supported with method %s.\n" % self.name)
65+
def assert_features(self, features):
66+
avail = self.get_supported_features()
67+
for key in ["udp", "dns", "ipv6"]:
68+
if getattr(features, key) and not getattr(avail, key):
69+
raise Fatal(
70+
"Feature %s not supported with method %s.\n" %
71+
(key, self.name))
6872

6973
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp):
7074
raise NotImplementedError()

sshuttle/methods/nat.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,12 @@ def _ipt_ttl(*args):
5555
'-p', 'tcp',
5656
'--to-ports', str(port))
5757

58-
if dnsport:
59-
for f, ip in [i for i in nslist if i[0] == family]:
60-
_ipt_ttl('-A', chain, '-j', 'REDIRECT',
61-
'--dest', '%s/32' % ip,
62-
'-p', 'udp',
63-
'--dport', '53',
64-
'--to-ports', str(dnsport))
58+
for f, ip in [i for i in nslist if i[0] == family]:
59+
_ipt_ttl('-A', chain, '-j', 'REDIRECT',
60+
'--dest', '%s/32' % ip,
61+
'-p', 'udp',
62+
'--dport', '53',
63+
'--to-ports', str(dnsport))
6564

6665
def restore_firewall(self, port, family, udp):
6766
# only ipv4 supported with NAT

sshuttle/methods/pf.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -181,27 +181,28 @@ def setup_firewall(self, port, dnsport, nslist, family, subnets, udp):
181181
if udp:
182182
raise Exception("UDP not supported by pf method_name")
183183

184-
includes = []
185-
# If a given subnet is both included and excluded, list the
186-
# exclusion first; the table will ignore the second, opposite
187-
# definition
188-
for f, swidth, sexclude, snet in sorted(
189-
subnets, key=lambda s: (s[1], s[2]), reverse=True):
190-
includes.append(b"%s%s/%d" %
191-
(b"!" if sexclude else b"",
192-
snet.encode("ASCII"),
193-
swidth))
194-
195-
tables.append(
196-
b'table <forward_subnets> {%s}' % b','.join(includes))
197-
translating_rules.append(
198-
b'rdr pass on lo0 proto tcp '
199-
b'to <forward_subnets> -> 127.0.0.1 port %r' % port)
200-
filtering_rules.append(
201-
b'pass out route-to lo0 inet proto tcp '
202-
b'to <forward_subnets> keep state')
203-
204-
if dnsport:
184+
if len(subnets) > 0:
185+
includes = []
186+
# If a given subnet is both included and excluded, list the
187+
# exclusion first; the table will ignore the second, opposite
188+
# definition
189+
for f, swidth, sexclude, snet in sorted(
190+
subnets, key=lambda s: (s[1], s[2]), reverse=True):
191+
includes.append(b"%s%s/%d" %
192+
(b"!" if sexclude else b"",
193+
snet.encode("ASCII"),
194+
swidth))
195+
196+
tables.append(
197+
b'table <forward_subnets> {%s}' % b','.join(includes))
198+
translating_rules.append(
199+
b'rdr pass on lo0 proto tcp '
200+
b'to <forward_subnets> -> 127.0.0.1 port %r' % port)
201+
filtering_rules.append(
202+
b'pass out route-to lo0 inet proto tcp '
203+
b'to <forward_subnets> keep state')
204+
205+
if len(nslist) > 0:
205206
tables.append(
206207
b'table <dns_servers> {%s}' %
207208
b','.join([ns[1].encode("ASCII") for ns in nslist]))

sshuttle/methods/tproxy.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def recv_udp(listener, bufsize):
5959
ip = socket.inet_ntop(family, cmsg_data[start:start + length])
6060
dstip = (ip, port)
6161
break
62+
print("xxxxx", srcip, dstip)
6263
return (srcip, dstip, data)
6364
elif recvmsg == "socket_ext":
6465
def recv_udp(listener, bufsize):
@@ -187,16 +188,15 @@ def _ipt_ttl(*args):
187188
_ipt('-A', tproxy_chain, '-m', 'socket', '-j', divert_chain,
188189
'-m', 'udp', '-p', 'udp')
189190

190-
if dnsport:
191-
for f, ip in [i for i in nslist if i[0] == family]:
192-
_ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1',
193-
'--dest', '%s/32' % ip,
194-
'-m', 'udp', '-p', 'udp', '--dport', '53')
195-
_ipt('-A', tproxy_chain, '-j', 'TPROXY',
196-
'--tproxy-mark', '0x1/0x1',
197-
'--dest', '%s/32' % ip,
198-
'-m', 'udp', '-p', 'udp', '--dport', '53',
199-
'--on-port', str(dnsport))
191+
for f, ip in [i for i in nslist if i[0] == family]:
192+
_ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1',
193+
'--dest', '%s/32' % ip,
194+
'-m', 'udp', '-p', 'udp', '--dport', '53')
195+
_ipt('-A', tproxy_chain, '-j', 'TPROXY',
196+
'--tproxy-mark', '0x1/0x1',
197+
'--dest', '%s/32' % ip,
198+
'-m', 'udp', '-p', 'udp', '--dport', '53',
199+
'--on-port', str(dnsport))
200200

201201
for f, swidth, sexclude, snet \
202202
in sorted(subnets, key=lambda s: s[1], reverse=True):
@@ -267,16 +267,3 @@ def _ipt_ttl(*args):
267267
if ipt_chain_exists(family, table, divert_chain):
268268
_ipt('-F', divert_chain)
269269
_ipt('-X', divert_chain)
270-
271-
def check_settings(self, udp, dns):
272-
if udp and recvmsg is None:
273-
raise Fatal("tproxy UDP support requires recvmsg function.\n")
274-
275-
if dns and recvmsg is None:
276-
raise Fatal("tproxy DNS support requires recvmsg function.\n")
277-
278-
if udp:
279-
debug1("tproxy UDP support enabled.\n")
280-
281-
if dns:
282-
debug1("tproxy DNS support enabled.\n")

sshuttle/tests/test_methods_nat.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import socket
44
import struct
55

6+
from sshuttle.helpers import Fatal
67
from sshuttle.methods import get_method
78

89

@@ -11,6 +12,7 @@ def test_get_supported_features():
1112
features = method.get_supported_features()
1213
assert not features.ipv6
1314
assert not features.udp
15+
assert features.dns
1416

1517

1618
def test_get_tcp_dstip():
@@ -52,10 +54,18 @@ def test_setup_udp_listener():
5254
assert listener.mock_calls == []
5355

5456

55-
def test_check_settings():
57+
def test_assert_features():
5658
method = get_method('nat')
57-
method.check_settings(True, True)
58-
method.check_settings(False, True)
59+
features = method.get_supported_features()
60+
method.assert_features(features)
61+
62+
features.udp = True
63+
with pytest.raises(Fatal):
64+
method.assert_features(features)
65+
66+
features.ipv6 = True
67+
with pytest.raises(Fatal):
68+
method.assert_features(features)
5969

6070

6171
def test_firewall_command():

sshuttle/tests/test_methods_pf.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import socket
44

55
from sshuttle.methods import get_method
6+
from sshuttle.helpers import Fatal
67

78

89
def test_get_supported_features():
910
method = get_method('pf')
1011
features = method.get_supported_features()
1112
assert not features.ipv6
1213
assert not features.udp
14+
assert features.dns
1315

1416

1517
@patch('sshuttle.helpers.verbose', new=3)
@@ -68,10 +70,18 @@ def test_setup_udp_listener():
6870
assert listener.mock_calls == []
6971

7072

71-
def test_check_settings():
73+
def test_assert_features():
7274
method = get_method('pf')
73-
method.check_settings(True, True)
74-
method.check_settings(False, True)
75+
features = method.get_supported_features()
76+
method.assert_features(features)
77+
78+
features.udp = True
79+
with pytest.raises(Fatal):
80+
method.assert_features(features)
81+
82+
features.ipv6 = True
83+
with pytest.raises(Fatal):
84+
method.assert_features(features)
7585

7686

7787
@patch('sshuttle.methods.pf.sys.stdout')

0 commit comments

Comments
 (0)