Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions apps/application/views/application_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from common.constants.permission_constants import CompareConstants, PermissionConstants, Permission, Group, Operate, \
ViewPermission, RoleConstants
from common.exception.app_exception import AppAuthenticationFailed
from common.log.log import log
from common.response import result
from common.swagger_api.common_api import CommonApi
from common.util.common import query_params_to_single_dict
Expand Down Expand Up @@ -603,6 +604,7 @@ class Page(APIView):
responses=result.get_page_api_response(ApplicationApi.get_response_body_api()),
tags=[_('Application')])
@has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND)
@log(menu=_('Application'), operate=_("Get the application list by page"))
def get(self, request: Request, current_page: int, page_size: int):
return result.success(
ApplicationSerializer.Query(
Expand Down
85 changes: 85 additions & 0 deletions apps/common/log/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: log.py
@date:2025/3/14 16:09
@desc:
"""

from setting.models.log_management import Log


def _get_ip_address(request):
"""
获取ip地址
@param request:
@return:
"""
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
ip = x_forwarded_for.split(',')[0]
else:
ip = request.META.get('REMOTE_ADDR')
return ip


def _get_user(request):
"""
获取用户
@param request:
@return:
"""
user = request.user
return {
"id": str(user.id),
"email": user.email,
"phone": user.phone,
"nick_name": user.nick_name,
"username": user.username,
"role": user.role,
}


def _get_details(request):
path = request.path
body = request.data
query = request.query_params
return {
'path': path,
'body': body,
'query': query
}


def log(menu: str, operate, get_user=_get_user, get_ip_address=_get_ip_address, get_details=_get_details):
"""
记录审计日志
@param menu: 操作菜单 str
@param operate: 操作 str|func 如果是一个函数 入参将是一个request 响应为str def operate(request): return "操作菜单"
@param get_user: 获取用户
@param get_ip_address:获取IP地址
@param get_details: 获取执行详情
@return:
"""

def inner(func):
def run(view, request, **kwargs):
status = 200
try:
return func(view, request, **kwargs)
except Exception as e:
status = 500
finally:
ip = get_ip_address(request)
user = get_user(request)
details = get_details(request)
_operate = operate
if callable(operate):
_operate = operate(request)
# 插入审计日志
Log(menu=menu, operate=_operate, user=user, status=status, ip_address=ip, details=details).save()

return run

return inner
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are some points of improvement:

  1. Variable Naming: The function names like log, _operate, etc., could be more descriptive to reflect what they do.

  2. Exception Handling: Adding exception handling for database operations can make error management better.

  3. Docstring Clarity: Enhance the docstrings with comments explaining their purpose and usage.

  4. Function Annotations: Explicitly annotate the parameters of the functions to improve readability.

  5. Use Default Arguments Wisely: Ensure default arguments behave as intended without causing unintended side effects.

  6. Code Duplication: Remove duplicate logic, such as repeated calls to _get_ip_address and _get_user.

  7. Logging Levels: Consider adding different logging levels (e.g., debug/info) in the log entry based on the status code.

Here's a revised version incorporating these points:

# coding=utf-8
"""
@project: MaxKB
@author:虎
@file: log.py
@date:2025/3/14 16:09
@desc:
"""

from django.db import transaction
from .models.log_management import Log


def _get_ip_address(request):
    """
    获取 IP 地址
    :param request:
    :return:
    """
    x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
    if x_forwarded_for:
        ip = x_forwarded_for.split(',')[0]
    else:
        ip = request.META.get('REMOTE_ADDR', '')
    return ip


def _get_user(request):
    """
    获取用户信息
    :param request:
    :return:
    """
    user = request.user
    return {
        "id": str(user.id),
        "email": user.email,
        "phone": user.phone,
        "nick_name": user.nick_name,
        "username": user.username,
        "role": getattr(user, 'role'),
    }


def _get_details(request):
    """
    获取请求详细信息
    :param request:
    :return:
    """
    path = request.path
    body = request.body.decode() if request.method == 'POST' else None
    query = request.GET.dict()
    return {
        'path': path,
        'body': body,
        'query': query
    }


def record_audit_log(menu: str, operate: Union[str, Callable], get_user: Optional[Callable] = _get_user, 
                    get_ip_address: Optional[Callable] = _get_ip_address, 
                    get_details: Optional[Callable] = _get_details) -> Callable[[Any), Any]]:
    """
    记录审计日志的装饰器

    :param menu: 操作类别(str)
    :param operate: 操作方法(str或callable),若 callable 则为函数,入参为请求对象,返回值为要写入的日志内容
    :param get_user: 用户提取函数,默认使用内置的 `_get_user`
    :param get_ip_address: 请求IP地址提取函数,默认使用内置的 `_get_ip_address`
    :param get_details: 请求详细信息提取函数,默认使用内置的 `_get_details`

    返回:
    打印并记录到数据库中的视图处理函数。
    """

    def decorator(view_func):
        @wraps(view_func)
        def wrapper(request, *args, **kwargs):
            try:
                response = view_func(request, *args, **kwargs)
                status_code = response.status_code
                user_info = get_user(request)
                ip_address = get_ip_address(request)
                details = get_details(request)

                # If operate is callable, call it with the current request
                if callable(operate):
                    _operate = operate(request)
                else:
                    _operate = operate

                # Insert audit log into DB using non-committing transactions
                with transaction.atomic():
                    Log.objects.create(
                        menu=menu,
                        operate=_operate,
                        user=user_info,
                        status=status_code,
                        ip_address=ip_address,
                        details=details
                    )
                
                return response
            except Exception as e:
                # Handle errors while writing to the database
                print(f"Error while recording log: {e}")
                raise
            
        return wrapper
    
    return decorator

This revised version addresses several issues identified, making it clearer, maintainable, and robust.

16 changes: 16 additions & 0 deletions apps/dataset/views/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import PermissionConstants, CompareConstants, Permission, Group, Operate, \
ViewPermission, RoleConstants
from common.log.log import log
from common.response import result
from common.response.result import get_page_request_params, get_page_api_response, get_api_response
from common.swagger_api.common_api import CommonApi
Expand All @@ -42,6 +43,7 @@ class SyncWeb(APIView):
dynamic_tag=keywords.get('dataset_id'))],
compare=CompareConstants.AND), PermissionConstants.DATASET_EDIT,
compare=CompareConstants.AND)
@log(menu=_('Knowledge Base'), operate=_("同步Web站点知识库"))
def put(self, request: Request, dataset_id: str):
return result.success(DataSetSerializers.SyncWeb(
data={'sync_type': request.query_params.get('sync_type'), 'id': dataset_id,
Expand All @@ -60,6 +62,7 @@ class CreateQADataset(APIView):
tags=[_('Knowledge Base')]
)
@has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND)
@log(menu=_('Knowledge Base'), operate=_("创建QA知识库"))
def post(self, request: Request):
return result.success(DataSetSerializers.Create(data={'user_id': request.user.id}).save_qa({
'file_list': request.FILES.getlist('file'),
Expand All @@ -79,6 +82,7 @@ class CreateWebDataset(APIView):
tags=[_('Knowledge Base')]
)
@has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND)
@log(menu=_('Knowledge Base'), operate=_("Create a web site knowledge base"))
def post(self, request: Request):
return result.success(DataSetSerializers.Create(data={'user_id': request.user.id}).save_web(request.data))

Expand All @@ -92,6 +96,7 @@ class Application(APIView):
responses=result.get_api_array_response(
DataSetSerializers.Application.get_response_body_api()),
tags=[_('Knowledge Base')])
@log(menu=_('Knowledge Base'), operate=_("Get a list of applications available in the knowledge base"))
def get(self, request: Request, dataset_id: str):
return result.success(DataSetSerializers.Operate(
data={'id': dataset_id, 'user_id': str(request.user.id)}).list_application())
Expand All @@ -103,6 +108,7 @@ def get(self, request: Request, dataset_id: str):
responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api()),
tags=[_('Knowledge Base')])
@has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND)
@log(menu=_('Knowledge Base'), operate=_("Get a list of knowledge bases"))
def get(self, request: Request):
data = {key: str(value) for key, value in request.query_params.items()}
d = DataSetSerializers.Query(data={**data, 'user_id': str(request.user.id)})
Expand All @@ -117,6 +123,7 @@ def get(self, request: Request):
tags=[_('Knowledge Base')]
)
@has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND)
@log(menu=_('Knowledge Base'), operate=_("Create a knowledge base"))
def post(self, request: Request):
return result.success(DataSetSerializers.Create(data={'user_id': request.user.id}).save(request.data))

Expand All @@ -130,6 +137,7 @@ class HitTest(APIView):
tags=[_('Knowledge Base')])
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=keywords.get('dataset_id')))
@log(menu=_('Knowledge Base'), operate=_("Hit test list"))
def get(self, request: Request, dataset_id: str):
return result.success(
DataSetSerializers.HitTest(data={'id': dataset_id, 'user_id': request.user.id,
Expand All @@ -150,6 +158,7 @@ class Embedding(APIView):
)
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=keywords.get('dataset_id')))
@log(menu=_('Knowledge Base'), operate=_("Re-vectorize"))
def put(self, request: Request, dataset_id: str):
return result.success(
DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).re_embedding())
Expand All @@ -164,6 +173,7 @@ class Export(APIView):
)
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=keywords.get('dataset_id')))
@log(menu=_('Knowledge Base'), operate=_("Export knowledge base"))
def get(self, request: Request, dataset_id: str):
return DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).export_excel()

Expand All @@ -178,6 +188,7 @@ class ExportZip(APIView):
)
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=keywords.get('dataset_id')))
@log(menu=_('Knowledge Base'), operate=_("Export knowledge base containing images"))
def get(self, request: Request, dataset_id: str):
return DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).export_zip()

Expand All @@ -193,6 +204,7 @@ class Operate(APIView):
dynamic_tag=keywords.get('dataset_id')),
lambda r, k: Permission(group=Group.DATASET, operate=Operate.DELETE,
dynamic_tag=k.get('dataset_id')), compare=CompareConstants.AND)
@log(menu=_('Knowledge Base'), operate=_("Delete knowledge base"))
def delete(self, request: Request, dataset_id: str):
operate = DataSetSerializers.Operate(data={'id': dataset_id})
return result.success(operate.delete())
Expand All @@ -205,6 +217,7 @@ def delete(self, request: Request, dataset_id: str):
tags=[_('Knowledge Base')])
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE,
dynamic_tag=keywords.get('dataset_id')))
@log(menu=_('Knowledge Base'), operate=_("Query knowledge base details based on knowledge base id"))
def get(self, request: Request, dataset_id: str):
return result.success(DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).one(
user_id=request.user.id))
Expand All @@ -219,6 +232,7 @@ def get(self, request: Request, dataset_id: str):
)
@has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=keywords.get('dataset_id')))
@log(menu=_('Knowledge Base'), operate=_("Modify knowledge base information"))
def put(self, request: Request, dataset_id: str):
return result.success(
DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).edit(request.data,
Expand All @@ -236,6 +250,7 @@ class Page(APIView):
tags=[_('Knowledge Base')]
)
@has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND)
@log(menu=_('Knowledge Base'), operate=_("Get the knowledge base paginated list"))
def get(self, request: Request, current_page, page_size):
d = DataSetSerializers.Query(
data={'name': request.query_params.get('name', None), 'desc': request.query_params.get("desc", None),
Expand All @@ -253,6 +268,7 @@ class Model(APIView):
[lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=keywords.get('dataset_id'))],
compare=CompareConstants.AND))
@log(menu=_('Knowledge Base'), operate=_("Get the model list of the knowledge base"))
def get(self, request: Request, dataset_id: str):
return result.success(
ModelSerializer.Query(
Expand Down
Loading