diff --git a/addon_service/authorized_storage_account/models.py b/addon_service/authorized_storage_account/models.py index bbe56cf1..9fcef2ba 100644 --- a/addon_service/authorized_storage_account/models.py +++ b/addon_service/authorized_storage_account/models.py @@ -15,6 +15,7 @@ from addon_service.common.validators import validate_addon_capability from addon_service.credentials.models import ExternalCredentials from addon_service.oauth1 import utils as oauth1_utils +from addon_service.oauth1.models import OAuth1TemporaryCredentials from addon_service.oauth2 import utils as oauth2_utils from addon_service.oauth2.models import ( OAuth2ClientConfig, @@ -24,10 +25,7 @@ AddonCapabilities, AddonImp, ) -from addon_toolkit.credentials import ( - Credentials, - OAuth1Credentials, -) +from addon_toolkit.credentials import OAuth1Credentials from addon_toolkit.interfaces.storage import StorageConfig @@ -73,8 +71,8 @@ class AuthorizedStorageAccount(AddonsServiceBaseModel): related_name="authorized_storage_account", ) _temporary_oauth1_credentials = models.OneToOneField( - "addon_service.ExternalCredentials", - on_delete=models.CASCADE, + "addon_service.OAuth1TemporaryCredentials", + on_delete=models.SET_NULL, primary_key=False, null=True, blank=True, @@ -116,20 +114,26 @@ def credentials_format(self): @property def credentials(self): if self._credentials: - return self._credentials.decrypted_credentials + return self._credentials.decrypted_dataclass return None @credentials.setter def credentials(self, credentials_data): - if self.temporary_oauth1_credentials: + if self._temporary_oauth1_credentials: self._temporary_oauth1_credentials.delete() self._temporary_oauth1_credentials = None - self._set_credentials("_credentials", credentials_data) + if not self._credentials: + self._credentials = ExternalCredentials.new() + try: + self._credentials.decrypted_dataclass = credentials_data + self._credentials.save() + except TypeError as e: + raise ValidationError(e) @property def temporary_oauth1_credentials(self) -> OAuth1Credentials | None: if self._temporary_oauth1_credentials: - return self._temporary_oauth1_credentials.decrypted_credentials + return self._temporary_oauth1_credentials.decrypted_dataclass return None @temporary_oauth1_credentials.setter @@ -138,23 +142,11 @@ def temporary_oauth1_credentials(self, credentials_data: OAuth1Credentials): raise ValidationError( "Trying to set temporary credentials for non OAuth1A account" ) - self._set_credentials("_temporary_oauth1_credentials", credentials_data) - - def _set_credentials(self, credentials_field: str, credentials_data: Credentials): - creds_type = type(credentials_data) - if not hasattr(self, credentials_field): - raise ValidationError("Trying to set credentials to non-existing field") - if creds_type is not self.credentials_format.dataclass: - raise ValidationError( - f"Expected credentials of type type {self.credentials_format.dataclass}." - f"Got credentials of type {creds_type}." - ) - if not getattr(self, credentials_field, None): - setattr(self, credentials_field, ExternalCredentials.new()) + if not self._temporary_oauth1_credentials: + self._temporary_oauth1_credentials = OAuth1TemporaryCredentials.new() try: - creds = getattr(self, credentials_field) - creds.decrypted_credentials = credentials_data - creds.save() + self._temporary_oauth1_credentials.decrypted_dataclass = credentials_data + self._temporary_oauth1_credentials.save() except TypeError as e: raise ValidationError(e) @@ -204,15 +196,18 @@ def auth_url(self) -> str | None: return self.oauth2_auth_url case CredentialsFormats.OAUTH1A: return self.oauth1_auth_url + return None @property - def oauth1_auth_url(self) -> str: + def oauth1_auth_url(self) -> str | None: client_config = self.external_service.oauth1_client_config - if self._temporary_oauth1_credentials: + _temporary_creds = self.temporary_oauth1_credentials + if _temporary_creds: return oauth1_utils.build_auth_url( auth_uri=client_config.auth_url, - temporary_oauth_token=self.temporary_oauth1_credentials.oauth_token, + temporary_oauth_token=_temporary_creds.oauth_token, ) + return None @property def oauth2_auth_url(self) -> str | None: diff --git a/addon_service/common/encrypted_dataclass_model.py b/addon_service/common/encrypted_dataclass_model.py new file mode 100644 index 00000000..8486d997 --- /dev/null +++ b/addon_service/common/encrypted_dataclass_model.py @@ -0,0 +1,86 @@ +import typing + +from django.core.exceptions import ValidationError +from django.db import models + +from addon_service.common import encryption +from addon_service.common.base_model import AddonsServiceBaseModel +from addon_service.common.dibs import dibs +from addon_toolkit.json_arguments import json_for_dataclass + + +_DATACLASS = typing.TypeVar("_DATACLASS") + + +class EncryptedDataclassModel(AddonsServiceBaseModel, typing.Generic[_DATACLASS]): + class Meta: + abstract = True + + encrypted_json = models.BinaryField() + _salt = models.BinaryField() + _scrypt_block_size = models.IntegerField() + _scrypt_cost_log2 = models.IntegerField() + _scrypt_parallelization = models.IntegerField() + + @classmethod + def new(cls): + # initialize key-parameter fields with fresh defaults + _new = cls() + _new.encryption_key_parameters = encryption.KeyParameters() + return _new + + @property + def dataclass_type(self) -> type[_DATACLASS]: + raise NotImplementedError( + f"{self.__class__} requires a `dataclass_type` attribute or property" + ) + + def rotate_encryption(self): + with dibs(self): + self.encrypted_json, self.encryption_key_parameters = ( + encryption.pls_rotate_encryption( + encrypted=self.encrypted_json, + stored_params=self.encryption_key_parameters, + ) + ) + self.save() + + @property + def decrypted_dataclass(self) -> _DATACLASS: + return self.dataclass_type(**self.decrypted_json) + + @decrypted_dataclass.setter + def decrypted_dataclass(self, value: _DATACLASS): + if not isinstance(value, self.dataclass_type): + raise ValidationError( + f"expected instance of {self.dataclass_type}, got {value}" + ) + self.decrypted_json = json_for_dataclass(value) + + @property + def decrypted_json(self): + return encryption.pls_decrypt_json( + self.encrypted_json, self.encryption_key_parameters + ) + + @decrypted_json.setter + def decrypted_json(self, value): + self.encrypted_json = encryption.pls_encrypt_json( + value, self.encryption_key_parameters + ) + + @property + def encryption_key_parameters(self) -> encryption.KeyParameters: + return encryption.KeyParameters( + salt=self._salt, + scrypt_block_size=self._scrypt_block_size, + scrypt_cost_log2=self._scrypt_cost_log2, + scrypt_parallelization=self._scrypt_parallelization, + ) + + @encryption_key_parameters.setter + def encryption_key_parameters(self, value: encryption.KeyParameters) -> None: + self._salt = value.salt + self._scrypt_block_size = value.scrypt_block_size + self._scrypt_cost_log2 = value.scrypt_cost_log2 + self._scrypt_parallelization = value.scrypt_parallelization diff --git a/addon_service/credentials/encryption.py b/addon_service/common/encryption.py similarity index 100% rename from addon_service/credentials/encryption.py rename to addon_service/common/encryption.py diff --git a/addon_service/credentials/models.py b/addon_service/credentials/models.py index acbb39fc..00db8f61 100644 --- a/addon_service/credentials/models.py +++ b/addon_service/credentials/models.py @@ -1,21 +1,11 @@ from django.core.exceptions import ValidationError from django.db import models -from addon_service.common.base_model import AddonsServiceBaseModel -from addon_service.common.dibs import dibs +from addon_service.common.encrypted_dataclass_model import EncryptedDataclassModel from addon_toolkit.credentials import Credentials -from addon_toolkit.json_arguments import json_for_dataclass -from . import encryption - - -class ExternalCredentials(AddonsServiceBaseModel): - encrypted_json = models.BinaryField() - _salt = models.BinaryField() - _scrypt_block_size = models.IntegerField() - _scrypt_cost_log2 = models.IntegerField() - _scrypt_parallelization = models.IntegerField() +class ExternalCredentials(EncryptedDataclassModel[Credentials]): # Attributes inherited from back-references: # authorized_storage_account (AuthorizedStorageAccount._credentials, One2One) @@ -27,64 +17,9 @@ class Meta: models.Index(fields=["modified"]), # for schedule_encryption_rotation ) - @classmethod - def new(cls): - # initialize key-parameter fields with fresh defaults - _new = cls() - _new._key_parameters = encryption.KeyParameters() - return _new - - ### - # public encryption-related methods - - @property - def decrypted_credentials(self) -> Credentials: - """Returns a Dataclass instance of the credentials for performing Addon Operations.""" - return self.format.dataclass(**self._decrypted_json) - - @decrypted_credentials.setter - def decrypted_credentials(self, value: Credentials): - self._decrypted_json = json_for_dataclass(value) - - def rotate_encryption(self): - with dibs(self): - self.encrypted_json, self._key_parameters = ( - encryption.pls_rotate_encryption( - encrypted=self.encrypted_json, - stored_params=self._key_parameters, - ) - ) - self.save() - - ### - # private encryption-related methods - - @property - def _decrypted_json(self): - return encryption.pls_decrypt_json(self.encrypted_json, self._key_parameters) - - @_decrypted_json.setter - def _decrypted_json(self, value): - self.encrypted_json = encryption.pls_encrypt_json(value, self._key_parameters) - @property - def _key_parameters(self) -> encryption.KeyParameters: - return encryption.KeyParameters( - salt=self._salt, - scrypt_block_size=self._scrypt_block_size, - scrypt_cost_log2=self._scrypt_cost_log2, - scrypt_parallelization=self._scrypt_parallelization, - ) - - @_key_parameters.setter - def _key_parameters(self, value: encryption.KeyParameters) -> None: - self._salt = value.salt - self._scrypt_block_size = value.scrypt_block_size - self._scrypt_cost_log2 = value.scrypt_cost_log2 - self._scrypt_parallelization = value.scrypt_parallelization - - # END encryption-related methods - ### + def dataclass_type(self) -> type[Credentials]: + return self.format.dataclass @property def authorized_accounts(self): @@ -120,6 +55,6 @@ def _validate_credentials(self): if not self.authorized_accounts: return try: - self.decrypted_credentials + self.decrypted_dataclass except TypeError as e: raise ValidationError(e) diff --git a/addon_service/oauth1/models.py b/addon_service/oauth1/models.py index e11fc0e8..0cdba61c 100644 --- a/addon_service/oauth1/models.py +++ b/addon_service/oauth1/models.py @@ -1,6 +1,10 @@ +import hashlib + from django.db import models from addon_service.common.base_model import AddonsServiceBaseModel +from addon_service.common.encrypted_dataclass_model import EncryptedDataclassModel +from addon_toolkit.credentials import OAuth1Credentials class OAuth1ClientConfig(AddonsServiceBaseModel): @@ -31,3 +35,31 @@ def __repr__(self): return f'<{self.__class__.__qualname__}(pk="{self.pk}", auth_uri="{self.auth_url}, access_token_url="{self.access_token_url}", request_token_url="{self.request_token_url}", client_key="{self.client_key}")>' __str__ = __repr__ + + +class Oauth1TemporaryCredentialsManager(models.Manager): + def filter_by_oauth1_temporary_token(self, oauth1_temporary_token: str): + return self.filter( + oauth1_temporary_token_hash=_hash_temp_token(oauth1_temporary_token) + ) + + +class OAuth1TemporaryCredentials(EncryptedDataclassModel[OAuth1Credentials]): + dataclass_type = OAuth1Credentials + + oauth1_temporary_token_hash = models.CharField() + + class Meta: + verbose_name = "OAuth1 Temporary Credentials" + verbose_name_plural = "OAuth1 Temporary Credentialss" + app_label = "addon_service" + indexes = [ + models.Index(fields="temporary_token_hash"), + ] + + def set_hashed_temporary_token(self, oauth1_temporary_token: str): + self.oauth1_temporary_token_hash = _hash_temp_token(oauth1_temporary_token) + + +def _hash_temp_token(oauth1_temporary_token: str): + return hashlib.sha384(oauth1_temporary_token.encode()) diff --git a/addon_service/oauth1/views.py b/addon_service/oauth1/views.py index f0acd358..6a49af6b 100644 --- a/addon_service/oauth1/views.py +++ b/addon_service/oauth1/views.py @@ -5,26 +5,26 @@ from addon_service.authorized_storage_account.models import AuthorizedStorageAccount from addon_service.common.known_imps import AddonImpNumbers +from addon_service.oauth1.models import OAuth1TemporaryCredentials from addon_service.oauth1.utils import get_access_token from addon_service.oauth_utlis import update_external_account_id -from addon_service.osf_models.fields import decrypt_string def oauth1_callback_view(request): - oauth_token = request.GET["oauth_token"] + temporary_oauth_token = request.GET["oauth_token"] oauth_verifier = request.GET["oauth_verifier"] - pk = decrypt_string(request.session.get("oauth1a_account_id")) - del request.session["oauth1a_account_id"] - - account = AuthorizedStorageAccount.objects.get(pk=pk) - + account = AuthorizedStorageAccount.objects.get( + _temporary_oauth1_credentials__in=OAuth1TemporaryCredentials.objects.filter_by_oauth1_temporary_token( + temporary_oauth_token + ) + ) oauth1_client_config = account.external_service.oauth1_client_config final_credentials, other_info = async_to_sync(get_access_token)( access_token_url=oauth1_client_config.access_token_url, oauth_consumer_key=oauth1_client_config.client_key, oauth_consumer_secret=oauth1_client_config.client_secret, - oauth_token=oauth_token, + oauth_token=temporary_oauth_token, oauth_token_secret=account.temporary_oauth1_credentials.oauth_token_secret, oauth_verifier=oauth_verifier, ) diff --git a/addon_service/tests/_helpers.py b/addon_service/tests/_helpers.py index 1f861235..070f2dd0 100644 --- a/addon_service/tests/_helpers.py +++ b/addon_service/tests/_helpers.py @@ -272,6 +272,6 @@ def _mock_scrypt(secret, salt, n, r, p, dklen, maxmem): return b"\xdd\xd1\xdfN9\n\xbb\xa5\x9a|\xc6\x1f\xd6b\xf2\xfc>\x1e\xfe\xfd\x14\xc6n\xd7\x18\xbf'\x04qk\x8c\xfb" return patch( - "addon_service.credentials.encryption.hashlib.scrypt", + "addon_service.common.encryption.hashlib.scrypt", side_effect=_mock_scrypt, )