Skip to content

Commit 3da60ab

Browse files
committed
refactor permissions
1 parent a214987 commit 3da60ab

File tree

8 files changed

+190
-132
lines changed

8 files changed

+190
-132
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import os
2+
import base64
3+
import hashlib
4+
from zope.interface import (
5+
Interface,
6+
Attribute,
7+
implementer,
8+
)
9+
10+
11+
class IPermission(Interface):
12+
"""
13+
A server-side method of granting permission to a client.
14+
"""
15+
name = Attribute("name")
16+
17+
def get_welcome_data():
18+
"""
19+
return a dict of information to include under the name of this
20+
Permission granter (under "permission-required" in the Welcome)
21+
"""
22+
23+
def verify_permission(submit_permission):
24+
"""
25+
return a bool indicating if the submit_permission data is a valid
26+
permission (or not)
27+
"""
28+
29+
30+
def create_permission_provider(kind):
31+
"""
32+
returns a permissions-provider
33+
"""
34+
if kind == "none":
35+
return NoPermission
36+
elif kind == "hashcash":
37+
return HashcashPermission
38+
raise ValueError(
39+
"Unknown permission provider '{}'".format(kind)
40+
)
41+
42+
43+
@implementer(IPermission)
44+
class NoPermission(object):
45+
"""
46+
A no-op permission provider used to grant any client access (the
47+
default).
48+
"""
49+
name = "none"
50+
51+
def get_welcome_data(self):
52+
return {}
53+
54+
def verify_permission(self, submit_permission):
55+
return True
56+
57+
58+
@implementer(IPermission)
59+
class HashcashPermission(object):
60+
"""
61+
A permission provider that generates a random 'resource' string
62+
and checks a proof-of-work from the client.
63+
"""
64+
name = "hashcash"
65+
66+
def __init__(self, bits=20):
67+
self._bits = bits
68+
69+
def get_welcome_data(self):
70+
"""
71+
Generate the data to include under this method's key in the
72+
`permission-required` value of the welcome message.
73+
74+
Should be called at most once per connection.
75+
"""
76+
self._hashcash_resource = base64.b64encode(os.urandom(8)).decode("utf8")
77+
return {
78+
"bits": self._bits,
79+
"resource": self._hashcash_resource,
80+
}
81+
82+
def verify_permission(self, perms):
83+
"""
84+
:returns bool: an indication of whether the provided permissions
85+
reply from a client is valid
86+
"""
87+
# XXX THINK do we need this whole method to be constant-time?
88+
# (basically impossible if it's not even syntactially valid?)
89+
stamp = perms.get("stamp", "")
90+
fields = stamp.split(":")
91+
if len(fields) != 7:
92+
return False
93+
vers, claimed_bits, date, resource, ext, rand, counter = fields
94+
vers = int(vers)
95+
if vers != 1:
96+
return False
97+
if resource != self._hashcash_resource:
98+
return False
99+
100+
claimed_bits = int(claimed_bits)
101+
if claimed_bits < self._bits:
102+
return False
103+
104+
h = hashlib.sha1()
105+
h.update(stamp.encode("utf8"))
106+
measured_hash = h.digest()
107+
if leading_zero_bits(measured_hash) < claimed_bits:
108+
return False
109+
return True
110+
111+
112+
def leading_zero_bits(bytestring):
113+
"""
114+
:returns int: the number of leading zeros in the given byte-string
115+
"""
116+
measured_bits = 0
117+
for byte in bytestring:
118+
bit = 1 << 7
119+
while bit:
120+
if byte & bit:
121+
return measured_bits
122+
else:
123+
measured_bits += 1
124+
bit = bit >> 1
125+
126+

src/wormhole_mailbox_server/server.py

Lines changed: 10 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections import namedtuple
55
from twisted.python import log
66
from twisted.application import service
7+
from .permission import create_permission_provider
78

89
def generate_mailbox_id():
910
return base64.b32encode(os.urandom(8)).lower().strip(b"=").decode("ascii")
@@ -552,102 +553,9 @@ def _shutdown(self):
552553
channel._shutdown()
553554

554555

555-
def leading_zero_bits(bytestring):
556-
"""
557-
:returns int: the number of leading zeros in the given byte-string
558-
"""
559-
measured_bits = 0
560-
for byte in bytestring:
561-
bit = 1 << 7
562-
while bit:
563-
if byte & bit:
564-
return measured_bits
565-
else:
566-
measured_bits += 1
567-
bit = bit >> 1
568-
569-
570-
class NoPermission(object):
571-
"""
572-
A no-op permission provider used to grant any client access (the
573-
default).
574-
"""
575-
name = "none"
576-
577-
def get_welcome_data(self):
578-
return {}
579-
580-
def verify_permission(self, submit_permission):
581-
return True
582-
583-
def is_passed(self):
584-
return True
585-
586-
587-
class HashcashPermission(object):
588-
"""
589-
A permission provider that generates a random 'resource' string
590-
and checks a proof-of-work from the client.
591-
"""
592-
name = "hashcash"
593-
594-
def __init__(self, bits=20):
595-
self._bits = bits
596-
self._passed = False
597-
598-
def get_welcome_data(self):
599-
"""
600-
Generate the data to include under this method's key in the
601-
`permission-required` value of the welcome message.
602-
603-
Should be called at most once per connection.
604-
"""
605-
self._hashcash_resource = base64.b64encode(os.urandom(8)).decode("utf8")
606-
return {
607-
"bits": self._bits,
608-
"resource": self._hashcash_resource,
609-
}
610-
611-
def is_passed(self):
612-
"""
613-
:returns bool: True if verify_permission has been called successfully
614-
"""
615-
return self._passed
616-
617-
def verify_permission(self, perms):
618-
"""
619-
:returns bool: an indication of whether the provided permissions
620-
reply from a client is valid
621-
"""
622-
# XXX THINK do we need this whole method to be constant-time?
623-
# (basically impossible if it's not even syntactially valid?)
624-
stamp = perms.get("stamp", "")
625-
fields = stamp.split(":")
626-
if len(fields) != 7:
627-
return False
628-
vers, claimed_bits, date, resource, ext, rand, counter = fields
629-
vers = int(vers)
630-
if vers != 1:
631-
return False
632-
if resource != self._hashcash_resource:
633-
return False
634-
635-
claimed_bits = int(claimed_bits)
636-
if claimed_bits < self._bits:
637-
return False
638-
639-
h = hashlib.sha1()
640-
h.update(stamp.encode("utf8"))
641-
measured_hash = h.digest()
642-
if leading_zero_bits(measured_hash) < claimed_bits:
643-
return False
644-
self._passed = True
645-
return True
646-
647-
648556
class Server(service.MultiService):
649557
def __init__(self, db, allow_list, welcome,
650-
blur_usage, usage_db=None, log_file=None, permissions="none"):
558+
blur_usage, usage_db=None, log_file=None, permission_provider=None):
651559
service.MultiService.__init__(self)
652560
self._db = db
653561
self._allow_list = allow_list
@@ -656,8 +564,8 @@ def __init__(self, db, allow_list, welcome,
656564
self._log_requests = blur_usage is None
657565
self._usage_db = usage_db
658566
self._log_file = log_file
659-
self._permissions = permissions
660-
assert self._permissions in ("none", "hashcash")
567+
self._permission_provider = permission_provider
568+
# XXX assert interface instead assert self._permissions in ("none", "hashcash")
661569
self._apps = {}
662570

663571
def get_welcome(self):
@@ -678,14 +586,7 @@ def get_permission_method(self):
678586
679587
:returns IPermissionGranter: a method of permission
680588
"""
681-
if self._permissions == "none":
682-
return NoPermission()
683-
elif self._permissions == "hashcash":
684-
return HashcashPermission()
685-
else:
686-
raise ValueError(
687-
'Unknown permission "{}"'.format(self._permissions)
688-
)
589+
return self._permission_provider()
689590

690591
def get_log_requests(self):
691592
return self._log_requests
@@ -801,7 +702,7 @@ def make_server(db, allow_list=True,
801702
advertise_version=None,
802703
signal_error=None,
803704
blur_usage=None,
804-
permissions="none",
705+
permission_provider=None,
805706
usage_db=None,
806707
log_file=None,
807708
welcome_motd=None,
@@ -827,6 +728,9 @@ def make_server(db, allow_list=True,
827728
if signal_error:
828729
welcome["error"] = signal_error
829730

731+
if permission_provider is None:
732+
permission_provider = create_permission_provider("none")
733+
830734
return Server(db, allow_list=allow_list, welcome=welcome,
831735
blur_usage=blur_usage, usage_db=usage_db, log_file=log_file,
832-
permissions=permissions)
736+
permission_provider=permission_provider)

src/wormhole_mailbox_server/server_tap.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .server import make_server
1010
from .web import make_web_server
1111
from .database import create_or_upgrade_channel_db, create_or_upgrade_usage_db
12+
from .permission import create_permission_provider
1213

1314
LONGDESC = """This plugin sets up a 'Mailbox' server for magic-wormhole.
1415
This service forwards short messages between clients, to perform key exchange
@@ -92,16 +93,18 @@ def makeService(config, channel_db="relay.sqlite", reactor=reactor):
9293
log_file = (os.fdopen(int(config["log-fd"]), "w")
9394
if config["log-fd"] is not None
9495
else None)
95-
server = make_server(channel_db,
96-
allow_list=config["allow-list"],
97-
advertise_version=config["advertise-version"],
98-
signal_error=config["signal-error"],
99-
blur_usage=config["blur-usage"],
100-
permissions=config["permissions"],
101-
usage_db=usage_db,
102-
log_file=log_file,
103-
welcome_motd=config["motd"],
104-
)
96+
97+
server = make_server(
98+
channel_db,
99+
allow_list=config["allow-list"],
100+
advertise_version=config["advertise-version"],
101+
signal_error=config["signal-error"],
102+
blur_usage=config["blur-usage"],
103+
permission_provider=create_permission_provider(config.get("permissions", "none")),
104+
usage_db=usage_db,
105+
log_file=log_file,
106+
welcome_motd=config["motd"],
107+
)
105108
server.setServiceParent(parent)
106109
rebooted = time.time()
107110
def expire():

src/wormhole_mailbox_server/server_websocket.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from twisted.internet import reactor
44
from twisted.python import log
55
from autobahn.twisted import websocket
6-
from .server import CrowdedError, ReclaimedError, SidedMessage, NoPermission
6+
from .server import CrowdedError, ReclaimedError, SidedMessage
7+
from .permission import NoPermission
78
from .util import dict_to_bytes, bytes_to_dict
89

910
# The WebSocket allows the client to send "commands" to the server, and the
@@ -110,6 +111,7 @@ def __init__(self):
110111
self._mailbox_id = None
111112
self._did_close = False
112113
self._permission = None
114+
self._permission_passed = False
113115

114116
def onConnect(self, request):
115117
rv = self.factory.server
@@ -187,7 +189,7 @@ def handle_ping(self, msg):
187189

188190
def handle_bind(self, msg, server_rx):
189191
# if demanding permission, but no permission yet .. error
190-
if self._permission is not None and not self._permission.is_passed():
192+
if not isinstance(self._permission, NoPermission) and not self._permission_passed:
191193
raise Error("must submit-permission first")
192194

193195
if self._app or self._side:
@@ -205,7 +207,8 @@ def handle_bind(self, msg, server_rx):
205207
def handle_submit_permissions(self, msg, server_rx):
206208
if msg.get("method", None) != self._permission.name:
207209
raise Error("need permission method '{}'".format(self._permission.name))
208-
if not self._permission.verify_permission(msg):
210+
self._permission_passed = self._permission.verify_permission(msg)
211+
if not self._permission_passed:
209212
raise Error("submit-permission failed")
210213

211214
def handle_list(self):

0 commit comments

Comments
 (0)