diff --git a/synapseclient/api/__init__.py b/synapseclient/api/__init__.py index 702fc3f85..3174d0c08 100644 --- a/synapseclient/api/__init__.py +++ b/synapseclient/api/__init__.py @@ -20,6 +20,7 @@ get_entity_id_version_bundle2, post_entity_bundle2_create, put_entity_id_bundle2, + store_entity_with_bundle2, ) from .entity_factory import get_from_entity_factory from .entity_services import ( @@ -64,6 +65,7 @@ "get_entity_id_version_bundle2", "post_entity_bundle2_create", "put_entity_id_bundle2", + "store_entity_with_bundle2", # file_services "post_file_multipart", "put_file_multipart_add", diff --git a/synapseclient/api/entity_bundle_services_v2.py b/synapseclient/api/entity_bundle_services_v2.py index d842b92ed..3b9702600 100644 --- a/synapseclient/api/entity_bundle_services_v2.py +++ b/synapseclient/api/entity_bundle_services_v2.py @@ -152,3 +152,92 @@ async def put_entity_id_bundle2( + (f"?generatedBy={generated_by}" if generated_by else ""), body=json.dumps(request), ) + + +async def store_entity_with_bundle2( + entity: Dict[str, Any], + parent_id: Optional[str] = None, + acl: Optional[Dict[str, Any]] = None, # TODO: Consider skipping ACL? + annotations: Optional[Dict[str, Any]] = None, + activity: Optional[Dict[str, Any]] = None, + new_version: bool = False, + force_version: bool = False, + *, + synapse_client: Optional["Synapse"] = None, +) -> Dict[str, Any]: + """ + Store an entity in Synapse using the bundle2 API endpoints to reduce HTTP calls. + + This function follows a specific flow: + 1. Determines if the operation is a create or update: + - If no ID is provided, searches for the ID via /entity/child + - If no ID is found, treats as a Create + - If an ID is found, treats as an Update + + 2. For Updates: + - Retrieves entity by ID and merges with existing data + - Updates desired fields in the retrieved object + - Pushes modified object with HTTP PUT if there are changes + + 3. For Creates: + - Creates a new object with desired fields + - Pushes the new object with HTTP POST + + Arguments: + entity: The entity to store. + parent_id: The ID of the parent entity for creation. + acl: Access control list for the entity. + annotations: Annotations to associate with the entity. + activity: Activity to associate with the entity. + new_version: If True, create a new version of the entity. + force_version: If True, forces a new version of an entity even if nothing has changed. + synapse_client: Synapse client instance. + + Returns: + The stored entity bundle. + """ + from synapseclient import Synapse + + client = Synapse.get_client(synapse_client=synapse_client) + + # Determine if this is a create or update operation + entity_id = entity.get("id", None) + + # Construct bundle request based on provided data + bundle_request = {"entity": entity} + + if annotations: + bundle_request["annotations"] = annotations + + if acl: + bundle_request["accessControlList"] = acl + + if activity: + bundle_request["activity"] = activity + + # Handle create or update + if not entity_id: + # This is a creation + client.logger.debug("Creating new entity via bundle2 API") + + # For creation, parent ID is required + # TODO: Projects won't have a parent in this case + # if parent_id: + # # Add parentId to the entity if not already set + # if not entity.get("parentId"): + # entity["parentId"] = parent_id + # elif not entity.get("parentId"): + # raise ValueError("Parent ID must be provided for entity creation") + + # Create entity using bundle2 create endpoint + return await post_entity_bundle2_create( + request=bundle_request, + generated_by=activity.get("id") if activity else None, + synapse_client=synapse_client, + ) + else: + # This is an update + client.logger.debug(f"Updating entity {entity_id} via bundle2 API") + + # For updates we might need to retrieve the existing entity to merge data + # Only retrieve if we need diff --git a/synapseclient/api/entity_factory.py b/synapseclient/api/entity_factory.py index 020ff2710..9dcf83335 100644 --- a/synapseclient/api/entity_factory.py +++ b/synapseclient/api/entity_factory.py @@ -245,7 +245,7 @@ async def _handle_file_entity( from synapseclient.models import FileHandle entity_instance.fill_from_dict( - synapse_file=entity_bundle["entity"], set_annotations=False + synapse_file=entity_bundle["entity"], annotations=None ) # Update entity with FileHandle metadata @@ -401,7 +401,7 @@ class type. This will also download the file if `download_file` is set to True. ) else: # Handle all other entity types - entity_instance.fill_from_dict(entity_bundle["entity"], set_annotations=False) + entity_instance.fill_from_dict(entity_bundle["entity"], annotations=None) if annotations: entity_instance.annotations = annotations diff --git a/synapseclient/models/annotations.py b/synapseclient/models/annotations.py index d6c1d609c..c8d2fa958 100644 --- a/synapseclient/models/annotations.py +++ b/synapseclient/models/annotations.py @@ -1,13 +1,16 @@ """The required data for working with annotations in Synapse""" -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from datetime import date, datetime from typing import Dict, List, Optional, Union +from typing_extensions import Any + from synapseclient import Synapse -from synapseclient.annotations import ANNO_TYPE_TO_FUNC +from synapseclient.annotations import ANNO_TYPE_TO_FUNC, _convert_to_annotations_list from synapseclient.api import set_annotations_async from synapseclient.core.async_utils import async_to_sync +from synapseclient.core.utils import delete_none_keys from synapseclient.models.protocols.annotations_protocol import ( AnnotationsSynchronousProtocol, ) @@ -136,3 +139,23 @@ def from_dict( annotations[key] = dict_to_convert[key] return annotations + + def to_synapse_request(self) -> Dict[str, Any]: + """Convert the annotations to the format the synapse rest API works in. + + Returns: + The annotations in the format the synapse rest API works in. + """ + annotations_dict = asdict(self) + + synapse_annotations = _convert_to_annotations_list( + annotations_dict["annotations"] or {} + ) + + result = { + "annotations": synapse_annotations, + "id": self.id, + "etag": self.etag, + } + delete_none_keys(result) + return result diff --git a/synapseclient/models/dataset.py b/synapseclient/models/dataset.py index 9b145d205..21f7bea98 100644 --- a/synapseclient/models/dataset.py +++ b/synapseclient/models/dataset.py @@ -855,12 +855,12 @@ def _set_last_persistent_instance(self) -> None: [dataclasses.replace(item) for item in self.items] if self.items else [] ) - def fill_from_dict(self, entity, set_annotations: bool = True) -> "Self": + def fill_from_dict(self, entity, annotations: Dict = None) -> "Self": """ Converts the data coming from the Synapse API into this datamodel. Arguments: - synapse_table: The data coming from the Synapse API + entity: The data coming from the Synapse API Returns: The Dataset object instance. @@ -887,8 +887,8 @@ def fill_from_dict(self, entity, set_annotations: bool = True) -> "Self": for item in entity.get("items", []) ] - if set_annotations: - self.annotations = Annotations.from_dict(entity.get("annotations", {})) + if annotations: + self.annotations = Annotations.from_dict(annotations.get("annotations", {})) return self def to_synapse_request(self): @@ -2255,12 +2255,12 @@ def _set_last_persistent_instance(self) -> None: [dataclasses.replace(item) for item in self.items] if self.items else [] ) - def fill_from_dict(self, entity, set_annotations: bool = True) -> "Self": + def fill_from_dict(self, entity, annotations: Dict = None) -> "Self": """ Converts the data coming from the Synapse API into this datamodel. Arguments: - synapse_table: The data coming from the Synapse API + entity: The data coming from the Synapse API Returns: The DatasetCollection object instance. @@ -2283,8 +2283,8 @@ def fill_from_dict(self, entity, set_annotations: bool = True) -> "Self": EntityRef(id=item["entityId"], version=item["versionNumber"]) for item in entity.get("items", []) ] - if set_annotations: - self.annotations = Annotations.from_dict(entity.get("annotations", {})) + if annotations: + self.annotations = Annotations.from_dict(annotations.get("annotations", {})) return self def to_synapse_request(self): diff --git a/synapseclient/models/entityview.py b/synapseclient/models/entityview.py index 640a0eeeb..95f18859d 100644 --- a/synapseclient/models/entityview.py +++ b/synapseclient/models/entityview.py @@ -739,9 +739,7 @@ def _set_last_persistent_instance(self) -> None: deepcopy(self.scope_ids) if self.scope_ids else set() ) - def fill_from_dict( - self, entity: Dict, set_annotations: bool = True - ) -> "EntityView": + def fill_from_dict(self, entity: Dict, annotations: Dict = None) -> "EntityView": """ Converts the data coming from the Synapse API into this datamodel. @@ -768,8 +766,8 @@ def fill_from_dict( self.view_type_mask = entity.get("viewTypeMask", None) self.scope_ids = set(f"syn{id}" for id in entity.get("scopeIds", [])) - if set_annotations: - self.annotations = Annotations.from_dict(entity.get("annotations", {})) + if annotations: + self.annotations = Annotations.from_dict(annotations.get("annotations", {})) return self def to_synapse_request(self): diff --git a/synapseclient/models/file.py b/synapseclient/models/file.py index 3a26c9f8b..d12c04033 100644 --- a/synapseclient/models/file.py +++ b/synapseclient/models/file.py @@ -6,7 +6,7 @@ from copy import deepcopy from dataclasses import dataclass, field from datetime import date, datetime -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from synapseclient import File as SynapseFile from synapseclient import Synapse @@ -607,7 +607,7 @@ def _fill_from_file_handle(self) -> None: def fill_from_dict( self, synapse_file: Union[Synapse_File, Dict[str, Union[bool, str, int]]], - set_annotations: bool = True, + annotations: Dict = None, ) -> "File": """ Converts a response from the REST API into this dataclass. @@ -642,10 +642,8 @@ def fill_from_dict( ) self._fill_from_file_handle() - if set_annotations: - self.annotations = Annotations.from_dict( - synapse_file.get("annotations", {}) - ) + if annotations: + self.annotations = Annotations.from_dict(annotations.get("annotations", {})) return self def _cannot_store(self) -> bool: @@ -866,10 +864,16 @@ async def store_async( delete_none_keys(synapse_file) entity = await store_entity( - resource=self, entity=synapse_file, synapse_client=client + entity=self.to_synapse_request(), + parent_id=self.parent_id, + annotations=Annotations(self.annotations).to_synapse_request() + if self.annotations + else None, + synapse_client=synapse_client, + ) + self.fill_from_dict( + entity=entity["entity"], annotations=entity.get("annotations", None) ) - - self.fill_from_dict(synapse_file=entity, set_annotations=False) re_read_required = await store_entity_components( root_resource=self, synapse_client=client @@ -948,7 +952,9 @@ async def change_metadata_async( ), ) - self.fill_from_dict(synapse_file=entity, set_annotations=True) + self.fill_from_dict( + synapse_file=entity, annotations=entity.get("annotations", {}) + ) self._set_last_persistent_instance() Synapse.get_client(synapse_client=synapse_client).logger.debug( f"Change metadata for file {self.name}, id: {self.id}: {self.path}" @@ -1391,3 +1397,32 @@ def _convert_into_legacy_file(self) -> SynapseFile: ) delete_none_keys(return_data) return return_data + + def to_synapse_request(self) -> Dict[str, Any]: + """ + Converts this dataclass into a request that can be sent to the Synapse API. + + Returns: + A dictionary that can be used in a request to the Synapse API. + """ + from synapseclient.core.constants import concrete_types + + entity = { + "name": self.name, + "description": self.description, + "id": self.id, + "etag": self.etag, + "createdOn": self.created_on, + "modifiedOn": self.modified_on, + "createdBy": self.created_by, + "modifiedBy": self.modified_by, + "parentId": self.parent_id, + "dataFileHandleId": self.data_file_handle_id, + "versionLabel": self.version_label, + "versionComment": self.version_comment, + "versionNumber": self.version_number, + "concreteType": concrete_types.FILE_ENTITY, + } + delete_none_keys(entity) + + return entity diff --git a/synapseclient/models/folder.py b/synapseclient/models/folder.py index c141209e3..72f61b7cd 100644 --- a/synapseclient/models/folder.py +++ b/synapseclient/models/folder.py @@ -2,13 +2,14 @@ from copy import deepcopy from dataclasses import dataclass, field, replace from datetime import date, datetime -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from opentelemetry import trace from synapseclient import Synapse from synapseclient.api import get_from_entity_factory from synapseclient.core.async_utils import async_to_sync, otel_trace_method +from synapseclient.core.constants import concrete_types from synapseclient.core.exceptions import SynapseError from synapseclient.core.utils import delete_none_keys, merge_dataclass_entities from synapseclient.entity import Folder as Synapse_Folder @@ -16,9 +17,10 @@ from synapseclient.models.mixins import AccessControllable, StorableContainer from synapseclient.models.protocols.folder_protocol import FolderSynchronousProtocol from synapseclient.models.services.search import get_id +from synapseclient.models.services.storable_entity import store_entity from synapseclient.models.services.storable_entity_components import ( FailureStrategy, - store_entity_components, + store_entity_components_file_folder_only, ) from synapseutils import copy @@ -173,33 +175,54 @@ def _set_last_persistent_instance(self) -> None: ) def fill_from_dict( - self, synapse_folder: Synapse_Folder, set_annotations: bool = True + self, entity: Synapse_Folder, annotations: Dict = None ) -> "Folder": """ Converts a response from the REST API into this dataclass. Arguments: - synapse_file: The response from the REST API. - set_annotations: Whether to set the annotations from the response. + entity: The response from the REST API. + annotations: Optional dictionary containing annotations data. Returns: The Folder object. """ - self.id = synapse_folder.get("id", None) - self.name = synapse_folder.get("name", None) - self.parent_id = synapse_folder.get("parentId", None) - self.description = synapse_folder.get("description", None) - self.etag = synapse_folder.get("etag", None) - self.created_on = synapse_folder.get("createdOn", None) - self.modified_on = synapse_folder.get("modifiedOn", None) - self.created_by = synapse_folder.get("createdBy", None) - self.modified_by = synapse_folder.get("modifiedBy", None) - if set_annotations: - self.annotations = Annotations.from_dict( - synapse_folder.get("annotations", None) - ) + self.id = entity.get("id", None) + self.name = entity.get("name", None) + self.parent_id = entity.get("parentId", None) + self.description = entity.get("description", None) + self.etag = entity.get("etag", None) + self.created_on = entity.get("createdOn", None) + self.modified_on = entity.get("modifiedOn", None) + self.created_by = entity.get("createdBy", None) + self.modified_by = entity.get("modifiedBy", None) + if annotations: + self.annotations = Annotations.from_dict(annotations.get("annotations", {})) return self + def to_synapse_request(self) -> Dict[str, Any]: + """ + Converts this dataclass into a request that can be sent to the Synapse API. + + Returns: + A dictionary that can be used in a request to the Synapse API. + """ + entity = { + "name": self.name, + "description": self.description, + "id": self.id, + "etag": self.etag, + "createdOn": self.created_on, + "modifiedOn": self.modified_on, + "createdBy": self.created_by, + "modifiedBy": self.modified_by, + "parentId": self.parent_id, + "concreteType": concrete_types.FOLDER_ENTITY, + } + delete_none_keys(entity) + + return entity + @otel_trace_method( method_to_trace_name=lambda self, **kwargs: f"Folder_Store: {self.name}" ) @@ -265,28 +288,19 @@ async def store_async( } ) if self.has_changed: - loop = asyncio.get_event_loop() - synapse_folder = Synapse_Folder( - id=self.id, - name=self.name, - parent=parent_id, - etag=self.etag, - description=self.description, + entity = await store_entity( + entity=self.to_synapse_request(), + parent_id=self.parent_id, + annotations=Annotations(self.annotations).to_synapse_request() + if self.annotations + else None, + synapse_client=synapse_client, ) - delete_none_keys(synapse_folder) - entity = await loop.run_in_executor( - None, - lambda: Synapse.get_client(synapse_client=synapse_client).store( - obj=synapse_folder, - set_annotations=False, - isRestricted=self.is_restricted, - createOrUpdate=False, - ), + self.fill_from_dict( + entity=entity["entity"], annotations=entity.get("annotations", None) ) - self.fill_from_dict(synapse_folder=entity, set_annotations=False) - - await store_entity_components( + await store_entity_components_file_folder_only( root_resource=self, failure_strategy=failure_strategy, synapse_client=synapse_client, diff --git a/synapseclient/models/materializedview.py b/synapseclient/models/materializedview.py index 862b37f38..fd4b62668 100644 --- a/synapseclient/models/materializedview.py +++ b/synapseclient/models/materializedview.py @@ -667,7 +667,7 @@ def _set_last_persistent_instance(self) -> None: ) def fill_from_dict( - self, entity: Dict, set_annotations: bool = True + self, entity: Dict, annotations: Dict = None ) -> "MaterializedView": """ Converts the data coming from the Synapse API into this datamodel. @@ -694,8 +694,8 @@ def fill_from_dict( self.is_search_enabled = entity.get("isSearchEnabled", False) self.defining_sql = entity.get("definingSQL", None) - if set_annotations: - self.annotations = entity.get("annotations", {}) + if annotations: + self.annotations = Annotations.from_dict(annotations.get("annotations", {})) return self diff --git a/synapseclient/models/mixins/table_components.py b/synapseclient/models/mixins/table_components.py index dd132fb58..c8d86fec1 100644 --- a/synapseclient/models/mixins/table_components.py +++ b/synapseclient/models/mixins/table_components.py @@ -426,15 +426,15 @@ async def store_async( if dry_run: client.logger.info( - f"[{self.id}:{self.name}]: Dry run enabled. No changes will be made." + f"[{self.id or ''}:{self.name}]: Dry run enabled. No changes will be made." ) if self.has_changed: - if self.id: - if dry_run: - client.logger.info( - f"[{self.id}:{self.name}]: Dry run {self.__class__} update, expected changes:" - ) + if dry_run: + client.logger.info( + f"[{self.id or ''}:{self.name}]: Dry run {self.__class__} {('update' if self.id else 'create')}, expected changes:" + ) + if self.id: log_dataclass_diff( logger=client.logger, prefix=f"[{self.id}:{self.name}]: ", @@ -443,17 +443,6 @@ async def store_async( fields_to_ignore=["columns", "_last_persistent_instance"], ) else: - entity = await put_entity_id_bundle2( - entity_id=self.id, - request=self.to_synapse_request(), - synapse_client=synapse_client, - ) - self.fill_from_dict(entity=entity["entity"], set_annotations=False) - else: - if dry_run: - client.logger.info( - f"[{self.id}:{self.name}]: Dry run {self.__class__} update, expected changes:" - ) log_dataclass_diff( logger=client.logger, prefix=f"[{self.name}]: ", @@ -461,11 +450,24 @@ async def store_async( obj2=self, fields_to_ignore=["columns", "_last_persistent_instance"], ) - else: - entity = await post_entity_bundle2_create( - request=self.to_synapse_request(), synapse_client=synapse_client - ) - self.fill_from_dict(entity=entity["entity"], set_annotations=False) + else: + # Use store_entity_with_bundle2 for both creation and update operations + from synapseclient.api.entity_bundle_services_v2 import ( + store_entity_with_bundle2, + ) + from synapseclient.models import Annotations + + entity = await store_entity_with_bundle2( + entity=self.to_synapse_request(), + parent_id=self.parent_id if not self.id else None, + annotations=Annotations(self.annotations).to_synapse_request() + if self.annotations + else None, + synapse_client=synapse_client, + ) + self.fill_from_dict( + entity=entity["entity"], annotations=entity.get("annotations", None) + ) schema_change_request = await self._generate_schema_change_request( dry_run=dry_run, synapse_client=synapse_client diff --git a/synapseclient/models/project.py b/synapseclient/models/project.py index 88c0d4db9..d114a3697 100644 --- a/synapseclient/models/project.py +++ b/synapseclient/models/project.py @@ -5,20 +5,21 @@ from typing import Dict, List, Optional, Union from opentelemetry import trace +from typing_extensions import Any from synapseclient import Synapse -from synapseclient.api import get_from_entity_factory +from synapseclient.api import get_from_entity_factory, store_entity_with_bundle2 from synapseclient.core.async_utils import async_to_sync, otel_trace_method +from synapseclient.core.constants import concrete_types from synapseclient.core.exceptions import SynapseError from synapseclient.core.utils import delete_none_keys, merge_dataclass_entities -from synapseclient.entity import Project as Synapse_Project from synapseclient.models import Annotations, File, Folder from synapseclient.models.mixins import AccessControllable, StorableContainer from synapseclient.models.protocols.project_protocol import ProjectSynchronousProtocol from synapseclient.models.services.search import get_id from synapseclient.models.services.storable_entity_components import ( FailureStrategy, - store_entity_components, + store_entity_components_file_folder_only, ) from synapseutils.copy_functions import copy @@ -201,34 +202,56 @@ def _set_last_persistent_instance(self) -> None: def fill_from_dict( self, - synapse_project: Union[Synapse_Project, Dict], - set_annotations: bool = True, + entity: Dict, + annotations: Optional[Dict] = None, ) -> "Project": """ Converts a response from the REST API into this dataclass. Arguments: - synapse_project: The response from the REST API. + entity: The response from the REST API. Returns: The Project object. """ - self.id = synapse_project.get("id", None) - self.name = synapse_project.get("name", None) - self.description = synapse_project.get("description", None) - self.etag = synapse_project.get("etag", None) - self.created_on = synapse_project.get("createdOn", None) - self.modified_on = synapse_project.get("modifiedOn", None) - self.created_by = synapse_project.get("createdBy", None) - self.modified_by = synapse_project.get("modifiedBy", None) - self.alias = synapse_project.get("alias", None) - self.parent_id = synapse_project.get("parentId", None) - if set_annotations: - self.annotations = Annotations.from_dict( - synapse_project.get("annotations", {}) - ) + self.id = entity.get("id", None) + self.name = entity.get("name", None) + self.description = entity.get("description", None) + self.etag = entity.get("etag", None) + self.created_on = entity.get("createdOn", None) + self.modified_on = entity.get("modifiedOn", None) + self.created_by = entity.get("createdBy", None) + self.modified_by = entity.get("modifiedBy", None) + self.alias = entity.get("alias", None) + self.parent_id = entity.get("parentId", None) + if annotations: + self.annotations = Annotations.from_dict(annotations.get("annotations", {})) return self + def to_synapse_request(self) -> Dict[str, Any]: + """ + Converts this dataclass into a request that can be sent to the Synapse API. + + Returns: + A dictionary that can be used in a request to the Synapse API. + """ + entity = { + "name": self.name, + "description": self.description, + "id": self.id, + "etag": self.etag, + "createdOn": self.created_on, + "modifiedOn": self.modified_on, + "createdBy": self.created_by, + "modifiedBy": self.modified_by, + "parentId": self.parent_id, + "alias": self.alias, + "concreteType": concrete_types.PROJECT_ENTITY, + } + delete_none_keys(entity) + + return entity + @otel_trace_method( method_to_trace_name=lambda self, **kwargs: f"Project_Store: ID: {self.id}, Name: {self.name}" ) @@ -295,28 +318,21 @@ async def store_async( "synapse.id": self.id or "", } ) + if self.has_changed: - loop = asyncio.get_event_loop() - synapse_project = Synapse_Project( - id=self.id, - etag=self.etag, - name=self.name, - description=self.description, - alias=self.alias, - parentId=self.parent_id, + entity = await store_entity_with_bundle2( + entity=self.to_synapse_request(), + parent_id=self.parent_id, + annotations=Annotations(self.annotations).to_synapse_request() + if self.annotations + else None, + synapse_client=synapse_client, ) - delete_none_keys(synapse_project) - entity = await loop.run_in_executor( - None, - lambda: Synapse.get_client(synapse_client=synapse_client).store( - obj=synapse_project, - set_annotations=False, - createOrUpdate=False, - ), + self.fill_from_dict( + entity=entity["entity"], annotations=entity.get("annotations", None) ) - self.fill_from_dict(synapse_project=entity, set_annotations=False) - await store_entity_components( + await store_entity_components_file_folder_only( root_resource=self, failure_strategy=failure_strategy, synapse_client=synapse_client, diff --git a/synapseclient/models/services/__init__.py b/synapseclient/models/services/__init__.py index d1e7227ca..553de02ba 100644 --- a/synapseclient/models/services/__init__.py +++ b/synapseclient/models/services/__init__.py @@ -3,6 +3,13 @@ from synapseclient.models.services.storable_entity_components import ( FailureStrategy, store_entity_components, + store_entity_components_file_folder_only, ) -__all__ = ["store_entity_components", "store_entity", "FailureStrategy", "get_id"] +__all__ = [ + "store_entity_components", + "store_entity_components_file_folder_only", + "store_entity", + "FailureStrategy", + "get_id", +] diff --git a/synapseclient/models/services/storable_entity.py b/synapseclient/models/services/storable_entity.py index d70082e80..89f0aa58d 100644 --- a/synapseclient/models/services/storable_entity.py +++ b/synapseclient/models/services/storable_entity.py @@ -1,103 +1,162 @@ """Script used to store an entity to Synapse.""" -from typing import TYPE_CHECKING, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union from opentelemetry import trace from synapseclient import Synapse from synapseclient.api import ( create_access_requirements_if_none, - post_entity, - put_entity, + get_entity_id_bundle2, + post_entity_bundle2_create, + put_entity_id_bundle2, ) from synapseclient.core.utils import get_properties if TYPE_CHECKING: - from synapseclient.models import File, Folder, Project + from synapseclient.models import ( + Annotations, + Dataset, + EntityView, + File, + Folder, + Project, + Table, + ) async def store_entity( - resource: Union["File", "Folder", "Project"], + resource: Union["File", "Folder", "Project", "Table", "Dataset", "EntityView"], entity: Dict[str, Union[str, bool, int, float]], + parent_id: Optional[str] = None, + acl: Optional[Dict[str, Any]] = None, + new_version: bool = False, + force_version: bool = False, *, synapse_client: Optional[Synapse] = None, -) -> bool: +) -> Dict[str, Any]: """ - Function to store an entity to synapse. + Function to store an entity to synapse using the bundle2 service. - TODO: This function is not complete and is a work in progress. + This function handles both creation and update of entities in Synapse: + 1. For new entities (without an ID): + - Creates a new entity with the provided data + - Pushes the new entity with bundle2 create endpoint + 2. For existing entities (with an ID): + - Updates the entity with the provided data + - Pushes the updated entity with bundle2 update endpoint Arguments: resource: The root dataclass instance we are storing data for. entity: The entity to store. + parent_id: The ID of the parent entity for creation. + acl: Access control list for the entity. + new_version: If True, create a new version of the entity. + force_version: If True, forces a new version of an entity even if nothing has changed. synapse_client: If not passed in and caching was not disabled by `Synapse.allow_client_caching(False)` this will use the last created instance from the Synapse class constructor. Returns: - If a read from Synapse is required to retireve the current state of the entity. + The entity data from the stored entity bundle. """ - query_params = {} - increment_version = False - # Create or update Entity in Synapse + from synapseclient.models import Annotations + + # First, handle the activity if it exists and needs to be stored + activity_id = None + if hasattr(resource, "activity") and resource.activity is not None: + # Store the activity first if it doesn't have an ID yet or if it's changed + last_persistent_instance = getattr(resource, "_last_persistent_instance", None) + activity_changed = ( + last_persistent_instance is None + or last_persistent_instance.activity != resource.activity + ) + + if not resource.activity.id or activity_changed: + resource.activity = await resource.activity.store_async( + synapse_client=synapse_client + ) + + activity_id = resource.activity.id + + # Prepare annotations if they exist + annotations = None + if hasattr(resource, "annotations") and resource.annotations: + annotations = Annotations(resource.annotations).to_synapse_request() + + # Set trace attributes if ID exists if resource.id: trace.get_current_span().set_attributes({"synapse.id": resource.id}) - if hasattr(resource, "version_number"): - if ( - resource.version_label - and resource.version_label - != resource._last_persistent_instance.version_label - ): - # a versionLabel implicitly implies incrementing - increment_version = True - elif resource.force_version and resource.version_number: - increment_version = True - entity["versionLabel"] = str(resource.version_number + 1) - - if increment_version: - query_params["newVersion"] = "true" - - updated_entity = await put_entity( - entity_id=resource.id, - request=get_properties(entity), - new_version=increment_version, + + # Handle versioning attributes if not already specified + # TODO: force_version is not yet supported in the bundle2 API: https://sagebionetworks.jira.com/browse/PLFM-8313 + if ( + not force_version + and hasattr(resource, "version_number") + and hasattr(resource, "force_version") + ): + if resource.force_version: + force_version = True + + # Get parent_id from resource if not specified + if parent_id is None: + parent_id = getattr(resource, "parent_id", None) + + # Get client + client = Synapse.get_client(synapse_client=synapse_client) + + # Determine if this is a create or update operation + entity_id = entity.get("id", None) + + # Construct bundle request based on provided data + bundle_request = {"entity": entity.to_synapse_request()} + + if annotations: + bundle_request["annotations"] = annotations + + if acl: + bundle_request["accessControlList"] = acl + + if activity_id: + bundle_request["activity"] = activity_id + + # Handle create or update + if not entity_id: + # This is a creation + client.logger.debug("Creating new entity via bundle2 API") + + # Create entity using bundle2 create endpoint + updated_entity = await post_entity_bundle2_create( + request=bundle_request, + generated_by=activity_id, synapse_client=synapse_client, ) else: - # TODO - When Link is implemented this needs to be completed - # If Link, get the target name, version number and concrete type and store in link properties - # if properties["concreteType"] == "org.sagebionetworks.repo.model.Link": - # target_properties = self._getEntity( - # properties["linksTo"]["targetId"], - # version=properties["linksTo"].get("targetVersionNumber"), - # ) - # if target_properties["parentId"] == properties["parentId"]: - # raise ValueError( - # "Cannot create a Link to an entity under the same parent." - # ) - # properties["linksToClassName"] = target_properties["concreteType"] - # if ( - # target_properties.get("versionNumber") is not None - # and properties["linksTo"].get("targetVersionNumber") is not None - # ): - # properties["linksTo"]["targetVersionNumber"] = target_properties[ - # "versionNumber" - # ] - # properties["name"] = target_properties["name"] - - updated_entity = await post_entity( - request=get_properties(entity), + # This is an update + client.logger.debug(f"Updating entity {entity_id} via bundle2 API") + + # If we're creating a new version or forcing one, we need to update + # the entity directly instead of via bundle2 + updated_entity = await put_entity_id_bundle2( + entity_id=entity_id, + request=bundle_request, + generated_by=activity_id, synapse_client=synapse_client, ) + # Handle access restrictions if needed if hasattr(resource, "is_restricted") and resource.is_restricted: - await create_access_requirements_if_none(entity_id=updated_entity.get("id")) + await create_access_requirements_if_none( + entity_id=updated_entity["entity"].get("id") + ) + # Set trace attributes trace.get_current_span().set_attributes( { - "synapse.id": updated_entity.get("id"), - "synapse.concrete_type": updated_entity.get("concreteType", ""), + "synapse.id": updated_entity["entity"].get("id"), + "synapse.concrete_type": updated_entity["entity"].get("concreteType", ""), } ) - return updated_entity + + return updated_entity["entity"] diff --git a/synapseclient/models/services/storable_entity_components.py b/synapseclient/models/services/storable_entity_components.py index 76c7622da..2928b990e 100644 --- a/synapseclient/models/services/storable_entity_components.py +++ b/synapseclient/models/services/storable_entity_components.py @@ -120,6 +120,77 @@ async def store_entity_components( return re_read_required +async def store_entity_components_file_folder_only( + root_resource: Union["File", "Folder", "Project", "Table", "Dataset", "EntityView"], + failure_strategy: FailureStrategy = FailureStrategy.LOG_EXCEPTION, + *, + synapse_client: Optional[Synapse] = None, +) -> bool: + """ + Function to store ancillary components of an entity to synapse. This function will + execute the stores in parallel. + + This is responsible for storing the annotations, activity, files and folders of a + resource to synapse. + + + Arguments: + root_resource: The root resource to store objects on. + synapse_client: If not passed in or None this will use the last client from the Synapse class constructor. + + Returns: + If a read from Synapse is required to retireve the current state of the entity. + """ + re_read_required = False + + tasks = [] + + if hasattr(root_resource, "files") and root_resource.files is not None: + for file in root_resource.files: + tasks.append( + asyncio.create_task( + file.store_async( + parent=root_resource, synapse_client=synapse_client + ) + ) + ) + + if hasattr(root_resource, "folders") and root_resource.folders is not None: + for folder in root_resource.folders: + tasks.append( + asyncio.create_task( + folder.store_async( + parent=root_resource, synapse_client=synapse_client + ) + ) + ) + + # tasks.append( + # asyncio.create_task( + # _store_activity_and_annotations( + # root_resource, synapse_client=synapse_client + # ) + # ) + # ) + + try: + tasks = [wrap_coroutine(task) for task in tasks] + for task in asyncio.as_completed(tasks): + result = await task + _resolve_store_task( + result=result, + failure_strategy=failure_strategy, + synapse_client=synapse_client, + ) + except Exception as ex: + Synapse.get_client(synapse_client=synapse_client).logger.exception(ex) + if failure_strategy == FailureStrategy.RAISE_EXCEPTION: + raise ex + + # TODO: Double check this logic. This might not be getting set properly from _resolve_store_task + return re_read_required + + def _resolve_store_task( result: Union[bool, "Folder", "File", BaseException], failure_strategy: FailureStrategy = FailureStrategy.LOG_EXCEPTION, diff --git a/synapseclient/models/submissionview.py b/synapseclient/models/submissionview.py index e285ed8d5..250e311d0 100644 --- a/synapseclient/models/submissionview.py +++ b/synapseclient/models/submissionview.py @@ -847,7 +847,7 @@ def _set_last_persistent_instance(self) -> None: deepcopy(self.scope_ids) if self.scope_ids else [] ) - def fill_from_dict(self, entity, set_annotations: bool = True) -> "Self": + def fill_from_dict(self, entity, annotations: Dict = None) -> "Self": """ Converts the data coming from the Synapse API into this datamodel. @@ -873,8 +873,8 @@ def fill_from_dict(self, entity, set_annotations: bool = True) -> "Self": self.is_search_enabled = entity.get("isSearchEnabled", False) self.scope_ids = [item for item in entity.get("scopeIds", [])] - if set_annotations: - self.annotations = Annotations.from_dict(entity.get("annotations", {})) + if annotations: + self.annotations = Annotations.from_dict(annotations.get("annotations", {})) return self diff --git a/synapseclient/models/table.py b/synapseclient/models/table.py index 9d305781c..07d97b2a5 100644 --- a/synapseclient/models/table.py +++ b/synapseclient/models/table.py @@ -1349,7 +1349,7 @@ def _set_last_persistent_instance(self) -> None: ) def fill_from_dict( - self, entity: Synapse_Table, set_annotations: bool = True + self, entity: Synapse_Table, annotations: Dict = None ) -> "Table": """ Converts the data coming from the Synapse API into this datamodel. @@ -1375,8 +1375,8 @@ def fill_from_dict( self.is_latest_version = entity.get("isLatestVersion", None) self.is_search_enabled = entity.get("isSearchEnabled", False) - if set_annotations: - self.annotations = Annotations.from_dict(entity.get("annotations", {})) + if annotations: + self.annotations = Annotations.from_dict(annotations.get("annotations", {})) return self def to_synapse_request(self): diff --git a/synapseclient/models/virtualtable.py b/synapseclient/models/virtualtable.py index 70b52b8d3..c3f484a81 100644 --- a/synapseclient/models/virtualtable.py +++ b/synapseclient/models/virtualtable.py @@ -11,7 +11,7 @@ from synapseclient.core.async_utils import async_to_sync from synapseclient.core.constants import concrete_types from synapseclient.core.utils import delete_none_keys -from synapseclient.models import Activity, Column +from synapseclient.models import Activity, Annotations, Column from synapseclient.models.mixins.access_control import AccessControllable from synapseclient.models.mixins.table_components import ( DeleteMixin, @@ -427,7 +427,7 @@ def _set_last_persistent_instance(self) -> None: ) def fill_from_dict( - self, entity: Dict[str, Any], set_annotations: bool = True + self, entity: Dict[str, Any], annotations: Dict = None ) -> "VirtualTable": """ Converts the data coming from the Synapse API into this datamodel. @@ -454,8 +454,8 @@ def fill_from_dict( self.is_search_enabled = entity.get("isSearchEnabled", False) self.defining_sql = entity.get("definingSQL", None) - if set_annotations: - self.annotations = entity.get("annotations", {}) + if annotations: + self.annotations = Annotations.from_dict(annotations.get("annotations", {})) return self diff --git a/tests/unit/synapseclient/mixins/unit_test_table_components.py b/tests/unit/synapseclient/mixins/unit_test_table_components.py index ab27fbb16..fe5f6bfae 100644 --- a/tests/unit/synapseclient/mixins/unit_test_table_components.py +++ b/tests/unit/synapseclient/mixins/unit_test_table_components.py @@ -85,9 +85,11 @@ def to_synapse_request(self) -> Any: }, } - def fill_from_dict(self, entity: Any, set_annotations: bool = True) -> None: + def fill_from_dict(self, entity: Any, annotations: Dict = {}) -> None: """Placeholder for fill_from_dict method""" self.__dict__.update(entity) + if annotations is not None: + self.__dict__.update(annotations) # TODO: Is this right? @pytest.fixture(autouse=True, scope="function") def init_syn(self, syn: Synapse) -> None: diff --git a/tests/unit/synapseclient/models/async/unit_test_dataset_async.py b/tests/unit/synapseclient/models/async/unit_test_dataset_async.py index 261c3fcfa..4b5bb891a 100644 --- a/tests/unit/synapseclient/models/async/unit_test_dataset_async.py +++ b/tests/unit/synapseclient/models/async/unit_test_dataset_async.py @@ -61,7 +61,7 @@ def test_fill_from_dict(self): # GIVEN an empty Dataset dataset = Dataset() # WHEN I fill it from a Synapse response - dataset.fill_from_dict(self.synapse_response, set_annotations=True) + dataset.fill_from_dict(self.synapse_response, annotations=self.synapse_response) # THEN I expect the Dataset to be filled with the expected values assert dataset.id == self.synapse_response["id"] assert dataset.name == self.synapse_response["name"] @@ -259,7 +259,9 @@ def test_fill_from_dict(self): # GIVEN an empty DatasetCollection dataset_collection = DatasetCollection() # WHEN I fill it from a Synapse response - dataset_collection.fill_from_dict(self.synapse_response, set_annotations=True) + dataset_collection.fill_from_dict( + self.synapse_response, annotations=self.synapse_response + ) # THEN I expect the DatasetCollection to be filled with the expected values assert dataset_collection.id == self.synapse_response["id"] assert dataset_collection.name == self.synapse_response["name"] diff --git a/tests/unit/synapseclient/models/async/unit_test_project_async.py b/tests/unit/synapseclient/models/async/unit_test_project_async.py index 33c56468e..3c4b5d69c 100644 --- a/tests/unit/synapseclient/models/async/unit_test_project_async.py +++ b/tests/unit/synapseclient/models/async/unit_test_project_async.py @@ -5,7 +5,6 @@ import pytest -from synapseclient import Project as Synapse_Project from synapseclient import Synapse from synapseclient.core.constants import concrete_types from synapseclient.core.constants.concrete_types import FILE_ENTITY @@ -30,19 +29,6 @@ class TestProject: def init_syn(self, syn: Synapse) -> None: self.syn = syn - def get_example_synapse_project_output(self) -> Synapse_Project: - return Synapse_Project( - id=PROJECT_ID, - name=PROJECT_NAME, - parentId=PARENT_ID, - description=DERSCRIPTION_PROJECT, - etag=ETAG, - createdOn=CREATED_ON, - modifiedOn=MODIFIED_ON, - createdBy=CREATED_BY, - modifiedBy=MODIFIED_BY, - ) - def get_example_rest_api_project_output(self) -> Dict[str, str]: return { "entity": { @@ -63,7 +49,7 @@ def test_fill_from_dict(self) -> None: # GIVEN an example Synapse Project `get_example_synapse_project_output` # WHEN I call `fill_from_dict` with the example Synapse Project project_output = Project().fill_from_dict( - self.get_example_synapse_project_output() + self.get_example_rest_api_project_output().get("entity") ) # THEN the Project object should be filled with the example Synapse Project @@ -87,34 +73,23 @@ async def test_store_with_id(self) -> None: description = str(uuid.uuid4()) project.description = description + # Create the mock return value + mock_return_value = self.get_example_rest_api_project_output() + # WHEN I call `store` with the Project object - with patch.object( - self.syn, - "store", - return_value=(self.get_example_synapse_project_output()), + with patch( + "synapseclient.models.project.store_entity_with_bundle2", + new_callable=AsyncMock, + return_value=mock_return_value, ) as mocked_client_call, patch( "synapseclient.api.entity_factory.get_entity_id_bundle2", new_callable=AsyncMock, - return_value=( - { - "entity": { - "concreteType": concrete_types.PROJECT_ENTITY, - "id": project.id, - } - } - ), + return_value=mock_return_value, ) as mocked_get: result = await project.store_async(synapse_client=self.syn) - # THEN we should call the method with this data - mocked_client_call.assert_called_once_with( - obj=Synapse_Project( - id=project.id, - description=description, - ), - set_annotations=False, - createOrUpdate=False, - ) + # THEN we should call the method + mocked_client_call.assert_called_once() # AND we should call the get method mocked_get.assert_called_once() @@ -246,25 +221,21 @@ async def test_store_after_get_with_changes(self) -> None: description = str(uuid.uuid4()) project.description = description + # Create the mock return value + mock_return_value = self.get_example_rest_api_project_output() + # WHEN I call `store` with the Project object - with patch.object( - self.syn, - "store", - return_value=(self.get_example_synapse_project_output()), - ) as mocked_store, patch( + with patch( + "synapseclient.models.project.store_entity_with_bundle2", + new_callable=AsyncMock, + return_value=mock_return_value, + ) as mocked_client_call, patch( "synapseclient.api.entity_factory.get_entity_id_bundle2", ) as mocked_get: result = await project.store_async(synapse_client=self.syn) - # THEN we should call store because there are changes - mocked_store.assert_called_once_with( - obj=Synapse_Project( - id=project.id, - description=description, - ), - set_annotations=False, - createOrUpdate=False, - ) + # THEN we should call store_entity_with_bundle2 because there are changes + mocked_client_call.assert_called_once() # AND we should not call get as we already have mocked_get.assert_not_called() @@ -297,14 +268,17 @@ async def test_store_with_annotations(self) -> None: description = str(uuid.uuid4()) project.description = description + # Create the mock return value + mock_return_value = self.get_example_rest_api_project_output() + # WHEN I call `store` with the Project object with patch( - "synapseclient.models.project.store_entity_components", + "synapseclient.models.project.store_entity_components_file_folder_only", return_value=(None), - ) as mocked_store_entity_components, patch.object( - self.syn, - "store", - return_value=(self.get_example_synapse_project_output()), + ) as mocked_store_entity_components, patch( + "synapseclient.models.project.store_entity_with_bundle2", + new_callable=AsyncMock, + return_value=mock_return_value, ) as mocked_client_call, patch( "synapseclient.api.entity_factory.get_entity_id_bundle2", new_callable=AsyncMock, @@ -319,15 +293,8 @@ async def test_store_with_annotations(self) -> None: ) as mocked_get: result = await project.store_async(synapse_client=self.syn) - # THEN we should call the method with this data - mocked_client_call.assert_called_once_with( - obj=Synapse_Project( - id=project.id, - description=description, - ), - set_annotations=False, - createOrUpdate=False, - ) + # THEN we should call store_entity_with_bundle2 + mocked_client_call.assert_called_once() # AND we should call the get method mocked_get.assert_called_once() @@ -361,48 +328,45 @@ async def test_store_with_name_and_parent_id(self) -> None: description = str(uuid.uuid4()) project.description = description + # Create the mock return value + mock_return_value = self.get_example_rest_api_project_output() + # WHEN I call `store` with the Project object with patch.object( - self.syn, - "store", - return_value=(self.get_example_synapse_project_output()), - ) as mocked_client_call, patch.object( self.syn, "findEntityId", return_value=PROJECT_ID, - ) as mocked_get, patch( + ) as mocked_find_entity_id, patch( + "synapseclient.models.project.store_entity_with_bundle2", + new_callable=AsyncMock, + return_value=mock_return_value, + ) as mocked_client_call, patch( "synapseclient.api.entity_factory.get_entity_id_bundle2", new_callable=AsyncMock, return_value=( { "entity": { "concreteType": concrete_types.PROJECT_ENTITY, - "id": project.id, + "id": PROJECT_ID, } } ), ) as mocked_get: result = await project.store_async(synapse_client=self.syn) - # THEN we should call the method with this data - mocked_client_call.assert_called_once_with( - obj=Synapse_Project( - id=project.id, - name=project.name, - parent=project.parent_id, - description=description, - ), - set_annotations=False, - createOrUpdate=False, + # THEN we should call store_entity_with_bundle2 + mocked_client_call.assert_called_once() + + # AND we should find the entity ID + mocked_find_entity_id.assert_called_once_with( + name=project.name, + parent=project.parent_id, ) # AND we should call the get method mocked_get.assert_called_once() - # AND findEntityId should be called - mocked_get.assert_called_once() - - # AND the project should be stored + # AND the project should be stored with the mock return data assert result.id == PROJECT_ID assert result.name == PROJECT_NAME assert result.parent_id == PARENT_ID diff --git a/tests/unit/synapseclient/models/synchronous/unit_test_project.py b/tests/unit/synapseclient/models/synchronous/unit_test_project.py index d2d484d29..4e0b239ea 100644 --- a/tests/unit/synapseclient/models/synchronous/unit_test_project.py +++ b/tests/unit/synapseclient/models/synchronous/unit_test_project.py @@ -5,7 +5,6 @@ import pytest -from synapseclient import Project as Synapse_Project from synapseclient import Synapse from synapseclient.core.constants import concrete_types from synapseclient.core.constants.concrete_types import FILE_ENTITY @@ -30,19 +29,6 @@ class TestProject: def init_syn(self, syn: Synapse) -> None: self.syn = syn - def get_example_synapse_project_output(self) -> Synapse_Project: - return Synapse_Project( - id=PROJECT_ID, - name=PROJECT_NAME, - parentId=PARENT_ID, - description=DERSCRIPTION_PROJECT, - etag=ETAG, - createdOn=CREATED_ON, - modifiedOn=MODIFIED_ON, - createdBy=CREATED_BY, - modifiedBy=MODIFIED_BY, - ) - def get_example_rest_api_project_output(self) -> Dict[str, str]: return { "entity": { @@ -63,7 +49,7 @@ def test_fill_from_dict(self) -> None: # GIVEN an example Synapse Project `get_example_synapse_project_output` # WHEN I call `fill_from_dict` with the example Synapse Project project_output = Project().fill_from_dict( - self.get_example_synapse_project_output() + self.get_example_rest_api_project_output().get("entity") ) # THEN the Project object should be filled with the example Synapse Project @@ -87,34 +73,23 @@ def test_store_with_id(self) -> None: description = str(uuid.uuid4()) project.description = description + # Create the mock return value + mock_return_value = self.get_example_rest_api_project_output() + # WHEN I call `store` with the Project object - with patch.object( - self.syn, - "store", - return_value=(self.get_example_synapse_project_output()), + with patch( + "synapseclient.models.project.store_entity_with_bundle2", + new_callable=AsyncMock, + return_value=mock_return_value, ) as mocked_client_call, patch( "synapseclient.api.entity_factory.get_entity_id_bundle2", new_callable=AsyncMock, - return_value=( - { - "entity": { - "concreteType": concrete_types.PROJECT_ENTITY, - "id": project.id, - } - } - ), + return_value=mock_return_value, ) as mocked_get: result = project.store(synapse_client=self.syn) # THEN we should call the method with this data - mocked_client_call.assert_called_once_with( - obj=Synapse_Project( - id=project.id, - description=description, - ), - set_annotations=False, - createOrUpdate=False, - ) + mocked_client_call.assert_called_once() # AND we should call the get method mocked_get.assert_called_once() @@ -246,25 +221,21 @@ def test_store_after_get_with_changes(self) -> None: description = str(uuid.uuid4()) project.description = description + # Create the mock return value + mock_return_value = self.get_example_rest_api_project_output() + # WHEN I call `store` with the Project object - with patch.object( - self.syn, - "store", - return_value=(self.get_example_synapse_project_output()), - ) as mocked_store, patch( + with patch( + "synapseclient.models.project.store_entity_with_bundle2", + new_callable=AsyncMock, + return_value=mock_return_value, + ) as mocked_client_call, patch( "synapseclient.api.entity_factory.get_entity_id_bundle2", ) as mocked_get: result = project.store(synapse_client=self.syn) - # THEN we should call store because there are changes - mocked_store.assert_called_once_with( - obj=Synapse_Project( - id=project.id, - description=description, - ), - set_annotations=False, - createOrUpdate=False, - ) + # THEN we should call store_entity_with_bundle2 because there are changes + mocked_client_call.assert_called_once() # AND we should not call get as we already have mocked_get.assert_not_called() @@ -297,14 +268,17 @@ def test_store_with_annotations(self) -> None: description = str(uuid.uuid4()) project.description = description + # Create the mock return value + mock_return_value = self.get_example_rest_api_project_output() + # WHEN I call `store` with the Project object with patch( - "synapseclient.models.project.store_entity_components", + "synapseclient.models.project.store_entity_components_file_folder_only", return_value=(None), - ) as mocked_store_entity_components, patch.object( - self.syn, - "store", - return_value=(self.get_example_synapse_project_output()), + ) as mocked_store_entity_components, patch( + "synapseclient.models.project.store_entity_with_bundle2", + new_callable=AsyncMock, + return_value=mock_return_value, ) as mocked_client_call, patch( "synapseclient.api.entity_factory.get_entity_id_bundle2", new_callable=AsyncMock, @@ -319,15 +293,8 @@ def test_store_with_annotations(self) -> None: ) as mocked_get: result = project.store(synapse_client=self.syn) - # THEN we should call the method with this data - mocked_client_call.assert_called_once_with( - obj=Synapse_Project( - id=project.id, - description=description, - ), - set_annotations=False, - createOrUpdate=False, - ) + # THEN we should call store_entity_with_bundle2 + mocked_client_call.assert_called_once() # AND we should call the get method mocked_get.assert_called_once() @@ -361,48 +328,45 @@ def test_store_with_name_and_parent_id(self) -> None: description = str(uuid.uuid4()) project.description = description + # Create the mock return value + mock_return_value = self.get_example_rest_api_project_output() + # WHEN I call `store` with the Project object with patch.object( - self.syn, - "store", - return_value=(self.get_example_synapse_project_output()), - ) as mocked_client_call, patch.object( self.syn, "findEntityId", return_value=PROJECT_ID, - ) as mocked_get, patch( + ) as mocked_find_entity_id, patch( + "synapseclient.models.project.store_entity_with_bundle2", + new_callable=AsyncMock, + return_value=mock_return_value, + ) as mocked_client_call, patch( "synapseclient.api.entity_factory.get_entity_id_bundle2", new_callable=AsyncMock, return_value=( { "entity": { "concreteType": concrete_types.PROJECT_ENTITY, - "id": project.id, + "id": PROJECT_ID, } } ), ) as mocked_get: result = project.store(synapse_client=self.syn) - # THEN we should call the method with this data - mocked_client_call.assert_called_once_with( - obj=Synapse_Project( - id=project.id, - name=project.name, - parent=project.parent_id, - description=description, - ), - set_annotations=False, - createOrUpdate=False, + # THEN we should call store_entity_with_bundle2 + mocked_client_call.assert_called_once() + + # AND we should find the entity ID + mocked_find_entity_id.assert_called_once_with( + name=project.name, + parent=project.parent_id, ) # AND we should call the get method mocked_get.assert_called_once() - # AND findEntityId should be called - mocked_get.assert_called_once() - - # AND the project should be stored + # AND the project should be stored with the mock return data assert result.id == PROJECT_ID assert result.name == PROJECT_NAME assert result.parent_id == PARENT_ID