diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..4c492fd --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,115 @@ +from django.test import TestCase +from django.contrib.auth import get_user_model +from core.auth import ( + MyTokenObtainPairOutSchema, + MyRefreshTokenOutSchema, + MyTokenObtainPairInputSchema, + MyTokenRefreshInputSchema, + CustomAuthBackend, + UserSchema +) +from typing import Dict, Union, Set + +User = get_user_model() + +class TestAuth(TestCase): + def setUp(self): + self.user = User.objects.create_user( + username='testuser', + password='testpass123', + email='test@example.com', + first_name='Test' + ) + self.auth_backend = CustomAuthBackend() + + def test_my_refresh_token_out_schema_dict(self): + # 测试MyRefreshTokenOutSchema的dict方法 + schema = MyRefreshTokenOutSchema(access="test_access", refresh="test_refresh") + result = schema.dict() + expected = { + "data": {"access": "test_access", "refresh": "test_refresh"}, + "message": None, + "success": True + } + self.assertEqual(result, expected) + + # 测试带有可选参数的dict方法 + result = schema.dict( + include={"access"}, # type: ignore + exclude={"refresh"}, # type: ignore + by_alias=True, + skip_defaults=True, + exclude_unset=True, + exclude_defaults=True, + exclude_none=True + ) + self.assertEqual(result, { + "data": {"access": "test_access", "refresh": "test_refresh"}, + "message": None, + "success": True + }) + + def test_my_token_obtain_pair_input_schema_get_token(self): + # 测试MyTokenObtainPairInputSchema的get_token方法 + result = MyTokenObtainPairInputSchema.get_token(self.user) + self.assertIn('data', result) + self.assertIn('refresh', result['data']) + self.assertIn('access', result['data']) + self.assertIn('user', result['data']) + self.assertEqual(result['data']['user'].first_name, 'Test') + self.assertEqual(result['data']['user'].email, 'test@example.com') + + def test_custom_auth_backend_authenticate_success(self): + # 测试CustomAuthBackend的authenticate方法 - 成功情况 + authenticated_user = self.auth_backend.authenticate( + None, + username='testuser', + password='testpass123' + ) + self.assertEqual(authenticated_user, self.user) + + def test_custom_auth_backend_authenticate_wrong_password(self): + # 测试CustomAuthBackend的authenticate方法 - 密码错误 + authenticated_user = self.auth_backend.authenticate( + None, + username='testuser', + password='wrongpass' + ) + self.assertIsNone(authenticated_user) + + def test_custom_auth_backend_authenticate_user_not_exist(self): + # 测试CustomAuthBackend的authenticate方法 - 用户不存在 + authenticated_user = self.auth_backend.authenticate( + None, + username='nonexistentuser', + password='testpass123' + ) + self.assertIsNone(authenticated_user) + + def test_user_schema(self): + # 测试UserSchema + user_schema = UserSchema(first_name="Test", email="test@example.com") + self.assertEqual(user_schema.first_name, "Test") + self.assertEqual(user_schema.email, "test@example.com") + + def test_my_token_obtain_pair_out_schema(self): + # 测试MyTokenObtainPairOutSchema + schema = MyTokenObtainPairOutSchema( + refresh="test_refresh", + access="test_access", + user=UserSchema(first_name="Test", email="test@example.com") + ) + self.assertEqual(schema.refresh, "test_refresh") + self.assertEqual(schema.access, "test_access") + self.assertEqual(schema.user.first_name, "Test") + self.assertEqual(schema.user.email, "test@example.com") + + def test_my_token_obtain_pair_input_schema_get_response_schema(self): + # 测试MyTokenObtainPairInputSchema的get_response_schema方法 + response_schema = MyTokenObtainPairInputSchema.get_response_schema() + self.assertTrue(hasattr(response_schema, '__origin__')) + + def test_my_token_refresh_input_schema_get_response_schema(self): + # 测试MyTokenRefreshInputSchema的get_response_schema方法 + response_schema = MyTokenRefreshInputSchema.get_response_schema() + self.assertEqual(response_schema, MyRefreshTokenOutSchema) \ No newline at end of file diff --git a/tests/test_model_operation.py b/tests/test_model_operation.py new file mode 100644 index 0000000..f699b4b --- /dev/null +++ b/tests/test_model_operation.py @@ -0,0 +1,364 @@ +from django.test import TestCase +from django.core.exceptions import ValidationError +from django.http import Http404 +from django.db import models +from pydantic import BaseModel +from employee.models import Employee, Department +from utils.model_opertion import ( + ModelOperation, + ModelOperationLogger, + ModelOperationHelper, + create, + create_obj_with_validate_unique, + update, + partial_update, + update_by_obj, +) + +class EmployeePayload: + def __init__(self, **kwargs): + self.data = kwargs + + def dict(self): + return self.data + +class TestModel(models.Model): + name = models.CharField(max_length=100, unique=True) + creator = models.CharField(max_length=100) + updater = models.CharField(max_length=100, null=True) + + class Meta: + app_label = 'employee' + +class TestPayload(BaseModel): + name: str + +class SaveErrorModel(models.Model): + name = models.CharField(max_length=100) + creator = models.CharField(max_length=100) + updater = models.CharField(max_length=100, null=True) + + def save(self, *args, **kwargs): + raise ValidationError("Test save error") + + class Meta: + app_label = 'employee' + +class TestModelOperationLogger(TestCase): + def setUp(self): + self.logger = ModelOperationLogger() + + def test_log_create_input(self): + with self.assertLogs('utils.model_opertion', level='INFO') as cm: + self.logger.log_create_input('Employee', {'name': 'John'}) + self.assertIn('input: create=Employee, payload={\'name\': \'John\'}', cm.output[0]) + + def test_log_create_success(self): + with self.assertLogs('utils.model_opertion', level='INFO') as cm: + self.logger.log_create_success('Employee', 1) + self.assertIn('create Employee success, id: 1', cm.output[0]) + + def test_log_create_error(self): + with self.assertLogs('utils.model_opertion', level='ERROR') as cm: + self.logger.log_create_error(Exception('test error')) + self.assertIn('test error', cm.output[0]) + + def test_log_update_input(self): + with self.assertLogs('utils.model_opertion', level='INFO') as cm: + self.logger.log_update_input('Employee', {'name': 'John'}) + self.assertIn('input: update=Employee, payload={\'name\': \'John\'}', cm.output[0]) + + def test_log_update_success(self): + with self.assertLogs('utils.model_opertion', level='INFO') as cm: + self.logger.log_update_success('Employee', 1) + self.assertIn('update Employee success, id: 1', cm.output[0]) + + def test_log_update_error(self): + with self.assertLogs('utils.model_opertion', level='WARNING') as cm: + self.logger.log_update_error(Exception('test error')) + self.assertIn('test error', cm.output[0]) + +class TestModelOperationHelper(TestCase): + def setUp(self): + self.helper = ModelOperationHelper() + self.employee = Employee.objects.create( + first_name='John', + last_name='Doe', + creator='test' + ) + + def test_get_object_by_id(self): + obj = self.helper.get_object_by_id(Employee, self.employee.id) + self.assertEqual(obj.id, self.employee.id) + + def test_get_object_by_id_not_found(self): + with self.assertRaises(Http404): + self.helper.get_object_by_id(Employee, 999) + + def test_create_object(self): + obj = self.helper.create_object( + Employee, + first_name='Jane', + last_name='Doe', + creator='test' + ) + self.assertEqual(obj.first_name, 'Jane') + + def test_validate_unique(self): + obj = Employee( + first_name='John', + last_name='Doe', + creator='test' + ) + with self.assertRaises(ValidationError): + self.helper.validate_unique(obj) + + def test_save_object(self): + obj = Employee( + first_name='Jane', + last_name='Smith', + creator='test' + ) + self.helper.save_object(obj) + self.assertIsNotNone(obj.id) + + def test_save_object_error(self): + obj = SaveErrorModel(name='test', creator='test') + with self.assertRaises(ValidationError): + self.helper.save_object(obj) + + def test_update_object_attrs(self): + self.helper.update_object_attrs( + self.employee, + first_name='Jane', + last_name='Smith' + ) + self.assertEqual(self.employee.first_name, 'Jane') + self.assertEqual(self.employee.last_name, 'Smith') + +class TestModelOperation(TestCase): + def setUp(self): + self.mock_logger = Mock(spec=ModelOperationLogger) + self.mock_helper = Mock(spec=ModelOperationHelper) + self.model_operation = ModelOperation( + logger=self.mock_logger, + helper=self.mock_helper + ) + + def test_create_success(self): + payload = EmployeePayload(first_name='John', last_name='Doe') + mock_obj = Mock(id=1) + self.mock_helper.create_object.return_value = mock_obj + + result = self.model_operation.create('test', Employee, payload) + + self.mock_logger.log_create_input.assert_called_once_with('Employee', payload.dict()) + self.mock_helper.create_object.assert_called_once_with( + Employee, + creator='test', + first_name='John', + last_name='Doe' + ) + self.mock_logger.log_create_success.assert_called_once_with('Employee', 1) + self.assertTrue(result.success) + self.assertEqual(result.data.id, 1) + + def test_create_error(self): + payload = EmployeePayload(first_name='John', last_name='Doe') + error = Exception('test error') + self.mock_helper.create_object.side_effect = error + + result = self.model_operation.create('test', Employee, payload) + + self.mock_logger.log_create_error.assert_called_once_with(error) + self.assertFalse(result.success) + self.assertEqual(result.message, 'test error') + + def test_create_obj_with_validate_unique_success(self): + payload = EmployeePayload(first_name='John', last_name='Doe') + mock_obj = Mock(id=1) + self.mock_helper.save_object.return_value = None + + result = self.model_operation.create_obj_with_validate_unique('test', Employee, payload) + + self.mock_logger.log_create_input.assert_called_once_with('Employee', payload.dict()) + self.mock_helper.validate_unique.assert_called_once() + self.mock_helper.save_object.assert_called_once() + self.mock_logger.log_create_success.assert_called_once() + self.assertTrue(result.success) + + def test_create_obj_with_validate_unique_error(self): + payload = EmployeePayload(first_name='John', last_name='Doe') + error = ValidationError('test error') + self.mock_helper.validate_unique.side_effect = error + + with self.assertRaises(ValidationError): + self.model_operation.create_obj_with_validate_unique('test', Employee, payload) + + self.mock_logger.log_create_error.assert_called_once_with(error) + + def test_create_obj_with_validate_unique_save_error(self): + payload = EmployeePayload(first_name='John', last_name='Doe') + error = ValidationError('test error') + self.mock_helper.save_object.side_effect = error + + with self.assertRaises(ValidationError): + self.model_operation.create_obj_with_validate_unique('test', Employee, payload) + + self.mock_logger.log_create_error.assert_called_once_with(error) + + def test_update_success(self): + payload = EmployeePayload(first_name='Jane') + mock_obj = Mock(id=1, __class__=Mock(__name__='Employee')) + self.mock_helper.get_object_by_id.return_value = mock_obj + + result = self.model_operation.update('test', Employee, payload, 1) + + self.mock_helper.get_object_by_id.assert_called_once_with(Employee, 1) + self.mock_logger.log_update_input.assert_called_once() + self.mock_helper.update_object_attrs.assert_called_once() + self.mock_helper.save_object.assert_called_once_with(mock_obj) + self.mock_logger.log_update_success.assert_called_once() + self.assertTrue(result.success) + + def test_update_error(self): + payload = EmployeePayload(first_name='Jane') + mock_obj = Mock(id=1, __class__=Mock(__name__='Employee')) + self.mock_helper.get_object_by_id.return_value = mock_obj + error = Exception('test error') + self.mock_helper.save_object.side_effect = error + + result = self.model_operation.update('test', Employee, payload, 1) + + self.mock_logger.log_update_error.assert_called_once_with(error) + self.assertFalse(result.success) + self.assertEqual(result.message, 'test error') + + def test_update_get_object_error(self): + payload = EmployeePayload(first_name='Jane') + error = Http404('Not found') + self.mock_helper.get_object_by_id.side_effect = error + + with self.assertRaises(Http404): + self.model_operation.update('test', Employee, payload, 1) + + def test_partial_update(self): + mock_obj = Mock(id=1, __class__=Mock(__name__='Employee')) + self.mock_helper.get_object_by_id.return_value = mock_obj + + result = self.model_operation.partial_update('test', Employee, 1, first_name='Jane') + + self.mock_helper.get_object_by_id.assert_called_once_with(Employee, 1) + self.mock_logger.log_update_input.assert_called_once() + self.mock_helper.update_object_attrs.assert_called_once() + self.mock_helper.save_object.assert_called_once_with(mock_obj) + self.assertTrue(result.success) + + def test_partial_update_get_object_error(self): + error = Http404('Not found') + self.mock_helper.get_object_by_id.side_effect = error + + with self.assertRaises(Http404): + self.model_operation.partial_update('test', Employee, 1, first_name='Jane') + + def test_update_by_obj(self): + mock_obj = Mock(id=1, __class__=Mock(__name__='Employee')) + + result = self.model_operation.update_by_obj(mock_obj, 'test', first_name='Jane') + + self.mock_logger.log_update_input.assert_called_once() + self.mock_helper.update_object_attrs.assert_called_once() + self.mock_helper.save_object.assert_called_once_with(mock_obj) + self.assertTrue(result.success) + + def test_update_by_obj_error(self): + mock_obj = Mock(id=1, __class__=Mock(__name__='Employee')) + error = ValidationError('test error') + self.mock_helper.save_object.side_effect = error + + result = self.model_operation.update_by_obj(mock_obj, 'test', first_name='Jane') + + self.mock_logger.log_update_error.assert_called_once_with(error) + self.assertFalse(result.success) + self.assertEqual(result.message, 'test error') + +class TestModelOperationCompatibility(TestCase): + """测试向后兼容性函数""" + def setUp(self): + self.employee = Employee.objects.create( + first_name='John', + last_name='Doe', + creator='test' + ) + + def test_create(self): + payload = EmployeePayload(first_name='Jane', last_name='Smith') + result = create('test', Employee, payload) + self.assertTrue(result.success) + self.assertIsNotNone(result.data.id) + + def test_create_error(self): + payload = EmployeePayload(first_name='Jane', last_name='Smith', department_id=999) # 不存在的department_id + result = create('test', Employee, payload) + self.assertFalse(result.success) + self.assertIsNone(result.data) + + def test_create_obj_with_validate_unique(self): + payload = EmployeePayload( + first_name='John', # 已存在的名字 + last_name='Doe' + ) + with self.assertRaises(ValidationError): + create_obj_with_validate_unique('test', Employee, payload) + + def test_create_obj_with_validate_unique_save_error(self): + payload = EmployeePayload(name='test') + with self.assertRaises(ValidationError): + create_obj_with_validate_unique('test', SaveErrorModel, payload) + + def test_update(self): + payload = EmployeePayload(first_name='Jane') + result = update('test', Employee, payload, self.employee.id) + self.assertTrue(result.success) + self.employee.refresh_from_db() + self.assertEqual(self.employee.first_name, 'Jane') + + def test_update_not_found(self): + payload = EmployeePayload(first_name='Jane') + with self.assertRaises(Http404): + update('test', Employee, payload, 999) + + def test_update_error(self): + payload = EmployeePayload(name='test') + obj = SaveErrorModel.objects.create(name='test', creator='test') + result = update('test', SaveErrorModel, payload, obj.id) + self.assertFalse(result.success) + self.assertIsNone(result.data) + + def test_partial_update(self): + result = partial_update('test', Employee, self.employee.id, first_name='Jane') + self.assertTrue(result.success) + self.employee.refresh_from_db() + self.assertEqual(self.employee.first_name, 'Jane') + + def test_partial_update_not_found(self): + with self.assertRaises(Http404): + partial_update('test', Employee, 999, first_name='Jane') + + def test_partial_update_error(self): + obj = SaveErrorModel.objects.create(name='test', creator='test') + result = partial_update('test', SaveErrorModel, obj.id, name='test2') + self.assertFalse(result.success) + self.assertIsNone(result.data) + + def test_update_by_obj(self): + result = update_by_obj(self.employee, 'test', first_name='Jane') + self.assertTrue(result.success) + self.employee.refresh_from_db() + self.assertEqual(self.employee.first_name, 'Jane') + + def test_update_by_obj_error(self): + obj = SaveErrorModel.objects.create(name='test', creator='test') + result = update_by_obj(obj, 'test', name='test2') + self.assertFalse(result.success) + self.assertIsNone(result.data) \ No newline at end of file diff --git a/utils/model_opertion.py b/utils/model_opertion.py index f4cb6ac..fa1494f 100644 --- a/utils/model_opertion.py +++ b/utils/model_opertion.py @@ -1,5 +1,5 @@ import logging -from typing import TypeVar, Union, Any, Optional +from typing import TypeVar, Union, Any, Optional, Dict, Type import traceback from django.core.exceptions import ValidationError @@ -13,66 +13,127 @@ GenericPayload = TypeVar("GenericPayload") - -def create(creator: str, model: CoreModel, payload: GenericPayload) -> StandResponse[Optional[DictId]]: - """创建对象""" - try: - logger.info(f"input: create={model.__name__}, payload={payload.dict()}") - obj = model.objects.create(creator=creator, **payload.dict()) - except Exception as e: +class ModelOperationLogger: + @staticmethod + def log_create_input(model_name: str, payload: Dict) -> None: + logger.info(f"input: create={model_name}, payload={payload}") + + @staticmethod + def log_create_success(model_name: str, obj_id: int) -> None: + logger.info(f"create {model_name} success, id: {obj_id}") + + @staticmethod + def log_create_error(error: Exception) -> None: logger.error(traceback.format_exc()) - return StandResponse[Optional[DictId]](success=False, message=str(e), data=None) - logger.info(f"create {model.__name__} success, id: {obj.id}") - return StandResponse[Optional[DictId]](data=DictId(id=obj.id)) - - -def create_obj_with_validate_unique( - creator: str, - model: CoreModel, - payload: GenericPayload, - exclude: Any = None -) -> StandResponse[Union[DictId, dict]]: - """创建对象,并且验证唯一性""" - obj = model( - creator=creator, - **payload.dict() - ) - try: - logger.info(f"input: create={model.__name__}, payload={payload.dict()}") - obj.validate_unique(exclude=exclude) - obj.save() - except ValidationError as e: + + @staticmethod + def log_update_input(model_name: str, payload: Dict) -> None: + logger.info(f"input: update={model_name}, payload={payload}") + + @staticmethod + def log_update_success(model_name: str, obj_id: int) -> None: + logger.info(f"update {model_name} success, id: {obj_id}") + + @staticmethod + def log_update_error(error: Exception) -> None: logger.warning(traceback.format_exc()) - raise e - - logger.info(f"create {model.__name__} success, id: {obj.id}") - return StandResponse[Union[DictId, dict]](data=DictId(id=obj.id)) - -def _update(obj, payload: dict, updater: str) -> OptionalDictResponseType: - logger.info(f"input: update={obj.__class__.__name__}, payload={payload}") - obj.updater = updater - for attr, value in payload.items(): - setattr(obj, attr, value) - try: +class ModelOperationHelper: + @staticmethod + def get_object_by_id(model: Type[CoreModel], obj_id: int) -> CoreModel: + return get_object_or_404(model, id=obj_id) + + @staticmethod + def create_object(model: Type[CoreModel], **kwargs) -> CoreModel: + return model.objects.create(**kwargs) + + @staticmethod + def validate_unique(obj: CoreModel, exclude: Any = None) -> None: + obj.validate_unique(exclude=exclude) + + @staticmethod + def save_object(obj: CoreModel) -> None: obj.save() - except Exception as e: - logger.warning(traceback.format_exc()) - return OptionalDictResponseType(success=False, message=str(e), data=None) - logger.info(f"update {obj.__class__.__name__} success, id: {obj.id}") - return OptionalDictResponseType(data=DictId(id=obj.id)) - - -def update(updater: str, model: CoreModel, payload: GenericPayload, obj_id: conint(ge=1)) -> StandResponse[Union[DictId, None]]: - """更新对象""" - obj = get_object_or_404(model, id=obj_id) - return _update(obj=obj, payload=payload.dict(), updater=updater) - - -def partial_update(updater: str, model: CoreModel, obj_id: conint(ge=1), **kwargs) -> OptionalDictResponseType: - obj = get_object_or_404(model, id=obj_id) - return _update(obj=obj, payload=kwargs, updater=updater) - - -def update_by_obj(obj: CoreModel, updater: str, **kwargs) -> OptionalDictResponseType: - return _update(obj=obj, payload=kwargs, updater=updater) + + @staticmethod + def update_object_attrs(obj: CoreModel, **attrs) -> None: + for attr, value in attrs.items(): + setattr(obj, attr, value) + +class ModelOperation: + def __init__(self, logger: ModelOperationLogger = None, helper: ModelOperationHelper = None): + self.logger = logger or ModelOperationLogger() + self.helper = helper or ModelOperationHelper() + + def create(self, creator: str, model: Type[CoreModel], payload: GenericPayload) -> StandResponse[Optional[DictId]]: + """创建对象""" + try: + payload_dict = payload.dict() + self.logger.log_create_input(model.__name__, payload_dict) + + obj = self.helper.create_object(model, creator=creator, **payload_dict) + + self.logger.log_create_success(model.__name__, obj.id) + return StandResponse[Optional[DictId]](data=DictId(id=obj.id)) + except Exception as e: + self.logger.log_create_error(e) + return StandResponse[Optional[DictId]](success=False, message=str(e), data=None) + + def create_obj_with_validate_unique( + self, + creator: str, + model: Type[CoreModel], + payload: GenericPayload, + exclude: Any = None + ) -> StandResponse[Union[DictId, dict]]: + """创建对象,并且验证唯一性""" + payload_dict = payload.dict() + self.logger.log_create_input(model.__name__, payload_dict) + + obj = model(creator=creator, **payload_dict) + + try: + self.helper.validate_unique(obj, exclude=exclude) + self.helper.save_object(obj) + + self.logger.log_create_success(model.__name__, obj.id) + return StandResponse[Union[DictId, dict]](data=DictId(id=obj.id)) + except ValidationError as e: + self.logger.log_create_error(e) + raise e + + def _update(self, obj: CoreModel, payload: dict, updater: str) -> OptionalDictResponseType: + self.logger.log_update_input(obj.__class__.__name__, payload) + + try: + obj.updater = updater + self.helper.update_object_attrs(obj, **payload) + self.helper.save_object(obj) + + self.logger.log_update_success(obj.__class__.__name__, obj.id) + return OptionalDictResponseType(data=DictId(id=obj.id)) + except Exception as e: + self.logger.log_update_error(e) + return OptionalDictResponseType(success=False, message=str(e), data=None) + + def update(self, updater: str, model: Type[CoreModel], payload: GenericPayload, obj_id: conint(ge=1)) -> StandResponse[Union[DictId, None]]: + """更新对象""" + obj = self.helper.get_object_by_id(model, obj_id) + return self._update(obj=obj, payload=payload.dict(), updater=updater) + + def partial_update(self, updater: str, model: Type[CoreModel], obj_id: conint(ge=1), **kwargs) -> OptionalDictResponseType: + obj = self.helper.get_object_by_id(model, obj_id) + return self._update(obj=obj, payload=kwargs, updater=updater) + + def update_by_obj(self, obj: CoreModel, updater: str, **kwargs) -> OptionalDictResponseType: + return self._update(obj=obj, payload=kwargs, updater=updater) + +# 创建全局实例以保持向后兼容性 +model_operation = ModelOperation() + +# 导出函数以保持向后兼容性 +create = model_operation.create +create_obj_with_validate_unique = model_operation.create_obj_with_validate_unique +update = model_operation.update +partial_update = model_operation.partial_update +update_by_obj = model_operation.update_by_obj