1
1
from mock import Mock , patch , call
2
2
import io
3
- import os
4
- import os .path
5
- import shutil
6
- import filecmp
7
3
8
4
import sshuttle .firewall
9
5
@@ -19,27 +15,27 @@ def setup_daemon():
19
15
10,2404:6800:4004:80c::33
20
16
PORTS 1024,1025,1026,1027
21
17
GO 1
18
+ HOST 1.2.3.3,existing
22
19
""" )
23
20
stdout = Mock ()
24
21
return stdin , stdout
25
22
26
23
27
- @patch ('sshuttle.firewall.HOSTSFILE' , new = 'tmp/hosts' )
28
- @patch ('sshuttle.firewall.hostmap' , new = {
29
- 'myhost' : '1.2.3.4' ,
30
- 'myotherhost' : '1.2.3.5' ,
31
- })
32
- def test_rewrite_etc_hosts ():
33
- if not os .path .isdir ("tmp" ):
34
- os .mkdir ("tmp" )
24
+ def test_rewrite_etc_hosts (tmpdir ):
25
+ orig_hosts = tmpdir .join ("hosts.orig" )
26
+ orig_hosts .write ("1.2.3.3 existing\n " )
35
27
36
- with open ( "tmp/ hosts.orig" , "w" ) as f :
37
- f . write ( "1.2.3.3 existing \n " )
28
+ new_hosts = tmpdir . join ( " hosts" )
29
+ orig_hosts . copy ( new_hosts )
38
30
39
- shutil .copyfile ("tmp/hosts.orig" , "tmp/hosts" )
31
+ hostmap = {
32
+ 'myhost' : '1.2.3.4' ,
33
+ 'myotherhost' : '1.2.3.5' ,
34
+ }
35
+ with patch ('sshuttle.firewall.HOSTSFILE' , new = str (new_hosts )):
36
+ sshuttle .firewall .rewrite_etc_hosts (hostmap , 10 )
40
37
41
- sshuttle .firewall .rewrite_etc_hosts (10 )
42
- with open ("tmp/hosts" ) as f :
38
+ with new_hosts .open () as f :
43
39
line = f .readline ()
44
40
s = line .split ()
45
41
assert s == ['1.2.3.3' , 'existing' ]
@@ -57,39 +53,37 @@ def test_rewrite_etc_hosts():
57
53
line = f .readline ()
58
54
assert line == ""
59
55
60
- sshuttle .firewall .restore_etc_hosts (10 )
61
- assert filecmp .cmp ("tmp/hosts.orig" , "tmp/hosts" , shallow = False ) is True
56
+ with patch ('sshuttle.firewall.HOSTSFILE' , new = str (new_hosts )):
57
+ sshuttle .firewall .restore_etc_hosts (10 )
58
+ assert orig_hosts .computehash () == new_hosts .computehash ()
62
59
63
60
64
- @patch ('sshuttle.firewall.HOSTSFILE' , new = 'tmp/hosts ' )
61
+ @patch ('sshuttle.firewall.rewrite_etc_hosts ' )
65
62
@patch ('sshuttle.firewall.setup_daemon' )
66
63
@patch ('sshuttle.firewall.get_method' )
67
- def test_main (mock_get_method , mock_setup_daemon ):
64
+ def test_main (mock_get_method , mock_setup_daemon , mock_rewrite_etc_hosts ):
68
65
stdin , stdout = setup_daemon ()
69
66
mock_setup_daemon .return_value = stdin , stdout
70
67
71
- if not os . path . isdir ( "tmp" ):
72
- os . mkdir ( "tmp" )
68
+ mock_get_method ( "not_auto" ). name = "test"
69
+ mock_get_method . reset_mock ( )
73
70
74
- sshuttle .firewall .main ("test " , False )
71
+ sshuttle .firewall .main ("not_auto " , False )
75
72
76
- with open ("tmp/hosts" ) as f :
77
- line = f .readline ()
78
- s = line .split ()
79
- assert s == ['1.2.3.3' , 'existing' ]
80
-
81
- line = f .readline ()
82
- assert line == ""
73
+ assert mock_rewrite_etc_hosts .mock_calls == [
74
+ call ({'1.2.3.3' : 'existing' }, 1024 ),
75
+ call ({}, 1024 ),
76
+ ]
83
77
84
- stdout .mock_calls == [
78
+ assert stdout .mock_calls == [
85
79
call .write ('READY test\n ' ),
86
80
call .flush (),
87
81
call .write ('STARTED\n ' ),
88
82
call .flush ()
89
83
]
90
- mock_setup_daemon .mock_calls == [call ()]
91
- mock_get_method .mock_calls == [
92
- call ('test ' ),
84
+ assert mock_setup_daemon .mock_calls == [call ()]
85
+ assert mock_get_method .mock_calls == [
86
+ call ('not_auto ' ),
93
87
call ().setup_firewall (
94
88
1024 , 1026 ,
95
89
[(10 , u'2404:6800:4004:80c::33' )],
@@ -104,6 +98,7 @@ def test_main(mock_get_method, mock_setup_daemon):
104
98
[(2 , 24 , False , u'1.2.3.0' ), (2 , 32 , True , u'1.2.3.66' )],
105
99
True ),
106
100
call ().setup_firewall ()(),
101
+ call ().setup_firewall ()(),
107
102
call ().setup_firewall (1024 , 0 , [], 10 , [], True ),
108
103
call ().setup_firewall (1025 , 0 , [], 2 , [], True ),
109
104
]
0 commit comments