Skip to content

Commit 74d87d4

Browse files
author
Roland Hedberg
committed
Cleaned up a bit
1 parent eb12108 commit 74d87d4

File tree

1 file changed

+104
-66
lines changed

1 file changed

+104
-66
lines changed

src/saml2/sigver.py

Lines changed: 104 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,16 @@
5252

5353
from tempfile import NamedTemporaryFile
5454
from subprocess import Popen, PIPE
55-
from xmlenc import EncryptionMethod, EncryptedKey, CipherData, CipherValue, \
56-
EncryptedData
55+
from xmlenc import EncryptionMethod
56+
from xmlenc import EncryptedKey
57+
from xmlenc import CipherData
58+
from xmlenc import CipherValue
59+
from xmlenc import EncryptedData
60+
61+
from Crypto.Hash import SHA256
62+
from Crypto.Hash import SHA384
63+
from Crypto.Hash import SHA512
64+
from Crypto.Hash import SHA
5765

5866
logger = logging.getLogger(__name__)
5967

@@ -63,7 +71,6 @@
6371
RSA_1_5 = "http://www.w3.org/2001/04/xmlenc#rsa-1_5"
6472
TRIPLE_DES_CBC = "http://www.w3.org/2001/04/xmlenc#tripledes-cbc"
6573

66-
from Crypto.Hash import SHA256, SHA384, SHA512, SHA
6774

6875

6976
class SigverError(SAMLError):
@@ -925,12 +932,14 @@ def security_context(conf, debug=None):
925932
raise SigverError('Unknown crypto_backend %s' % (
926933
repr(conf.crypto_backend)))
927934

928-
return SecurityContext(crypto, conf.key_file,
929-
cert_file=conf.cert_file, metadata=metadata,
930-
debug=debug, only_use_keys_in_metadata=_only_md,
931-
cert_handler_extra_class=conf.cert_handler_extra_class,
932-
generate_cert_info=conf.generate_cert_info, tmp_cert_file=conf.tmp_cert_file,
933-
tmp_key_file=conf.tmp_key_file, validate_certificate=conf.validate_certificate)
935+
return SecurityContext(
936+
crypto, conf.key_file, cert_file=conf.cert_file, metadata=metadata,
937+
debug=debug, only_use_keys_in_metadata=_only_md,
938+
cert_handler_extra_class=conf.cert_handler_extra_class,
939+
generate_cert_info=conf.generate_cert_info,
940+
tmp_cert_file=conf.tmp_cert_file,
941+
tmp_key_file=conf.tmp_key_file,
942+
validate_certificate=conf.validate_certificate)
934943

935944

936945
class CertHandlerExtra(object):
@@ -940,7 +949,8 @@ def __init__(self):
940949
def use_generate_cert_func(self):
941950
raise Exception("use_generate_cert_func function must be implemented")
942951

943-
def generate_cert(self, generate_cert_info, root_cert_string, root_key_string):
952+
def generate_cert(self, generate_cert_info, root_cert_string,
953+
root_key_string):
944954
raise Exception("generate_cert function must be implemented")
945955
#Excepts to return (cert_string, key_string)
946956

@@ -953,12 +963,14 @@ def validate_cert(self, cert_str, root_cert_string, root_key_string):
953963

954964

955965
class CertHandler(object):
956-
def __init__(self, security_context, cert_file=None, cert_type="pem", key_file=None, key_type="pem",
957-
generate_cert_info=None, cert_handler_extra_class=None, tmp_cert_file=None, tmp_key_file=None,
958-
verify_cert=False):
966+
def __init__(self, security_context, cert_file=None, cert_type="pem",
967+
key_file=None, key_type="pem", generate_cert_info=None,
968+
cert_handler_extra_class=None, tmp_cert_file=None,
969+
tmp_key_file=None, verify_cert=False):
959970
"""
960-
Initiates the class for handling certificates. Enables the certificates to either be a single certificate
961-
as base functionality or makes it possible to generate a new certificate for each call to the function.
971+
Initiates the class for handling certificates. Enables the certificates
972+
to either be a single certificate as base functionality or makes it
973+
possible to generate a new certificate for each call to the function.
962974
:param key_file:
963975
:param key_type:
964976
:param cert_file:
@@ -968,7 +980,9 @@ def __init__(self, security_context, cert_file=None, cert_type="pem", key_file=N
968980
"""
969981
self._verify_cert = False
970982
self._generate_cert = False
971-
self._last_cert_verified = None #This cert do not have to be valid, it is just the last cert to be validated.
983+
#This cert do not have to be valid, it is just the last cert to be
984+
# validated.
985+
self._last_cert_verified = None
972986
if cert_type == "pem" and key_type == "pem":
973987
self._verify_cert = verify_cert is True
974988
self._security_context = security_context
@@ -978,7 +992,8 @@ def __init__(self, security_context, cert_file=None, cert_type="pem", key_file=N
978992
else:
979993
self._key_str = ""
980994
if cert_file is not None:
981-
self._cert_str = self._osw.read_str_from_file(cert_file, cert_type)
995+
self._cert_str = self._osw.read_str_from_file(cert_file,
996+
cert_type)
982997
else:
983998
self._cert_str = ""
984999

@@ -989,8 +1004,9 @@ def __init__(self, security_context, cert_file=None, cert_type="pem", key_file=N
9891004

9901005
self._cert_info = None
9911006
self._generate_cert_func_active = False
992-
if generate_cert_info is not None and len(self._cert_str) > 0 and len(self._key_str) > 0 \
993-
and tmp_key_file is not None and tmp_cert_file is not None:
1007+
if generate_cert_info is not None and len(self._cert_str) > 0 and \
1008+
len(self._key_str) > 0 and tmp_key_file is not \
1009+
None and tmp_cert_file is not None:
9941010
self._generate_cert = True
9951011
self._cert_info = generate_cert_info
9961012
self._cert_handler_extra_class = cert_handler_extra_class
@@ -999,8 +1015,10 @@ def verify_cert(self, cert_file):
9991015
if self._verify_cert:
10001016
cert_str = self._osw.read_str_from_file(cert_file, "pem")
10011017
self._last_validated_cert = cert_str
1002-
if self._cert_handler_extra_class is not None and self._cert_handler_extra_class.use_validate_cert_func():
1003-
self._cert_handler_extra_class.validate_cert(cert_str, self._cert_str, self._key_str)
1018+
if self._cert_handler_extra_class is not None and \
1019+
self._cert_handler_extra_class.use_validate_cert_func():
1020+
self._cert_handler_extra_class.validate_cert(
1021+
cert_str, self._cert_str, self._key_str)
10041022
else:
10051023
valid, mess = self._osw.verify(self._cert_str, cert_str)
10061024
logger.info("CertHandler.verify_cert: %s" % mess)
@@ -1016,22 +1034,25 @@ def update_cert(self, active=False, client_crt=None):
10161034
self._tmp_cert_str = client_crt
10171035
#No private key for signing
10181036
self._tmp_key_str = ""
1019-
elif self._cert_handler_extra_class is not None and self._cert_handler_extra_class.use_generate_cert_func():
1037+
elif self._cert_handler_extra_class is not None and \
1038+
self._cert_handler_extra_class.use_generate_cert_func():
10201039
(self._tmp_cert_str, self._tmp_key_str) = \
10211040
self._cert_handler_extra_class.generate_cert(self._cert_info, self._cert_str, self._key_str)
10221041
else:
10231042
self._tmp_cert_str, self._tmp_key_str = self._osw.create_certificate(self._cert_info, request=True)
1024-
self._tmp_cert_str = self._osw.create_cert_signed_certificate(self._cert_str, self._key_str,
1025-
self._tmp_cert_str)
1026-
valid, mess = self._osw.verify(self._cert_str, self._tmp_cert_str)
1043+
self._tmp_cert_str = self._osw.create_cert_signed_certificate(
1044+
self._cert_str, self._key_str, self._tmp_cert_str)
1045+
valid, mess = self._osw.verify(self._cert_str,
1046+
self._tmp_cert_str)
10271047
self._osw.write_str_to_file(self._tmp_cert_file, self._tmp_cert_str)
10281048
self._osw.write_str_to_file(self._tmp_key_file, self._tmp_key_str)
10291049
self._security_context.key_file = self._tmp_key_file
10301050
self._security_context.cert_file = self._tmp_cert_file
10311051
self._security_context.key_type = "pem"
10321052
self._security_context.cert_type = "pem"
1033-
self._security_context.my_cert = read_cert_from_file(self._security_context.cert_file,
1034-
self._security_context.cert_type)
1053+
self._security_context.my_cert = read_cert_from_file(
1054+
self._security_context.cert_file,
1055+
self._security_context.cert_type)
10351056

10361057

10371058
# How to get a rsa pub key fingerprint from a certificate
@@ -1043,8 +1064,9 @@ class SecurityContext(object):
10431064
def __init__(self, crypto, key_file="", key_type="pem",
10441065
cert_file="", cert_type="pem", metadata=None,
10451066
debug=False, template="", encrypt_key_type="des-192",
1046-
only_use_keys_in_metadata=False, cert_handler_extra_class=None, generate_cert_info=None,
1047-
tmp_cert_file=None, tmp_key_file=None, validate_certificate=None):
1067+
only_use_keys_in_metadata=False, cert_handler_extra_class=None,
1068+
generate_cert_info=None, tmp_cert_file=None,
1069+
tmp_key_file=None, validate_certificate=None):
10481070

10491071
self.crypto = crypto
10501072
assert (isinstance(self.crypto, CryptoBackend))
@@ -1059,8 +1081,10 @@ def __init__(self, crypto, key_file="", key_type="pem",
10591081

10601082
self.my_cert = read_cert_from_file(cert_file, cert_type)
10611083

1062-
self.cert_handler = CertHandler(self, cert_file, cert_type, key_file, key_type, generate_cert_info,
1063-
cert_handler_extra_class, tmp_cert_file, tmp_key_file, validate_certificate)
1084+
self.cert_handler = CertHandler(self, cert_file, cert_type, key_file,
1085+
key_type, generate_cert_info,
1086+
cert_handler_extra_class, tmp_cert_file,
1087+
tmp_key_file, validate_certificate)
10641088

10651089
self.cert_handler.update_cert(True)
10661090

@@ -1135,7 +1159,8 @@ def verify_signature(self, signedtext, cert_file=None, cert_type="pem",
11351159
)
11361160

11371161
def _check_signature(self, decoded_xml, item, node_name=NODE_NAME,
1138-
origdoc=None, id_attr="", must=False, only_valid_cert=False):
1162+
origdoc=None, id_attr="", must=False,
1163+
only_valid_cert=False):
11391164
#print item
11401165
try:
11411166
issuer = item.issuer.text.strip()
@@ -1179,13 +1204,15 @@ def _check_signature(self, decoded_xml, item, node_name=NODE_NAME,
11791204
try:
11801205
if self.verify_signature(origdoc, pem_file,
11811206
node_name=node_name,
1182-
node_id=item.id, id_attr=id_attr):
1207+
node_id=item.id,
1208+
id_attr=id_attr):
11831209
verified = True
11841210
break
11851211
except Exception:
11861212
if self.verify_signature(decoded_xml, pem_file,
11871213
node_name=node_name,
1188-
node_id=item.id, id_attr=id_attr):
1214+
node_id=item.id,
1215+
id_attr=id_attr):
11891216
verified = True
11901217
break
11911218
else:
@@ -1247,91 +1274,101 @@ def correctly_signed_message(self, decoded_xml, msgtype, must=False,
12471274
return msg
12481275

12491276
return self._check_signature(decoded_xml, msg, class_name(msg),
1250-
origdoc, must=must, only_valid_cert=only_valid_cert)
1277+
origdoc, must=must,
1278+
only_valid_cert=only_valid_cert)
12511279

12521280
def correctly_signed_authn_request(self, decoded_xml, must=False,
12531281
origdoc=None, only_valid_cert=False):
12541282
return self.correctly_signed_message(decoded_xml, "authn_request",
1255-
must, origdoc, only_valid_cert=only_valid_cert)
1283+
must, origdoc,
1284+
only_valid_cert=only_valid_cert)
12561285

12571286
def correctly_signed_authn_query(self, decoded_xml, must=False,
1258-
origdoc=None):
1287+
origdoc=None, only_valid_cert=False):
12591288
return self.correctly_signed_message(decoded_xml, "authn_query",
1260-
must, origdoc)
1289+
must, origdoc, only_valid_cert)
12611290

12621291
def correctly_signed_logout_request(self, decoded_xml, must=False,
1263-
origdoc=None):
1292+
origdoc=None, only_valid_cert=False):
12641293
return self.correctly_signed_message(decoded_xml, "logout_request",
1265-
must, origdoc)
1294+
must, origdoc, only_valid_cert)
12661295

12671296
def correctly_signed_logout_response(self, decoded_xml, must=False,
1268-
origdoc=None):
1297+
origdoc=None, only_valid_cert=False):
12691298
return self.correctly_signed_message(decoded_xml, "logout_response",
1270-
must, origdoc)
1299+
must, origdoc, only_valid_cert)
12711300

12721301
def correctly_signed_attribute_query(self, decoded_xml, must=False,
1273-
origdoc=None):
1302+
origdoc=None, only_valid_cert=False):
12741303
return self.correctly_signed_message(decoded_xml, "attribute_query",
1275-
must, origdoc)
1304+
must, origdoc, only_valid_cert)
12761305

12771306
def correctly_signed_authz_decision_query(self, decoded_xml, must=False,
1278-
origdoc=None):
1307+
origdoc=None,
1308+
only_valid_cert=False):
12791309
return self.correctly_signed_message(decoded_xml,
12801310
"authz_decision_query", must,
1281-
origdoc)
1311+
origdoc, only_valid_cert)
12821312

12831313
def correctly_signed_authz_decision_response(self, decoded_xml, must=False,
1284-
origdoc=None):
1314+
origdoc=None,
1315+
only_valid_cert=False):
12851316
return self.correctly_signed_message(decoded_xml,
12861317
"authz_decision_response", must,
1287-
origdoc)
1318+
origdoc, only_valid_cert)
12881319

12891320
def correctly_signed_name_id_mapping_request(self, decoded_xml, must=False,
1290-
origdoc=None):
1321+
origdoc=None,
1322+
only_valid_cert=False):
12911323
return self.correctly_signed_message(decoded_xml,
12921324
"name_id_mapping_request",
1293-
must, origdoc)
1325+
must, origdoc, only_valid_cert)
12941326

12951327
def correctly_signed_name_id_mapping_response(self, decoded_xml, must=False,
1296-
origdoc=None):
1328+
origdoc=None,
1329+
only_valid_cert=False):
12971330
return self.correctly_signed_message(decoded_xml,
12981331
"name_id_mapping_response",
1299-
must, origdoc)
1332+
must, origdoc, only_valid_cert)
13001333

13011334
def correctly_signed_artifact_request(self, decoded_xml, must=False,
1302-
origdoc=None):
1335+
origdoc=None, only_valid_cert=False):
13031336
return self.correctly_signed_message(decoded_xml,
13041337
"artifact_request",
1305-
must, origdoc)
1338+
must, origdoc, only_valid_cert)
13061339

13071340
def correctly_signed_artifact_response(self, decoded_xml, must=False,
1308-
origdoc=None):
1341+
origdoc=None, only_valid_cert=False):
13091342
return self.correctly_signed_message(decoded_xml,
13101343
"artifact_response",
1311-
must, origdoc)
1344+
must, origdoc, only_valid_cert)
13121345

13131346
def correctly_signed_manage_name_id_request(self, decoded_xml, must=False,
1314-
origdoc=None):
1347+
origdoc=None,
1348+
only_valid_cert=False):
13151349
return self.correctly_signed_message(decoded_xml,
13161350
"manage_name_id_request",
1317-
must, origdoc)
1351+
must, origdoc, only_valid_cert)
13181352

13191353
def correctly_signed_manage_name_id_response(self, decoded_xml, must=False,
1320-
origdoc=None):
1354+
origdoc=None,
1355+
only_valid_cert=False):
13211356
return self.correctly_signed_message(decoded_xml,
13221357
"manage_name_id_response", must,
1323-
origdoc)
1358+
origdoc, only_valid_cert)
13241359

13251360
def correctly_signed_assertion_id_request(self, decoded_xml, must=False,
1326-
origdoc=None):
1361+
origdoc=None,
1362+
only_valid_cert=False):
13271363
return self.correctly_signed_message(decoded_xml,
13281364
"assertion_id_request", must,
1329-
origdoc)
1365+
origdoc, only_valid_cert)
13301366

13311367
def correctly_signed_assertion_id_response(self, decoded_xml, must=False,
1332-
origdoc=None):
1368+
origdoc=None,
1369+
only_valid_cert=False):
13331370
return self.correctly_signed_message(decoded_xml, "assertion", must,
1334-
origdoc)
1371+
origdoc, only_valid_cert)
13351372

13361373
def correctly_signed_response(self, decoded_xml, must=False, origdoc=None):
13371374
""" Check if a instance is correctly signed, if we have metadata for
@@ -1353,11 +1390,12 @@ def correctly_signed_response(self, decoded_xml, must=False, origdoc=None):
13531390
origdoc)
13541391

13551392
if isinstance(response, Response) and (response.assertion or
1356-
response.encrypted_assertion):
1393+
response.encrypted_assertion):
13571394
# Try to find the signing cert in the assertion
13581395
for assertion in (response.assertion or response.encrypted_assertion):
13591396
if response.encrypted_assertion:
1360-
decoded_xml = self.decrypt(assertion.encrypted_data.to_string())
1397+
decoded_xml = self.decrypt(
1398+
assertion.encrypted_data.to_string())
13611399
assertion = saml.assertion_from_string(decoded_xml)
13621400

13631401
if not assertion.signature:

0 commit comments

Comments
 (0)