|
12 | 12 | from datetime import timedelta |
13 | 13 | from functools import cached_property |
14 | 14 | from hashlib import sha256 |
| 15 | +from typing import Set, List, Optional, Tuple, Dict |
15 | 16 |
|
16 | 17 | import dns |
17 | 18 | import psl_dns |
18 | 19 | import rest_framework.authtoken.models |
| 20 | +from cryptography import x509, hazmat |
19 | 21 | from django.conf import settings |
20 | 22 | from django.contrib.auth.hashers import make_password |
21 | 23 | from django.contrib.auth.models import AbstractBaseUser, AnonymousUser, BaseUserManager |
@@ -982,3 +984,209 @@ def verify(self, solution: str): |
982 | 984 | and |
983 | 985 | age <= settings.CAPTCHA_VALIDITY_PERIOD # not expired |
984 | 986 | ) |
| 987 | + |
| 988 | + |
| 989 | +class Identity(models.Model): |
| 990 | + rr_type = None |
| 991 | + |
| 992 | + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) |
| 993 | + name = models.CharField(max_length=24, default="") |
| 994 | + created = models.DateTimeField(auto_now_add=True) |
| 995 | + owner = models.ForeignKey(User, on_delete=models.PROTECT, related_name='identities') |
| 996 | + default_ttl = models.PositiveIntegerField(default=300) |
| 997 | + |
| 998 | + class Meta: |
| 999 | + abstract = True |
| 1000 | + |
| 1001 | + def get_record_contents(self) -> List[str]: |
| 1002 | + raise NotImplementedError |
| 1003 | + |
| 1004 | + def save_rrs(self): |
| 1005 | + raise NotImplementedError |
| 1006 | + |
| 1007 | + def save(self, *args, **kwargs): |
| 1008 | + self.save_rrs() |
| 1009 | + return super().save(*args, **kwargs) |
| 1010 | + |
| 1011 | + def delete_rrs(self): |
| 1012 | + raise NotImplementedError |
| 1013 | + |
| 1014 | + def delete(self, using=None, keep_parents=False): |
| 1015 | + # TODO this will delete also RRs that may be covered by other identities |
| 1016 | + self.delete_rrs() |
| 1017 | + return super().delete(using, keep_parents) |
| 1018 | + |
| 1019 | + def get_or_create_rr_set(self, domain: Domain, subname: str) -> RRset: |
| 1020 | + try: |
| 1021 | + return RRset.objects.get(domain=domain, subname=subname, type=self.rr_type) |
| 1022 | + except RRset.DoesNotExist: |
| 1023 | + # TODO save this RRset? |
| 1024 | + return RRset(domain=domain, subname=subname, type=self.rr_type, ttl=self.default_ttl) |
| 1025 | + |
| 1026 | + @staticmethod |
| 1027 | + def get_or_create_rr(rrset: RRset, content: str) -> RR: |
| 1028 | + try: |
| 1029 | + return RR.objects.get(rrset=rrset, content=content) |
| 1030 | + except RR.DoesNotExist: |
| 1031 | + return RR(rrset=rrset, content=content) |
| 1032 | + |
| 1033 | + |
| 1034 | +class TLSIdentity(Identity): |
| 1035 | + rr_type = 'TLSA' |
| 1036 | + |
| 1037 | + class CertificateUsage(models.IntegerChoices): |
| 1038 | + CA_CONSTRAINT = 0 |
| 1039 | + SERVICE_CERTIFICATE_CONSTRAINT = 1 |
| 1040 | + TRUST_ANCHOR_ASSERTION = 2 |
| 1041 | + DOMAIN_ISSUED_CERTIFICATE = 3 |
| 1042 | + |
| 1043 | + class Selector(models.IntegerChoices): |
| 1044 | + FULL_CERTIFICATE = 0 |
| 1045 | + SUBJECT_PUBLIC_KEY_INFO = 1 |
| 1046 | + |
| 1047 | + class MatchingType(models.IntegerChoices): |
| 1048 | + NO_HASH_USED = 0 |
| 1049 | + SHA256 = 1 |
| 1050 | + SHA512 = 2 |
| 1051 | + |
| 1052 | + class Protocol(models.TextChoices): |
| 1053 | + TCP = 'tcp' |
| 1054 | + UDP = 'udp' |
| 1055 | + SCTP = 'sctp' |
| 1056 | + |
| 1057 | + certificate = models.TextField() |
| 1058 | + |
| 1059 | + tlsa_selector = models.IntegerField(choices=Selector.choices, default=Selector.SUBJECT_PUBLIC_KEY_INFO) |
| 1060 | + tlsa_matching_type = models.IntegerField(choices=MatchingType.choices, default=MatchingType.SHA256) |
| 1061 | + tlsa_certificate_usage = models.IntegerField(choices=CertificateUsage.choices, |
| 1062 | + default=CertificateUsage.DOMAIN_ISSUED_CERTIFICATE) |
| 1063 | + |
| 1064 | + port = models.IntegerField(default=443) |
| 1065 | + protocol = models.TextField(choices=Protocol.choices, default=Protocol.TCP) |
| 1066 | + |
| 1067 | + scheduled_removal = models.DateTimeField(null=True) |
| 1068 | + |
| 1069 | + def __init__(self, *args, **kwargs): |
| 1070 | + super().__init__(*args, **kwargs) |
| 1071 | + if 'not_valid_after' not in kwargs: |
| 1072 | + self.scheduled_removal = self.not_valid_after |
| 1073 | + |
| 1074 | + def get_record_contents(self) -> List[str]: |
| 1075 | + # choose hash function |
| 1076 | + if self.tlsa_matching_type == self.MatchingType.SHA256: |
| 1077 | + hash_function = hazmat.primitives.hashes.SHA256() |
| 1078 | + elif self.tlsa_matching_type == self.MatchingType.SHA512: |
| 1079 | + hash_function = hazmat.primitives.hashes.SHA512() |
| 1080 | + else: |
| 1081 | + raise NotImplementedError |
| 1082 | + |
| 1083 | + # choose data to hash |
| 1084 | + if self.tlsa_selector == self.Selector.SUBJECT_PUBLIC_KEY_INFO: |
| 1085 | + to_be_hashed = self._cert.public_key().public_bytes( |
| 1086 | + hazmat.primitives.serialization.Encoding.DER, |
| 1087 | + hazmat.primitives.serialization.PublicFormat.SubjectPublicKeyInfo |
| 1088 | + ) |
| 1089 | + else: |
| 1090 | + raise NotImplementedError |
| 1091 | + |
| 1092 | + # compute the hash |
| 1093 | + h = hazmat.primitives.hashes.Hash(hash_function) |
| 1094 | + h.update(to_be_hashed) |
| 1095 | + hash = h.finalize().hex() |
| 1096 | + |
| 1097 | + # create TLSA record content |
| 1098 | + return [f"{self.tlsa_certificate_usage} {self.tlsa_selector} {self.tlsa_matching_type} {hash}"] |
| 1099 | + |
| 1100 | + @property |
| 1101 | + def _cert(self) -> x509.Certificate: |
| 1102 | + return x509.load_pem_x509_certificate(self.certificate.encode()) |
| 1103 | + |
| 1104 | + @property |
| 1105 | + def fingerprint(self) -> str: |
| 1106 | + return self._cert.fingerprint(hazmat.primitives.hashes.SHA256()).hex() |
| 1107 | + |
| 1108 | + @property |
| 1109 | + def subject_names(self) -> Set[str]: |
| 1110 | + subject_names = { |
| 1111 | + x.value for x in |
| 1112 | + self._cert.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME) |
| 1113 | + } |
| 1114 | + |
| 1115 | + try: |
| 1116 | + subject_alternative_names = { |
| 1117 | + x for x in |
| 1118 | + self._cert.extensions.get_extension_for_oid( |
| 1119 | + x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME).value.get_values_for_type(x509.DNSName) |
| 1120 | + } |
| 1121 | + except x509.extensions.ExtensionNotFound: |
| 1122 | + subject_alternative_names = set() |
| 1123 | + |
| 1124 | + return subject_names | subject_alternative_names |
| 1125 | + |
| 1126 | + @staticmethod |
| 1127 | + def get_closest_ancestor(domain_name, owner: User) -> Optional[Domain]: |
| 1128 | + # TODO move to Domain? |
| 1129 | + labels = domain_name.split('.') |
| 1130 | + ancestor_names = ['.'.join(labels[i:]) for i in range(len(labels))] |
| 1131 | + for ancestor_name in ancestor_names: # TODO do this with one query |
| 1132 | + try: |
| 1133 | + return Domain.objects.get(name=ancestor_name, owner=owner) |
| 1134 | + except Domain.DoesNotExist: |
| 1135 | + continue |
| 1136 | + return None |
| 1137 | + |
| 1138 | + def domains_subnames(self) -> Set[Tuple[Domain, str]]: |
| 1139 | + domains_subnames = set() |
| 1140 | + for name in self.subject_names: |
| 1141 | + # cut off any wildcard prefix |
| 1142 | + name = name.lstrip('*').lstrip('.') |
| 1143 | + |
| 1144 | + # filter names for valid domain names |
| 1145 | + try: |
| 1146 | + validate_domain_name[1](name) |
| 1147 | + except ValidationError: |
| 1148 | + continue |
| 1149 | + |
| 1150 | + # find user-owned parent domain |
| 1151 | + domain = self.get_closest_ancestor(name, self.owner) |
| 1152 | + if not domain: |
| 1153 | + continue |
| 1154 | + subname = name[:-len(domain.name)].rstrip('.') |
| 1155 | + |
| 1156 | + # return subname, domain pair |
| 1157 | + domains_subnames.add((domain, f"_{self.port:n}._{self.protocol}.{subname}".rstrip('.'))) |
| 1158 | + return domains_subnames |
| 1159 | + |
| 1160 | + def get_rrsets(self) -> List[RRset]: |
| 1161 | + rrsets = [] |
| 1162 | + for domain, subname in self.domains_subnames(): |
| 1163 | + rrsets.append(self.get_or_create_rr_set(domain, subname)) |
| 1164 | + return rrsets |
| 1165 | + |
| 1166 | + def get_rrs(self) -> List[RR]: |
| 1167 | + rrs = [] |
| 1168 | + for domain, subname in self.domains_subnames(): |
| 1169 | + rrset = self.get_or_create_rr_set(domain, subname) |
| 1170 | + for content in self.get_record_contents(): |
| 1171 | + rrs.append(self.get_or_create_rr(rrset=rrset, content=content)) |
| 1172 | + return rrs |
| 1173 | + |
| 1174 | + def save_rrs(self): |
| 1175 | + for rr in self.get_rrs(): |
| 1176 | + rr.rrset.save() |
| 1177 | + rr.save() |
| 1178 | + |
| 1179 | + def delete_rrs(self): |
| 1180 | + for domain, subname in self.domains_subnames(): |
| 1181 | + rrset = self.get_or_create_rr_set(domain, subname) |
| 1182 | + rrset.records.filter(content__in=self.get_record_contents()).delete() |
| 1183 | + if not len(rrset.records.all()): |
| 1184 | + rrset.delete() |
| 1185 | + |
| 1186 | + @property |
| 1187 | + def not_valid_before(self): |
| 1188 | + return self._cert.not_valid_before |
| 1189 | + |
| 1190 | + @property |
| 1191 | + def not_valid_after(self): |
| 1192 | + return self._cert.not_valid_after |
0 commit comments