1313# limitations under the License.
1414
1515import re
16- from typing import Dict , List , Optional , Tuple , Union
16+ from typing import Any , Dict , List , Optional , Tuple , Union
1717from uuid import UUID
1818
1919from fastapi import APIRouter , Depends , HTTPException , Query , Security , status
2929from argilla_server .errors .future .base_errors import MISSING_VECTOR_ERROR_CODE
3030from argilla_server .models import Dataset as DatasetModel
3131from argilla_server .models import Record , User
32- from argilla_server .policies import DatasetPolicyV1 , authorize
32+ from argilla_server .policies import DatasetPolicyV1 , RecordPolicyV1 , authorize , is_authorized
3333from argilla_server .schemas .v1 .datasets import Dataset
3434from argilla_server .schemas .v1 .records import (
3535 Filters ,
8383_VALID_SORT_VALUES = tuple (sort .value for sort in SortOrder )
8484_METADATA_PROPERTY_SORT_BY_REGEX = re .compile (r"^metadata\.(?P<name>(?=.*[a-z0-9])[a-z0-9_-]+)$" )
8585
86-
8786SortByQueryParamParsed = Annotated [
8887 Dict [str , str ],
8988 Depends (
@@ -410,7 +409,7 @@ async def list_current_user_dataset_records(
410409 limit : int = Query (default = LIST_DATASET_RECORDS_LIMIT_DEFAULT , ge = 1 , le = LIST_DATASET_RECORDS_LIMIT_LE ),
411410 current_user : User = Security (auth .get_current_user ),
412411):
413- dataset = await _get_dataset_or_raise (db , dataset_id )
412+ dataset = await _get_dataset_or_raise (db , dataset_id , with_metadata_properties = True )
414413
415414 await authorize (current_user , DatasetPolicyV1 .get (dataset ))
416415
@@ -427,6 +426,10 @@ async def list_current_user_dataset_records(
427426 sort_by_query_param = sort_by_query_param ,
428427 )
429428
429+ for record in records :
430+ record .dataset = dataset
431+ record .metadata_ = await _filter_record_metadata_for_user (record , current_user )
432+
430433 return Records (items = records , total = total )
431434
432435
@@ -570,8 +573,7 @@ async def search_current_user_dataset_records(
570573 limit : int = Query (default = LIST_DATASET_RECORDS_LIMIT_DEFAULT , ge = 1 , le = LIST_DATASET_RECORDS_LIMIT_LE ),
571574 current_user : User = Security (auth .get_current_user ),
572575):
573- dataset = await _get_dataset_or_raise (db , dataset_id , with_fields = True )
574-
576+ dataset = await _get_dataset_or_raise (db , dataset_id , with_fields = True , with_metadata_properties = True )
575577 await authorize (current_user , DatasetPolicyV1 .search_records (dataset ))
576578
577579 await _validate_search_records_query (db , body , dataset_id )
@@ -589,7 +591,7 @@ async def search_current_user_dataset_records(
589591 sort_by_query_param = sort_by_query_param ,
590592 )
591593
592- record_id_score_map = {
594+ record_id_score_map : Dict [ UUID , Dict [ str , Union [ float , SearchRecord , None ]]] = {
593595 response .record_id : {"query_score" : response .score , "search_record" : None }
594596 for response in search_responses .items
595597 }
@@ -603,6 +605,9 @@ async def search_current_user_dataset_records(
603605 )
604606
605607 for record in records :
608+ record .dataset = dataset
609+ record .metadata_ = await _filter_record_metadata_for_user (record , current_user )
610+
606611 record_id_score_map [record .id ]["search_record" ] = SearchRecord (
607612 record = RecordSchema .from_orm (record ), query_score = record_id_score_map [record .id ]["query_score" ]
608613 )
@@ -698,3 +703,14 @@ async def list_dataset_records_search_suggestions_options(
698703 for sa in suggestion_agents_by_question
699704 ]
700705 )
706+
707+
708+ async def _filter_record_metadata_for_user (record : Record , user : User ) -> Optional [Dict [str , Any ]]:
709+ if record .metadata_ is None :
710+ return None
711+
712+ metadata = {}
713+ for metadata_name in list (record .metadata_ .keys ()):
714+ if await is_authorized (user , RecordPolicyV1 .get_metadata (record , metadata_name )):
715+ metadata [metadata_name ] = record .metadata_ [metadata_name ]
716+ return metadata
0 commit comments