diff --git a/diracx-client/src/diracx/client/_generated/_client.py b/diracx-client/src/diracx/client/_generated/_client.py index aa558f636..80e44eb8b 100644 --- a/diracx-client/src/diracx/client/_generated/_client.py +++ b/diracx-client/src/diracx/client/_generated/_client.py @@ -15,7 +15,14 @@ from . import models as _models from ._configuration import DiracConfiguration from ._utils.serialization import Deserializer, Serializer -from .operations import AuthOperations, ConfigOperations, JobsOperations, WellKnownOperations +from .operations import ( + AuthOperations, + ConfigOperations, + JobsOperations, + PilotsLegacyOperations, + PilotsOperations, + WellKnownOperations, +) class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -29,6 +36,10 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype config: _generated.operations.ConfigOperations :ivar jobs: JobsOperations operations :vartype jobs: _generated.operations.JobsOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.operations.PilotsOperations + :ivar pilots_legacy: PilotsLegacyOperations operations + :vartype pilots_legacy: _generated.operations.PilotsLegacyOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -65,6 +76,8 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.auth = AuthOperations(self._client, self._config, self._serialize, self._deserialize) self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots_legacy = PilotsLegacyOperations(self._client, self._config, self._serialize, self._deserialize) def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: """Runs the network request through the client's chained policies. diff --git a/diracx-client/src/diracx/client/_generated/aio/_client.py b/diracx-client/src/diracx/client/_generated/aio/_client.py index 10cfad884..5083a584f 100644 --- a/diracx-client/src/diracx/client/_generated/aio/_client.py +++ b/diracx-client/src/diracx/client/_generated/aio/_client.py @@ -15,7 +15,14 @@ from .. import models as _models from .._utils.serialization import Deserializer, Serializer from ._configuration import DiracConfiguration -from .operations import AuthOperations, ConfigOperations, JobsOperations, WellKnownOperations +from .operations import ( + AuthOperations, + ConfigOperations, + JobsOperations, + PilotsLegacyOperations, + PilotsOperations, + WellKnownOperations, +) class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -29,6 +36,10 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype config: _generated.aio.operations.ConfigOperations :ivar jobs: JobsOperations operations :vartype jobs: _generated.aio.operations.JobsOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.aio.operations.PilotsOperations + :ivar pilots_legacy: PilotsLegacyOperations operations + :vartype pilots_legacy: _generated.aio.operations.PilotsLegacyOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -65,6 +76,8 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.auth = AuthOperations(self._client, self._config, self._serialize, self._deserialize) self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots_legacy = PilotsLegacyOperations(self._client, self._config, self._serialize, self._deserialize) def send_request( self, request: HttpRequest, *, stream: bool = False, **kwargs: Any diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py b/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py index 10db0c7a9..53c8a8f82 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py @@ -14,6 +14,8 @@ from ._operations import AuthOperations # type: ignore from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore +from ._operations import PilotsOperations # type: ignore +from ._operations import PilotsLegacyOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -24,6 +26,8 @@ "AuthOperations", "ConfigOperations", "JobsOperations", + "PilotsOperations", + "PilotsLegacyOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py index 0916d8a28..df72fed83 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py @@ -52,6 +52,14 @@ build_jobs_summary_request, build_jobs_unassign_bulk_jobs_sandboxes_request, build_jobs_unassign_job_sandboxes_request, + build_pilots_add_pilot_stamps_request, + build_pilots_delete_pilots_request, + build_pilots_get_pilot_jobs_request, + build_pilots_legacy_send_message_request, + build_pilots_search_logs_request, + build_pilots_search_request, + build_pilots_summary_request, + build_pilots_update_pilot_fields_request, build_well_known_get_installation_metadata_request, build_well_known_get_jwks_request, build_well_known_get_openid_configuration_request, @@ -1826,7 +1834,7 @@ async def patch_metadata(self, body: Union[Dict[str, Dict[str, Any]], IO[bytes]] @overload async def search( self, - body: Optional[_models.JobSearchParams] = None, + body: Optional[_models.SearchParams] = None, *, page: int = 1, per_page: int = 100, @@ -1840,7 +1848,7 @@ async def search( **TODO: Add more docs**. :param body: Default value is None. - :type body: ~_generated.models.JobSearchParams + :type body: ~_generated.models.SearchParams :keyword page: Default value is 1. :paramtype page: int :keyword per_page: Default value is 100. @@ -1886,7 +1894,7 @@ async def search( @distributed_trace_async async def search( self, - body: Optional[Union[_models.JobSearchParams, IO[bytes]]] = None, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, *, page: int = 1, per_page: int = 100, @@ -1898,8 +1906,8 @@ async def search( **TODO: Add more docs**. - :param body: Is either a JobSearchParams type or a IO[bytes] type. Default value is None. - :type body: ~_generated.models.JobSearchParams or IO[bytes] + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] :keyword page: Default value is 1. :paramtype page: int :keyword per_page: Default value is 100. @@ -1929,7 +1937,7 @@ async def search( _content = body else: if body is not None: - _json = self._serialize.body(body, "JobSearchParams") + _json = self._serialize.body(body, "SearchParams") else: _json = None @@ -1968,14 +1976,14 @@ async def search( @overload async def summary( - self, body: _models.JobSummaryParams, *, content_type: str = "application/json", **kwargs: Any + self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any ) -> Any: """Summary. Show information suitable for plotting. :param body: Required. - :type body: ~_generated.models.JobSummaryParams + :type body: ~_generated.models.SummaryParams :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str @@ -2001,13 +2009,13 @@ async def summary(self, body: IO[bytes], *, content_type: str = "application/jso """ @distributed_trace_async - async def summary(self, body: Union[_models.JobSummaryParams, IO[bytes]], **kwargs: Any) -> Any: + async def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: """Summary. Show information suitable for plotting. - :param body: Is either a JobSummaryParams type or a IO[bytes] type. Required. - :type body: ~_generated.models.JobSummaryParams or IO[bytes] + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] :return: any :rtype: any :raises ~azure.core.exceptions.HttpResponseError: @@ -2032,7 +2040,7 @@ async def summary(self, body: Union[_models.JobSummaryParams, IO[bytes]], **kwar if isinstance(body, (IOBase, bytes)): _content = body else: - _json = self._serialize.body(body, "JobSummaryParams") + _json = self._serialize.body(body, "SummaryParams") _request = build_jobs_summary_request( content_type=content_type, @@ -2157,3 +2165,831 @@ async def submit_jdl_jobs(self, body: Union[List[str], IO[bytes]], **kwargs: Any return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + async def add_pilot_stamps( + self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + + _request = build_pilots_add_pilot_stamps_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def delete_pilots( + self, + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any + ) -> None: + """Delete Pilots. + + Endpoint to delete a pilot. + + Two features: + + + #. Or you provide pilot_stamps, so you can delete pilots by their stamp + #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. + :paramtype pilot_stamps: list[str] + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Default value is None. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake.This flag is only used for deletion by time. Default value + is False. + :paramtype delete_only_aborted: bool + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + async def update_pilot_fields( + self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def update_pilot_fields( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def update_pilot_fields( + self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + + _request = build_pilots_update_pilot_fields_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @distributed_trace_async + async def get_pilot_jobs( + self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any + ) -> List[int]: + """Get Pilot Jobs. + + Endpoint only for admins, to get jobs of a pilot. + + :keyword pilot_stamp: The stamp of the pilot. Default value is None. + :paramtype pilot_stamp: str + :keyword job_id: The ID of the job. Default value is None. + :paramtype job_id: int + :return: list of int + :rtype: list[int] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[List[int]] = kwargs.pop("cls", None) + + _request = build_pilots_get_pilot_jobs_request( + pilot_stamp=pilot_stamp, + job_id=job_id, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("[int]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + async def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + async def search_logs( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def search_logs( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def search_logs( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_logs_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + async def summary( + self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + +class PilotsLegacyOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`pilots_legacy` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + async def send_message( + self, body: _models.BodyPilotsLegacySendMessage, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsLegacySendMessage + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def send_message(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def send_message(self, body: Union[_models.BodyPilotsLegacySendMessage, IO[bytes]], **kwargs: Any) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Is either a BodyPilotsLegacySendMessage type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsLegacySendMessage or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsLegacySendMessage") + + _request = build_pilots_legacy_send_message_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py b/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py index a408e57d2..0c70ce3e9 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py @@ -11,10 +11,12 @@ __all__ = [ "AuthOperations", "JobsOperations", + "PilotsOperations", ] # Add all objects you want publicly available to users at this package level from ....patches.auth.aio import AuthOperations from ....patches.jobs.aio import JobsOperations +from ....patches.pilots.aio import PilotsOperations def patch_sdk(): diff --git a/diracx-client/src/diracx/client/_generated/models/__init__.py b/diracx-client/src/diracx/client/_generated/models/__init__.py index 7343700e4..ce33e799c 100644 --- a/diracx-client/src/diracx/client/_generated/models/__init__.py +++ b/diracx-client/src/diracx/client/_generated/models/__init__.py @@ -14,27 +14,32 @@ from ._models import ( # type: ignore BodyAuthGetOidcToken, BodyAuthGetOidcTokenGrantType, + BodyPilotsAddPilotStamps, + BodyPilotsLegacySendMessage, + BodyPilotsUpdatePilotFields, GroupInfo, HTTPValidationError, HeartbeatData, InitiateDeviceFlowResponse, InsertedJob, JobCommand, - JobSearchParams, - JobSearchParamsSearchItem, JobStatusUpdate, - JobSummaryParams, - JobSummaryParamsSearchItem, + LogLine, Metadata, OpenIDConfiguration, + PilotFieldsMapping, SandboxDownloadResponse, SandboxInfo, SandboxUploadResponse, ScalarSearchSpec, ScalarSearchSpecValue, + SearchParams, + SearchParamsSearchItem, SetJobStatusReturn, SetJobStatusReturnSuccess, SortSpec, + SummaryParams, + SummaryParamsSearchItem, SupportInfo, TokenResponse, UserInfoResponse, @@ -48,6 +53,7 @@ from ._enums import ( # type: ignore ChecksumAlgorithm, JobStatus, + PilotStatus, SandboxFormat, SandboxType, ScalarSearchOperator, @@ -61,27 +67,32 @@ __all__ = [ "BodyAuthGetOidcToken", "BodyAuthGetOidcTokenGrantType", + "BodyPilotsAddPilotStamps", + "BodyPilotsLegacySendMessage", + "BodyPilotsUpdatePilotFields", "GroupInfo", "HTTPValidationError", "HeartbeatData", "InitiateDeviceFlowResponse", "InsertedJob", "JobCommand", - "JobSearchParams", - "JobSearchParamsSearchItem", "JobStatusUpdate", - "JobSummaryParams", - "JobSummaryParamsSearchItem", + "LogLine", "Metadata", "OpenIDConfiguration", + "PilotFieldsMapping", "SandboxDownloadResponse", "SandboxInfo", "SandboxUploadResponse", "ScalarSearchSpec", "ScalarSearchSpecValue", + "SearchParams", + "SearchParamsSearchItem", "SetJobStatusReturn", "SetJobStatusReturnSuccess", "SortSpec", + "SummaryParams", + "SummaryParamsSearchItem", "SupportInfo", "TokenResponse", "UserInfoResponse", @@ -92,6 +103,7 @@ "VectorSearchSpecValues", "ChecksumAlgorithm", "JobStatus", + "PilotStatus", "SandboxFormat", "SandboxType", "ScalarSearchOperator", diff --git a/diracx-client/src/diracx/client/_generated/models/_enums.py b/diracx-client/src/diracx/client/_generated/models/_enums.py index 8098c62f4..44da9887d 100644 --- a/diracx-client/src/diracx/client/_generated/models/_enums.py +++ b/diracx-client/src/diracx/client/_generated/models/_enums.py @@ -34,6 +34,19 @@ class JobStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): RESCHEDULED = "Rescheduled" +class PilotStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """PilotStatus.""" + + SUBMITTED = "Submitted" + WAITING = "Waiting" + RUNNING = "Running" + DONE = "Done" + FAILED = "Failed" + DELETED = "Deleted" + ABORTED = "Aborted" + UNKNOWN = "Unknown" + + class SandboxFormat(str, Enum, metaclass=CaseInsensitiveEnumMeta): """SandboxFormat.""" diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index 14045211b..187d8d87b 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -94,6 +94,144 @@ class BodyAuthGetOidcTokenGrantType(_serialization.Model): """OAuth2 Grant type.""" +class BodyPilotsAddPilotStamps(_serialization.Model): + """Body_pilots_add_pilot_stamps. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :vartype pilot_stamps: list[str] + :ivar vo: Pilot virtual organization. Required. + :vartype vo: str + :ivar grid_type: Grid type of the pilots. + :vartype grid_type: str + :ivar grid_site: Pilots grid site. + :vartype grid_site: str + :ivar destination_site: Pilots destination site. + :vartype destination_site: str + :ivar pilot_references: Association of a pilot reference with a pilot stamp. + :vartype pilot_references: dict[str, str] + :ivar pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", "Running", + "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :vartype pilot_status: str or ~_generated.models.PilotStatus + """ + + _validation = { + "pilot_stamps": {"required": True}, + "vo": {"required": True}, + } + + _attribute_map = { + "pilot_stamps": {"key": "pilot_stamps", "type": "[str]"}, + "vo": {"key": "vo", "type": "str"}, + "grid_type": {"key": "grid_type", "type": "str"}, + "grid_site": {"key": "grid_site", "type": "str"}, + "destination_site": {"key": "destination_site", "type": "str"}, + "pilot_references": {"key": "pilot_references", "type": "{str}"}, + "pilot_status": {"key": "pilot_status", "type": "str"}, + } + + def __init__( + self, + *, + pilot_stamps: List[str], + vo: str, + grid_type: str = "Dirac", + grid_site: str = "Unknown", + destination_site: str = "NotAssigned", + pilot_references: Optional[Dict[str, str]] = None, + pilot_status: Optional[Union[str, "_models.PilotStatus"]] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :paramtype pilot_stamps: list[str] + :keyword vo: Pilot virtual organization. Required. + :paramtype vo: str + :keyword grid_type: Grid type of the pilots. + :paramtype grid_type: str + :keyword grid_site: Pilots grid site. + :paramtype grid_site: str + :keyword destination_site: Pilots destination site. + :paramtype destination_site: str + :keyword pilot_references: Association of a pilot reference with a pilot stamp. + :paramtype pilot_references: dict[str, str] + :keyword pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", + "Running", "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype pilot_status: str or ~_generated.models.PilotStatus + """ + super().__init__(**kwargs) + self.pilot_stamps = pilot_stamps + self.vo = vo + self.grid_type = grid_type + self.grid_site = grid_site + self.destination_site = destination_site + self.pilot_references = pilot_references + self.pilot_status = pilot_status + + +class BodyPilotsLegacySendMessage(_serialization.Model): + """Body_pilots/legacy_send_message. + + All required parameters must be populated in order to send to server. + + :ivar lines: Message from the pilot to the logging system. Required. + :vartype lines: list[~_generated.models.LogLine] + :ivar pilot_stamp: PilotStamp, required as legacy pilots do not have a token with stamp in it. + Required. + :vartype pilot_stamp: str + """ + + _validation = { + "lines": {"required": True}, + "pilot_stamp": {"required": True}, + } + + _attribute_map = { + "lines": {"key": "lines", "type": "[LogLine]"}, + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + } + + def __init__(self, *, lines: List["_models.LogLine"], pilot_stamp: str, **kwargs: Any) -> None: + """ + :keyword lines: Message from the pilot to the logging system. Required. + :paramtype lines: list[~_generated.models.LogLine] + :keyword pilot_stamp: PilotStamp, required as legacy pilots do not have a token with stamp in + it. Required. + :paramtype pilot_stamp: str + """ + super().__init__(**kwargs) + self.lines = lines + self.pilot_stamp = pilot_stamp + + +class BodyPilotsUpdatePilotFields(_serialization.Model): + """Body_pilots_update_pilot_fields. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. Required. + :vartype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + + _validation = { + "pilot_stamps_to_fields_mapping": {"required": True}, + } + + _attribute_map = { + "pilot_stamps_to_fields_mapping": {"key": "pilot_stamps_to_fields_mapping", "type": "[PilotFieldsMapping]"}, + } + + def __init__(self, *, pilot_stamps_to_fields_mapping: List["_models.PilotFieldsMapping"], **kwargs: Any) -> None: + """ + :keyword pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. + Required. + :paramtype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + super().__init__(**kwargs) + self.pilot_stamps_to_fields_mapping = pilot_stamps_to_fields_mapping + + class GroupInfo(_serialization.Model): """GroupInfo. @@ -358,56 +496,6 @@ def __init__(self, *, job_id: int, command: str, arguments: Optional[str] = None self.arguments = arguments -class JobSearchParams(_serialization.Model): - """JobSearchParams. - - :ivar parameters: Parameters. - :vartype parameters: list[str] - :ivar search: Search. - :vartype search: list[~_generated.models.JobSearchParamsSearchItem] - :ivar sort: Sort. - :vartype sort: list[~_generated.models.SortSpec] - :ivar distinct: Distinct. - :vartype distinct: bool - """ - - _attribute_map = { - "parameters": {"key": "parameters", "type": "[str]"}, - "search": {"key": "search", "type": "[JobSearchParamsSearchItem]"}, - "sort": {"key": "sort", "type": "[SortSpec]"}, - "distinct": {"key": "distinct", "type": "bool"}, - } - - def __init__( - self, - *, - parameters: Optional[List[str]] = None, - search: List["_models.JobSearchParamsSearchItem"] = [], - sort: List["_models.SortSpec"] = [], - distinct: bool = False, - **kwargs: Any - ) -> None: - """ - :keyword parameters: Parameters. - :paramtype parameters: list[str] - :keyword search: Search. - :paramtype search: list[~_generated.models.JobSearchParamsSearchItem] - :keyword sort: Sort. - :paramtype sort: list[~_generated.models.SortSpec] - :keyword distinct: Distinct. - :paramtype distinct: bool - """ - super().__init__(**kwargs) - self.parameters = parameters - self.search = search - self.sort = sort - self.distinct = distinct - - -class JobSearchParamsSearchItem(_serialization.Model): - """JobSearchParamsSearchItem.""" - - class JobStatusUpdate(_serialization.Model): """JobStatusUpdate. @@ -458,42 +546,51 @@ def __init__( self.source = source -class JobSummaryParams(_serialization.Model): - """JobSummaryParams. +class LogLine(_serialization.Model): + """LogLine. All required parameters must be populated in order to send to server. - :ivar grouping: Grouping. Required. - :vartype grouping: list[str] - :ivar search: Search. - :vartype search: list[~_generated.models.JobSummaryParamsSearchItem] + :ivar timestamp: Timestamp. Required. + :vartype timestamp: str + :ivar severity: Severity. Required. + :vartype severity: str + :ivar message: Message. Required. + :vartype message: str + :ivar scope: Scope. Required. + :vartype scope: str """ _validation = { - "grouping": {"required": True}, + "timestamp": {"required": True}, + "severity": {"required": True}, + "message": {"required": True}, + "scope": {"required": True}, } _attribute_map = { - "grouping": {"key": "grouping", "type": "[str]"}, - "search": {"key": "search", "type": "[JobSummaryParamsSearchItem]"}, + "timestamp": {"key": "timestamp", "type": "str"}, + "severity": {"key": "severity", "type": "str"}, + "message": {"key": "message", "type": "str"}, + "scope": {"key": "scope", "type": "str"}, } - def __init__( - self, *, grouping: List[str], search: List["_models.JobSummaryParamsSearchItem"] = [], **kwargs: Any - ) -> None: + def __init__(self, *, timestamp: str, severity: str, message: str, scope: str, **kwargs: Any) -> None: """ - :keyword grouping: Grouping. Required. - :paramtype grouping: list[str] - :keyword search: Search. - :paramtype search: list[~_generated.models.JobSummaryParamsSearchItem] + :keyword timestamp: Timestamp. Required. + :paramtype timestamp: str + :keyword severity: Severity. Required. + :paramtype severity: str + :keyword message: Message. Required. + :paramtype message: str + :keyword scope: Scope. Required. + :paramtype scope: str """ super().__init__(**kwargs) - self.grouping = grouping - self.search = search - - -class JobSummaryParamsSearchItem(_serialization.Model): - """JobSummaryParamsSearchItem.""" + self.timestamp = timestamp + self.severity = severity + self.message = message + self.scope = scope class Metadata(_serialization.Model): @@ -655,6 +752,102 @@ def __init__( self.code_challenge_methods_supported = code_challenge_methods_supported +class PilotFieldsMapping(_serialization.Model): + """All the fields that a user can modify on a Pilot (except PilotStamp). + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: Pilotstamp. Required. + :vartype pilot_stamp: str + :ivar status_reason: Statusreason. + :vartype status_reason: str + :ivar status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :vartype status: str or ~_generated.models.PilotStatus + :ivar bench_mark: Benchmark. + :vartype bench_mark: float + :ivar destination_site: Destinationsite. + :vartype destination_site: str + :ivar queue: Queue. + :vartype queue: str + :ivar grid_site: Gridsite. + :vartype grid_site: str + :ivar grid_type: Gridtype. + :vartype grid_type: str + :ivar accounting_sent: Accountingsent. + :vartype accounting_sent: bool + :ivar current_job_id: Currentjobid. + :vartype current_job_id: int + """ + + _validation = { + "pilot_stamp": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "PilotStamp", "type": "str"}, + "status_reason": {"key": "StatusReason", "type": "str"}, + "status": {"key": "Status", "type": "str"}, + "bench_mark": {"key": "BenchMark", "type": "float"}, + "destination_site": {"key": "DestinationSite", "type": "str"}, + "queue": {"key": "Queue", "type": "str"}, + "grid_site": {"key": "GridSite", "type": "str"}, + "grid_type": {"key": "GridType", "type": "str"}, + "accounting_sent": {"key": "AccountingSent", "type": "bool"}, + "current_job_id": {"key": "CurrentJobID", "type": "int"}, + } + + def __init__( + self, + *, + pilot_stamp: str, + status_reason: Optional[str] = None, + status: Optional[Union[str, "_models.PilotStatus"]] = None, + bench_mark: Optional[float] = None, + destination_site: Optional[str] = None, + queue: Optional[str] = None, + grid_site: Optional[str] = None, + grid_type: Optional[str] = None, + accounting_sent: Optional[bool] = None, + current_job_id: Optional[int] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamp: Pilotstamp. Required. + :paramtype pilot_stamp: str + :keyword status_reason: Statusreason. + :paramtype status_reason: str + :keyword status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype status: str or ~_generated.models.PilotStatus + :keyword bench_mark: Benchmark. + :paramtype bench_mark: float + :keyword destination_site: Destinationsite. + :paramtype destination_site: str + :keyword queue: Queue. + :paramtype queue: str + :keyword grid_site: Gridsite. + :paramtype grid_site: str + :keyword grid_type: Gridtype. + :paramtype grid_type: str + :keyword accounting_sent: Accountingsent. + :paramtype accounting_sent: bool + :keyword current_job_id: Currentjobid. + :paramtype current_job_id: int + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.status_reason = status_reason + self.status = status + self.bench_mark = bench_mark + self.destination_site = destination_site + self.queue = queue + self.grid_site = grid_site + self.grid_type = grid_type + self.accounting_sent = accounting_sent + self.current_job_id = current_job_id + + class SandboxDownloadResponse(_serialization.Model): """SandboxDownloadResponse. @@ -836,6 +1029,56 @@ class ScalarSearchSpecValue(_serialization.Model): """Value.""" +class SearchParams(_serialization.Model): + """SearchParams. + + :ivar parameters: Parameters. + :vartype parameters: list[str] + :ivar search: Search. + :vartype search: list[~_generated.models.SearchParamsSearchItem] + :ivar sort: Sort. + :vartype sort: list[~_generated.models.SortSpec] + :ivar distinct: Distinct. + :vartype distinct: bool + """ + + _attribute_map = { + "parameters": {"key": "parameters", "type": "[str]"}, + "search": {"key": "search", "type": "[SearchParamsSearchItem]"}, + "sort": {"key": "sort", "type": "[SortSpec]"}, + "distinct": {"key": "distinct", "type": "bool"}, + } + + def __init__( + self, + *, + parameters: Optional[List[str]] = None, + search: List["_models.SearchParamsSearchItem"] = [], + sort: List["_models.SortSpec"] = [], + distinct: bool = False, + **kwargs: Any + ) -> None: + """ + :keyword parameters: Parameters. + :paramtype parameters: list[str] + :keyword search: Search. + :paramtype search: list[~_generated.models.SearchParamsSearchItem] + :keyword sort: Sort. + :paramtype sort: list[~_generated.models.SortSpec] + :keyword distinct: Distinct. + :paramtype distinct: bool + """ + super().__init__(**kwargs) + self.parameters = parameters + self.search = search + self.sort = sort + self.distinct = distinct + + +class SearchParamsSearchItem(_serialization.Model): + """SearchParamsSearchItem.""" + + class SetJobStatusReturn(_serialization.Model): """SetJobStatusReturn. @@ -979,6 +1222,44 @@ def __init__(self, *, parameter: str, direction: Union[str, "_models.SortDirecti self.direction = direction +class SummaryParams(_serialization.Model): + """SummaryParams. + + All required parameters must be populated in order to send to server. + + :ivar grouping: Grouping. Required. + :vartype grouping: list[str] + :ivar search: Search. + :vartype search: list[~_generated.models.SummaryParamsSearchItem] + """ + + _validation = { + "grouping": {"required": True}, + } + + _attribute_map = { + "grouping": {"key": "grouping", "type": "[str]"}, + "search": {"key": "search", "type": "[SummaryParamsSearchItem]"}, + } + + def __init__( + self, *, grouping: List[str], search: List["_models.SummaryParamsSearchItem"] = [], **kwargs: Any + ) -> None: + """ + :keyword grouping: Grouping. Required. + :paramtype grouping: list[str] + :keyword search: Search. + :paramtype search: list[~_generated.models.SummaryParamsSearchItem] + """ + super().__init__(**kwargs) + self.grouping = grouping + self.search = search + + +class SummaryParamsSearchItem(_serialization.Model): + """SummaryParamsSearchItem.""" + + class SupportInfo(_serialization.Model): """SupportInfo. diff --git a/diracx-client/src/diracx/client/_generated/operations/__init__.py b/diracx-client/src/diracx/client/_generated/operations/__init__.py index 10db0c7a9..53c8a8f82 100644 --- a/diracx-client/src/diracx/client/_generated/operations/__init__.py +++ b/diracx-client/src/diracx/client/_generated/operations/__init__.py @@ -14,6 +14,8 @@ from ._operations import AuthOperations # type: ignore from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore +from ._operations import PilotsOperations # type: ignore +from ._operations import PilotsLegacyOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -24,6 +26,8 @@ "AuthOperations", "ConfigOperations", "JobsOperations", + "PilotsOperations", + "PilotsLegacyOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index 0259e5aaf..b6789fb77 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -590,6 +590,162 @@ def build_jobs_submit_jdl_jobs_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) +def build_pilots_add_pilot_stamps_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + +def build_pilots_delete_pilots_request( + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any +) -> HttpRequest: + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + # Construct URL + _url = "/api/pilots/" + + # Construct parameters + if pilot_stamps is not None: + _params["pilot_stamps"] = _SERIALIZER.query("pilot_stamps", pilot_stamps, "[str]") + if age_in_days is not None: + _params["age_in_days"] = _SERIALIZER.query("age_in_days", age_in_days, "int") + if delete_only_aborted is not None: + _params["delete_only_aborted"] = _SERIALIZER.query("delete_only_aborted", delete_only_aborted, "bool") + + return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) + + +def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + # Construct URL + _url = "/api/pilots/metadata" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + + return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) + + +def build_pilots_get_pilot_jobs_request( + *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/jobs" + + # Construct parameters + if pilot_stamp is not None: + _params["pilot_stamp"] = _SERIALIZER.query("pilot_stamp", pilot_stamp, "str") + if job_id is not None: + _params["job_id"] = _SERIALIZER.query("job_id", job_id, "int") + + # Construct headers + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/search" + + # Construct parameters + if page is not None: + _params["page"] = _SERIALIZER.query("page", page, "int") + if per_page is not None: + _params["per_page"] = _SERIALIZER.query("per_page", per_page, "int") + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_pilots_search_logs_request(*, page: int = 1, per_page: int = 100, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/search/logs" + + # Construct parameters + if page is not None: + _params["page"] = _SERIALIZER.query("page", page, "int") + if per_page is not None: + _params["per_page"] = _SERIALIZER.query("per_page", per_page, "int") + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_pilots_summary_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/summary" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + +def build_pilots_legacy_send_message_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + # Construct URL + _url = "/api/pilots/legacy/message" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -2351,7 +2507,7 @@ def patch_metadata( # pylint: disable=inconsistent-return-statements @overload def search( self, - body: Optional[_models.JobSearchParams] = None, + body: Optional[_models.SearchParams] = None, *, page: int = 1, per_page: int = 100, @@ -2365,7 +2521,7 @@ def search( **TODO: Add more docs**. :param body: Default value is None. - :type body: ~_generated.models.JobSearchParams + :type body: ~_generated.models.SearchParams :keyword page: Default value is 1. :paramtype page: int :keyword per_page: Default value is 100. @@ -2411,7 +2567,7 @@ def search( @distributed_trace def search( self, - body: Optional[Union[_models.JobSearchParams, IO[bytes]]] = None, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, *, page: int = 1, per_page: int = 100, @@ -2423,8 +2579,8 @@ def search( **TODO: Add more docs**. - :param body: Is either a JobSearchParams type or a IO[bytes] type. Default value is None. - :type body: ~_generated.models.JobSearchParams or IO[bytes] + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] :keyword page: Default value is 1. :paramtype page: int :keyword per_page: Default value is 100. @@ -2454,7 +2610,7 @@ def search( _content = body else: if body is not None: - _json = self._serialize.body(body, "JobSearchParams") + _json = self._serialize.body(body, "SearchParams") else: _json = None @@ -2492,13 +2648,13 @@ def search( return deserialized # type: ignore @overload - def summary(self, body: _models.JobSummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: + def summary(self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: """Summary. Show information suitable for plotting. :param body: Required. - :type body: ~_generated.models.JobSummaryParams + :type body: ~_generated.models.SummaryParams :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str @@ -2524,13 +2680,13 @@ def summary(self, body: IO[bytes], *, content_type: str = "application/json", ** """ @distributed_trace - def summary(self, body: Union[_models.JobSummaryParams, IO[bytes]], **kwargs: Any) -> Any: + def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: """Summary. Show information suitable for plotting. - :param body: Is either a JobSummaryParams type or a IO[bytes] type. Required. - :type body: ~_generated.models.JobSummaryParams or IO[bytes] + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] :return: any :rtype: any :raises ~azure.core.exceptions.HttpResponseError: @@ -2555,7 +2711,7 @@ def summary(self, body: Union[_models.JobSummaryParams, IO[bytes]], **kwargs: An if isinstance(body, (IOBase, bytes)): _content = body else: - _json = self._serialize.body(body, "JobSummaryParams") + _json = self._serialize.body(body, "SummaryParams") _request = build_jobs_summary_request( content_type=content_type, @@ -2680,3 +2836,829 @@ def submit_jdl_jobs(self, body: Union[List[str], IO[bytes]], **kwargs: Any) -> L return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + def add_pilot_stamps( + self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + + _request = build_pilots_add_pilot_stamps_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def delete_pilots( # pylint: disable=inconsistent-return-statements + self, + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any + ) -> None: + """Delete Pilots. + + Endpoint to delete a pilot. + + Two features: + + + #. Or you provide pilot_stamps, so you can delete pilots by their stamp + #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. + :paramtype pilot_stamps: list[str] + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Default value is None. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake.This flag is only used for deletion by time. Default value + is False. + :paramtype delete_only_aborted: bool + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + def update_pilot_fields( + self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def update_pilot_fields(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def update_pilot_fields( # pylint: disable=inconsistent-return-statements + self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + + _request = build_pilots_update_pilot_fields_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @distributed_trace + def get_pilot_jobs( + self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any + ) -> List[int]: + """Get Pilot Jobs. + + Endpoint only for admins, to get jobs of a pilot. + + :keyword pilot_stamp: The stamp of the pilot. Default value is None. + :paramtype pilot_stamp: str + :keyword job_id: The ID of the job. Default value is None. + :paramtype job_id: int + :return: list of int + :rtype: list[int] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[List[int]] = kwargs.pop("cls", None) + + _request = build_pilots_get_pilot_jobs_request( + pilot_stamp=pilot_stamp, + job_id=job_id, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("[int]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + def search_logs( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def search_logs( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def search_logs( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_logs_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + def summary(self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + +class PilotsLegacyOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`pilots_legacy` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + def send_message( + self, body: _models.BodyPilotsLegacySendMessage, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsLegacySendMessage + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def send_message(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def send_message( # pylint: disable=inconsistent-return-statements + self, body: Union[_models.BodyPilotsLegacySendMessage, IO[bytes]], **kwargs: Any + ) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Is either a BodyPilotsLegacySendMessage type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsLegacySendMessage or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsLegacySendMessage") + + _request = build_pilots_legacy_send_message_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore diff --git a/diracx-client/src/diracx/client/_generated/operations/_patch.py b/diracx-client/src/diracx/client/_generated/operations/_patch.py index b7b8c67fa..b14e98b84 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_patch.py +++ b/diracx-client/src/diracx/client/_generated/operations/_patch.py @@ -11,10 +11,12 @@ __all__ = [ "AuthOperations", "JobsOperations", + "PilotsOperations", ] # Add all objects you want publicly available to users at this package level from ...patches.auth.sync import AuthOperations from ...patches.jobs.sync import JobsOperations +from ...patches.pilots.sync import PilotsOperations def patch_sdk(): diff --git a/diracx-client/src/diracx/client/patches/pilots/aio.py b/diracx-client/src/diracx/client/patches/pilots/aio.py new file mode 100644 index 000000000..ac533a67c --- /dev/null +++ b/diracx-client/src/diracx/client/patches/pilots/aio.py @@ -0,0 +1,53 @@ +"""Patches for the autorest-generated pilots client. + +This file can be used to customize the generated code for the pilots client. +When adding new classes to this file, make sure to also add them to the +__all__ list in the corresponding file in the patches directory. +""" + +from __future__ import annotations + +__all__ = [ + "PilotsOperations", +] + +from typing import Any, Unpack + +from azure.core.tracing.decorator_async import distributed_trace_async + +from ..._generated.aio.operations._operations import PilotsOperations as _PilotsOperations +from .common import ( + make_search_body, + make_summary_body, + make_add_pilot_stamps_body, + make_update_pilot_fields_body, + SearchKwargs, + SummaryKwargs, + AddPilotStampsKwargs, + UpdatePilotFieldsKwargs +) + +# We're intentionally ignoring overrides here because we want to change the interface. +# mypy: disable-error-code=override + + +class PilotsOperations(_PilotsOperations): + @distributed_trace_async + async def search(self, **kwargs: Unpack[SearchKwargs]) -> list[dict[str, Any]]: + """TODO""" + return await super().search(**make_search_body(**kwargs)) + + @distributed_trace_async + async def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]: + """TODO""" + return await super().summary(**make_summary_body(**kwargs)) + + @distributed_trace_async + async def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> None: + """TODO""" + return await super().add_pilot_stamps(**make_add_pilot_stamps_body(**kwargs)) + + @distributed_trace_async + async def update_pilot_fields(self, **kwargs: Unpack[UpdatePilotFieldsKwargs]) -> None: + """TODO""" + return await super().update_pilot_fields(**make_update_pilot_fields_body(**kwargs)) diff --git a/diracx-client/src/diracx/client/patches/pilots/common.py b/diracx-client/src/diracx/client/patches/pilots/common.py new file mode 100644 index 000000000..3f5ec8c4b --- /dev/null +++ b/diracx-client/src/diracx/client/patches/pilots/common.py @@ -0,0 +1,146 @@ +"""Utilities which are common to the sync and async pilots operator patches.""" + +from __future__ import annotations + +__all__ = [ + "make_search_body", + "SearchKwargs", + "make_summary_body", + "SummaryKwargs", + "AddPilotStampsKwargs", + "make_add_pilot_stamps_body", + "UpdatePilotFieldsKwargs", + "make_update_pilot_fields_body" +] + +import json +from io import BytesIO +from typing import Any, IO, TypedDict, Unpack, cast, Literal + +from diracx.core.models import SearchSpec, PilotStatus, PilotFieldsMapping + + +class ResponseExtra(TypedDict, total=False): + content_type: str + headers: dict[str, str] + params: dict[str, str] + cls: Any + + +# ------------------ Search ------------------ +class SearchBody(TypedDict, total=False): + parameters: list[str] | None + search: list[SearchSpec] | None + sort: list[str] | None + + +class SearchExtra(ResponseExtra, total=False): + page: int + per_page: int + + +class SearchKwargs(SearchBody, SearchExtra): ... + + +class UnderlyingSearchArgs(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + + +def make_search_body(**kwargs: Unpack[SearchKwargs]) -> UnderlyingSearchArgs: + body: SearchBody = {} + for key in SearchBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["parameters", "search", "sort"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingSearchArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(SearchExtra, kwargs)) + return result + +# ------------------ Summary ------------------ + +class SummaryBody(TypedDict, total=False): + grouping: list[str] + search: list[str] + + +class SummaryKwargs(SummaryBody, ResponseExtra): ... + + +class UnderlyingSummaryArgs(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + + +def make_summary_body(**kwargs: Unpack[SummaryKwargs]) -> UnderlyingSummaryArgs: + body: SummaryBody = {} + for key in SummaryBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["grouping", "search"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingSummaryArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(ResponseExtra, kwargs)) + return result + +# ------------------ AddPilotStamps ------------------ + +class AddPilotStampsBody(TypedDict, total=False): + pilot_stamps: list[str] + grid_type: str + grid_site: str + pilot_references: dict[str, str] + pilot_status: PilotStatus + vo: str + +class AddPilotStampsKwargs(AddPilotStampsBody, ResponseExtra): ... + +class UnderlyingAddPilotStampsArgs(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + +def make_add_pilot_stamps_body(**kwargs: Unpack[AddPilotStampsKwargs]) -> UnderlyingAddPilotStampsArgs: + body: AddPilotStampsBody = {} + for key in AddPilotStampsBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["pilot_stamps", "grid_type", "grid_site", "pilot_references", "pilot_status", "vo"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingAddPilotStampsArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(ResponseExtra, kwargs)) + return result + +# ------------------ UpdatePilotFields ------------------ + +class UpdatePilotFieldsBody(TypedDict, total=False): + pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] + +class UpdatePilotFieldsKwargs(UpdatePilotFieldsBody, ResponseExtra): ... + +class UnderlyingUpdatePilotFields(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + +def make_update_pilot_fields_body(**kwargs: Unpack[UpdatePilotFieldsKwargs]) -> UnderlyingUpdatePilotFields: + body: UpdatePilotFieldsBody = {} + for key in UpdatePilotFieldsBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["pilot_stamps_to_fields_mapping"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingUpdatePilotFields = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(ResponseExtra, kwargs)) + return result diff --git a/diracx-client/src/diracx/client/patches/pilots/sync.py b/diracx-client/src/diracx/client/patches/pilots/sync.py new file mode 100644 index 000000000..744cee161 --- /dev/null +++ b/diracx-client/src/diracx/client/patches/pilots/sync.py @@ -0,0 +1,53 @@ +"""Patches for the autorest-generated pilots client. + +This file can be used to customize the generated code for the pilots client. +When adding new classes to this file, make sure to also add them to the +__all__ list in the corresponding file in the patches directory. +""" + +from __future__ import annotations + +__all__ = [ + "PilotsOperations", +] + +from typing import Any, Unpack + +from azure.core.tracing.decorator import distributed_trace + +from ..._generated.operations._operations import PilotsOperations as _PilotsOperations +from .common import ( + make_search_body, + make_summary_body, + make_add_pilot_stamps_body, + make_update_pilot_fields_body, + SearchKwargs, + SummaryKwargs, + AddPilotStampsKwargs, + UpdatePilotFieldsKwargs +) + +# We're intentionally ignoring overrides here because we want to change the interface. +# mypy: disable-error-code=override + + +class PilotsOperations(_PilotsOperations): + @distributed_trace + def search(self, **kwargs: Unpack[SearchKwargs]) -> list[dict[str, Any]]: + """TODO""" + return super().search(**make_search_body(**kwargs)) + + @distributed_trace + def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]: + """TODO""" + return super().summary(**make_summary_body(**kwargs)) + + @distributed_trace + def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> None: + """TODO""" + return super().add_pilot_stamps(**make_add_pilot_stamps_body(**kwargs)) + + @distributed_trace + def update_pilot_fields(self, **kwargs: Unpack[UpdatePilotFieldsKwargs]) -> None: + """TODO""" + return super().update_pilot_fields(**make_update_pilot_fields_body(**kwargs)) diff --git a/diracx-core/src/diracx/core/exceptions.py b/diracx-core/src/diracx/core/exceptions.py index 54d7c240d..a9a571795 100644 --- a/diracx-core/src/diracx/core/exceptions.py +++ b/diracx-core/src/diracx/core/exceptions.py @@ -15,6 +15,7 @@ class DiracError(RuntimeError): def __init__(self, detail: str = "Unknown"): self.detail = detail + super().__init__(detail) class AuthorizationError(DiracError): ... @@ -49,19 +50,19 @@ class InvalidQueryError(DiracError): class TokenNotFoundError(DiracError): - def __init__(self, jti: str, detail: str | None = None): + def __init__(self, jti: str, detail: str = ""): self.jti: str = jti super().__init__(f"Token {jti} not found" + (f" ({detail})" if detail else "")) class JobNotFoundError(DiracError): - def __init__(self, job_id: int, detail: str | None = None): + def __init__(self, job_id: int, detail: str = ""): self.job_id: int = job_id super().__init__(f"Job {job_id} not found" + (f" ({detail})" if detail else "")) class SandboxNotFoundError(DiracError): - def __init__(self, pfn: str, se_name: str, detail: str | None = None): + def __init__(self, pfn: str, se_name: str, detail: str = ""): self.pfn: str = pfn self.se_name: str = se_name super().__init__( @@ -71,7 +72,7 @@ def __init__(self, pfn: str, se_name: str, detail: str | None = None): class SandboxAlreadyAssignedError(DiracError): - def __init__(self, pfn: str, se_name: str, detail: str | None = None): + def __init__(self, pfn: str, se_name: str, detail: str = ""): self.pfn: str = pfn self.se_name: str = se_name super().__init__( @@ -81,7 +82,7 @@ def __init__(self, pfn: str, se_name: str, detail: str | None = None): class SandboxAlreadyInsertedError(DiracError): - def __init__(self, pfn: str, se_name: str, detail: str | None = None): + def __init__(self, pfn: str, se_name: str, detail: str = ""): self.pfn: str = pfn self.se_name: str = se_name super().__init__( @@ -91,7 +92,7 @@ def __init__(self, pfn: str, se_name: str, detail: str | None = None): class JobError(DiracError): - def __init__(self, job_id, detail: str | None = None): + def __init__(self, job_id, detail: str = ""): self.job_id: int = job_id super().__init__( f"Error concerning job {job_id}" + (f" ({detail})" if detail else "") @@ -100,3 +101,43 @@ def __init__(self, job_id, detail: str | None = None): class NotReadyError(DiracError): """Tried to access a value which is asynchronously loaded but not yet available.""" + + +class DiracFormattedError(DiracError): + # TODO: Refactor? + pattern = "Error %s" + + def __init__(self, data: dict[str, str], detail: str = ""): + self.data = data + + parts = [f"({key}: {value})" for key, value in data.items()] + message = type(self).pattern % (" ".join(parts)) + if detail: + message += f": {detail}" + + super().__init__(message) + + +class PilotNotFoundError(DiracFormattedError): + pattern = "Pilot %s not found" + + def __init__( + self, + data: dict[str, str], + detail: str = "", + non_existing_pilots: set = set(), + ): + super().__init__(data, detail) + self.non_existing_pilots = non_existing_pilots + + +class PilotAlreadyExistsError(DiracFormattedError): + pattern = "Pilot %s already exists" + + +class PilotJobsNotFoundError(DiracFormattedError): + pattern = "Pilots or Jobs %s not found" + + +class PilotAlreadyAssociatedWithJobError(DiracFormattedError): + pattern = "Pilot is already associated with a job %s " diff --git a/diracx-core/src/diracx/core/models.py b/diracx-core/src/diracx/core/models.py index 415e36295..fc32d30c6 100644 --- a/diracx-core/src/diracx/core/models.py +++ b/diracx-core/src/diracx/core/models.py @@ -7,7 +7,7 @@ from datetime import datetime from enum import StrEnum -from typing import Literal +from typing import Literal, Optional from pydantic import BaseModel, Field from typing_extensions import TypedDict @@ -29,7 +29,7 @@ class VectorSearchOperator(StrEnum): class ScalarSearchSpec(TypedDict): parameter: str operator: ScalarSearchOperator - value: str | int + value: str | int | datetime class VectorSearchSpec(TypedDict): @@ -59,13 +59,13 @@ class InsertedJob(TypedDict): TimeStamp: datetime -class JobSummaryParams(BaseModel): +class SummaryParams(BaseModel): grouping: list[str] search: list[SearchSpec] = [] # TODO: Add more validation -class JobSearchParams(BaseModel): +class SearchParams(BaseModel): parameters: list[str] | None = None search: list[SearchSpec] = [] sort: list[SortSpec] = [] @@ -272,3 +272,44 @@ class JobCommand(BaseModel): job_id: int command: Literal["Kill"] arguments: str | None = None + + +class PilotFieldsMapping(BaseModel, extra="forbid"): + """All the fields that a user can modify on a Pilot (except PilotStamp).""" + + PilotStamp: str + StatusReason: Optional[str] = None + Status: Optional[PilotStatus] = None + BenchMark: Optional[float] = None + DestinationSite: Optional[str] = None + Queue: Optional[str] = None + GridSite: Optional[str] = None + GridType: Optional[str] = None + AccountingSent: Optional[bool] = None + CurrentJobID: Optional[int] = None + + +class PilotStatus(StrEnum): + #: The pilot has been generated and is transferred to a remote site: + SUBMITTED = "Submitted" + #: The pilot is waiting for a computing resource in a batch queue: + WAITING = "Waiting" + #: The pilot is running a payload on a worker node: + RUNNING = "Running" + #: The pilot finished its execution: + DONE = "Done" + #: The pilot execution failed: + FAILED = "Failed" + #: The pilot was deleted: + DELETED = "Deleted" + #: The pilot execution was aborted: + ABORTED = "Aborted" + #: Cannot get information about the pilot status: + UNKNOWN = "Unknown" + + +class LogLine(BaseModel): + timestamp: str + severity: str + message: str + scope: str diff --git a/diracx-db/pyproject.toml b/diracx-db/pyproject.toml index 8a5e87d8a..2ebc5cca1 100644 --- a/diracx-db/pyproject.toml +++ b/diracx-db/pyproject.toml @@ -34,6 +34,7 @@ TaskQueueDB = "diracx.db.sql:TaskQueueDB" [project.entry-points."diracx.dbs.os"] JobParametersDB = "diracx.db.os:JobParametersDB" +PilotLogsDB = "diracx.db.os:PilotLogsDB" [build-system] requires = ["hatchling", "hatch-vcs"] diff --git a/diracx-db/src/diracx/db/os/__init__.py b/diracx-db/src/diracx/db/os/__init__.py index 535e2a954..d8a450754 100644 --- a/diracx-db/src/diracx/db/os/__init__.py +++ b/diracx-db/src/diracx/db/os/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations -__all__ = ("JobParametersDB",) +__all__ = ("JobParametersDB", "PilotLogsDB") from .job_parameters import JobParametersDB +from .pilot_logs import PilotLogsDB diff --git a/diracx-db/src/diracx/db/os/pilot_logs.py b/diracx-db/src/diracx/db/os/pilot_logs.py new file mode 100644 index 000000000..614c3cb50 --- /dev/null +++ b/diracx-db/src/diracx/db/os/pilot_logs.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from diracx.db.os.utils import BaseOSDB + + +class PilotLogsDB(BaseOSDB): + fields = { + "PilotStamp": {"type": "keyword"}, + "PilotID": {"type": "long"}, + "Severity": {"type": "keyword"}, + "Message": {"type": "text"}, + "VO": {"type": "keyword"}, + "TimeStamp": {"type": "date_nanos"}, + "Scope": {"type": "keyword"}, + } + index_prefix = "pilot_logs" + + def index_name(self, vo: str, doc_id: int) -> str: + split = int(int(doc_id) // 1e6) + # We split docs into smaller one (grouped by 1 million pilot) + # Ex: pilot_logs_dteam_1030m + return f"{self.index_prefix}_{vo.lower()}_{split}m" diff --git a/diracx-db/src/diracx/db/os/utils.py b/diracx-db/src/diracx/db/os/utils.py index ea5d292e6..7beb0f104 100644 --- a/diracx-db/src/diracx/db/os/utils.py +++ b/diracx-db/src/diracx/db/os/utils.py @@ -1,5 +1,7 @@ from __future__ import annotations +from opensearchpy.helpers import async_bulk + __all__ = ("BaseOSDB",) import contextlib @@ -197,13 +199,35 @@ async def upsert(self, vo: str, doc_id: int, document: Any) -> None: response, ) + async def bulk_insert(self, index_name: str, docs: list[dict[str, Any]]) -> None: + """Bulk inserting to database.""" + n_inserted, failed = await async_bulk( + self.client, actions=[doc | {"_index": index_name} for doc in docs] + ) + logger.info("Inserted %d documents to %s", n_inserted, index_name) + + if failed: + logger.error("Fail to insert %d documents to %s", failed, index_name) + async def search( - self, parameters, search, sorts, *, per_page: int = 100, page: int | None = None - ) -> list[dict[str, Any]]: + self, + parameters, + search, + sorts, + *, + per_page: int = 10000, + page: int | None = None, + ) -> tuple[int, list[dict[str, Any]]]: """Search the database for matching results. See the DiracX search API documentation for details. """ + if page: + if page < 1: + raise InvalidQueryError("Page must be a positive integer") + if per_page < 1: + raise InvalidQueryError("Per page must be a positive integer") + body = {} if parameters: body["_source"] = parameters @@ -213,7 +237,12 @@ async def search( for sort in sorts: field_name = sort["parameter"] field_type = self.fields.get(field_name, {}).get("type") - require_type("sort", field_name, field_type, {"keyword", "long", "date"}) + require_type( + "sort", + field_name, + field_type, + {"keyword", "long", "date", "date_nanos"}, + ) body["sort"].append({field_name: {"order": sort["direction"]}}) params = {} @@ -226,17 +255,19 @@ async def search( ) hits = [hit["_source"] for hit in response["hits"]["hits"]] + total_hits = response["hits"]["total"]["value"] + # Dates are returned as strings, convert them to Python datetimes for hit in hits: for field_name in hit: if field_name not in self.fields: continue - if self.fields[field_name]["type"] == "date": + if self.fields[field_name]["type"] in ["date", "date_nanos"]: hit[field_name] = datetime.strptime( hit[field_name], "%Y-%m-%dT%H:%M:%S.%f%z" ) - return hits + return total_hits, hits def require_type(operator, field_name, field_type, allowed_types): diff --git a/diracx-db/src/diracx/db/sql/__init__.py b/diracx-db/src/diracx/db/sql/__init__.py index 3be3af8a3..e2f141ad5 100644 --- a/diracx-db/src/diracx/db/sql/__init__.py +++ b/diracx-db/src/diracx/db/sql/__init__.py @@ -12,6 +12,6 @@ from .auth.db import AuthDB from .job.db import JobDB from .job_logging.db import JobLoggingDB -from .pilot_agents.db import PilotAgentsDB +from .pilots.db import PilotAgentsDB from .sandbox_metadata.db import SandboxMetadataDB from .task_queue.db import TaskQueueDB diff --git a/diracx-db/src/diracx/db/sql/dummy/db.py b/diracx-db/src/diracx/db/sql/dummy/db.py index 0c25df43c..b68bbe64d 100644 --- a/diracx-db/src/diracx/db/sql/dummy/db.py +++ b/diracx-db/src/diracx/db/sql/dummy/db.py @@ -1,9 +1,10 @@ from __future__ import annotations -from sqlalchemy import func, insert, select +from sqlalchemy import insert from uuid_utils import UUID -from diracx.db.sql.utils import BaseSQLDB, apply_search_filters +from diracx.core.models import SearchSpec +from diracx.db.sql.utils import BaseSQLDB from .schema import Base as DummyDBBase from .schema import Cars, Owners @@ -21,19 +22,11 @@ class DummyDB(BaseSQLDB): # This needs to be here for the BaseSQLDB to create the engine metadata = DummyDBBase.metadata - async def summary(self, group_by, search) -> list[dict[str, str | int]]: - columns = [Cars.__table__.columns[x] for x in group_by] - - stmt = select(*columns, func.count(Cars.license_plate).label("count")) - stmt = apply_search_filters(Cars.__table__.columns.__getitem__, stmt, search) - stmt = stmt.group_by(*columns) - - # Execute the query - return [ - dict(row._mapping) - async for row in (await self.conn.stream(stmt)) - if row.count > 0 # type: ignore - ] + async def dummy_summary( + self, group_by: list[str], search: list[SearchSpec] + ) -> list[dict[str, str | int]]: + """Get a summary of the pilots.""" + return await self.summary(Cars, group_by=group_by, search=search) async def insert_owner(self, name: str) -> int: stmt = insert(Owners).values(name=name) diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index 89f2bb49d..8e563fb0e 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -5,7 +5,7 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Iterable -from sqlalchemy import bindparam, case, delete, func, insert, select, update +from sqlalchemy import bindparam, case, delete, insert, select, update if TYPE_CHECKING: from sqlalchemy.sql.elements import BindParameter @@ -13,8 +13,7 @@ from diracx.core.exceptions import InvalidQueryError from diracx.core.models import JobCommand, SearchSpec, SortSpec -from ..utils import BaseSQLDB, apply_search_filters, apply_sort_constraints -from ..utils.functions import utcnow +from ..utils import BaseSQLDB, _get_columns, utcnow from .schema import ( HeartBeatLoggingInfo, InputData, @@ -25,17 +24,6 @@ ) -def _get_columns(table, parameters): - columns = [x for x in table.columns] - if parameters: - if unrecognised_parameters := set(parameters) - set(table.columns.keys()): - raise InvalidQueryError( - f"Unrecognised parameters requested {unrecognised_parameters}" - ) - columns = [c for c in columns if c.name in parameters] - return columns - - class JobDB(BaseSQLDB): metadata = JobDBBase.metadata @@ -54,22 +42,13 @@ class JobDB(BaseSQLDB): # to find a way to make it dynamic jdl_2_db_parameters = ["JobName", "JobType", "JobGroup"] - async def summary(self, group_by, search) -> list[dict[str, str | int]]: + async def job_summary( + self, group_by: list[str], search: list[SearchSpec] + ) -> list[dict[str, str | int]]: """Get a summary of the jobs.""" - columns = _get_columns(Jobs.__table__, group_by) - - stmt = select(*columns, func.count(Jobs.job_id).label("count")) - stmt = apply_search_filters(Jobs.__table__.columns.__getitem__, stmt, search) - stmt = stmt.group_by(*columns) - - # Execute the query - return [ - dict(row._mapping) - async for row in (await self.conn.stream(stmt)) - if row.count > 0 # type: ignore - ] + return await self.summary(Jobs, group_by=group_by, search=search) - async def search( + async def search_jobs( self, parameters: list[str] | None, search: list[SearchSpec], @@ -80,34 +59,15 @@ async def search( page: int | None = None, ) -> tuple[int, list[dict[Any, Any]]]: """Search for jobs in the database.""" - # Find which columns to select - columns = _get_columns(Jobs.__table__, parameters) - - stmt = select(*columns) - - stmt = apply_search_filters(Jobs.__table__.columns.__getitem__, stmt, search) - stmt = apply_sort_constraints(Jobs.__table__.columns.__getitem__, stmt, sorts) - - if distinct: - stmt = stmt.distinct() - - # Calculate total count before applying pagination - total_count_subquery = stmt.alias() - total_count_stmt = select(func.count()).select_from(total_count_subquery) - total = (await self.conn.execute(total_count_stmt)).scalar_one() - - # Apply pagination - if page is not None: - if page < 1: - raise InvalidQueryError("Page must be a positive integer") - if per_page < 1: - raise InvalidQueryError("Per page must be a positive integer") - stmt = stmt.offset((page - 1) * per_page).limit(per_page) - - # Execute the query - return total, [ - dict(row._mapping) async for row in (await self.conn.stream(stmt)) - ] + return await self.search( + model=Jobs, + parameters=parameters, + search=search, + sorts=sorts, + distinct=distinct, + per_page=per_page, + page=page, + ) async def create_job(self, compressed_original_jdl: str): """Used to insert a new job with original JDL. Returns inserted job id.""" diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/db.py b/diracx-db/src/diracx/db/sql/pilot_agents/db.py deleted file mode 100644 index 954f081b1..000000000 --- a/diracx-db/src/diracx/db/sql/pilot_agents/db.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from datetime import datetime, timezone - -from sqlalchemy import insert - -from ..utils import BaseSQLDB -from .schema import PilotAgents, PilotAgentsDBBase - - -class PilotAgentsDB(BaseSQLDB): - """PilotAgentsDB class is a front-end to the PilotAgents Database.""" - - metadata = PilotAgentsDBBase.metadata - - async def add_pilot_references( - self, - pilot_ref: list[str], - vo: str, - grid_type: str = "DIRAC", - pilot_stamps: dict | None = None, - ) -> None: - if pilot_stamps is None: - pilot_stamps = {} - - now = datetime.now(tz=timezone.utc) - - # Prepare the list of dictionaries for bulk insertion - values = [ - { - "PilotJobReference": ref, - "VO": vo, - "GridType": grid_type, - "SubmissionTime": now, - "LastUpdateTime": now, - "Status": "Submitted", - "PilotStamp": pilot_stamps.get(ref, ""), - } - for ref in pilot_ref - ] - - # Insert multiple rows in a single execute call - stmt = insert(PilotAgents).values(values) - await self.conn.execute(stmt) - return diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/__init__.py b/diracx-db/src/diracx/db/sql/pilots/__init__.py similarity index 100% rename from diracx-db/src/diracx/db/sql/pilot_agents/__init__.py rename to diracx-db/src/diracx/db/sql/pilots/__init__.py diff --git a/diracx-db/src/diracx/db/sql/pilots/db.py b/diracx-db/src/diracx/db/sql/pilots/db.py new file mode 100644 index 000000000..86611553c --- /dev/null +++ b/diracx-db/src/diracx/db/sql/pilots/db.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +from sqlalchemy import bindparam +from sqlalchemy.exc import IntegrityError +from sqlalchemy.sql import delete, insert, update + +from diracx.core.exceptions import ( + PilotAlreadyAssociatedWithJobError, + PilotNotFoundError, +) +from diracx.core.models import ( + PilotFieldsMapping, + PilotStatus, + SearchSpec, + SortSpec, +) + +from ..utils import ( + BaseSQLDB, +) +from .schema import ( + JobToPilotMapping, + PilotAgents, + PilotAgentsDBBase, + PilotOutput, +) + + +class PilotAgentsDB(BaseSQLDB): + """PilotAgentsDB class is a front-end to the PilotAgents Database.""" + + metadata = PilotAgentsDBBase.metadata + + # ----------------------------- Insert Functions ----------------------------- + + async def add_pilots( + self, + pilot_stamps: list[str], + vo: str, + grid_type: str = "DIRAC", + grid_site: str = "Unknown", + destination_site: str = "NotAssigned", + pilot_references: dict[str, str] | None = None, + status: str = PilotStatus.SUBMITTED, + ): + """Bulk add pilots in the DB. + + If we can't find a pilot_reference associated with a stamp, we take the stamp by default. + """ + if pilot_references is None: + pilot_references = {} + + now = datetime.now(tz=timezone.utc) + + # Prepare the list of dictionaries for bulk insertion + values = [ + { + "PilotJobReference": pilot_references.get(stamp, stamp), + "VO": vo, + "GridType": grid_type, + "GridSite": grid_site, + "DestinationSite": destination_site, + "SubmissionTime": now, + "LastUpdateTime": now, + "Status": status, + "PilotStamp": stamp, + } + for stamp in pilot_stamps + ] + + # Insert multiple rows in a single execute call and use 'returning' to get primary keys + stmt = insert(PilotAgents).values(values) # Assuming 'id' is the primary key + + await self.conn.execute(stmt) + + async def add_jobs_to_pilot(self, job_to_pilot_mapping: list[dict[str, Any]]): + """Associate a pilot with jobs. + + job_to_pilot_mapping format: + ```py + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + ] + ``` + + Raises: + - PilotNotFoundError if a pilot_id is not associated with a pilot. + - PilotAlreadyAssociatedWithJobError if the pilot is already associated with one of the given jobs. + - NotImplementedError if the integrity error is not caught. + + **Important note**: We assume that a job exists. + + """ + # Insert multiple rows in a single execute call + stmt = insert(JobToPilotMapping).values(job_to_pilot_mapping) + + try: + await self.conn.execute(stmt) + except IntegrityError as e: + if "foreign key" in str(e.orig).lower(): + raise PilotNotFoundError( + data={"pilot_stamps": str(job_to_pilot_mapping)}, + detail="at least one of these pilots do not exist", + ) from e + + if ( + "duplicate entry" in str(e.orig).lower() + or "unique constraint" in str(e.orig).lower() + ): + raise PilotAlreadyAssociatedWithJobError( + data={"job_to_pilot_mapping": str(job_to_pilot_mapping)} + ) from e + + # Other errors to catch + raise NotImplementedError( + "Engine Specific error not caught" + str(e) + ) from e + + # ----------------------------- Delete Functions ----------------------------- + + async def delete_pilots(self, pilot_ids: list[int]): + """Destructive function. Delete pilots.""" + stmt = delete(PilotAgents).where(PilotAgents.pilot_id.in_(pilot_ids)) + + await self.conn.execute(stmt) + + async def remove_jobs_from_pilots(self, pilot_ids: list[int]): + """Destructive function. De-associate jobs and pilots.""" + stmt = delete(JobToPilotMapping).where( + JobToPilotMapping.pilot_id.in_(pilot_ids) + ) + + await self.conn.execute(stmt) + + async def delete_pilot_logs(self, pilot_ids: list[int]): + """Destructive function. Remove logs from pilots.""" + stmt = delete(PilotOutput).where(PilotOutput.pilot_id.in_(pilot_ids)) + + await self.conn.execute(stmt) + + # ----------------------------- Update Functions ----------------------------- + + async def update_pilot_fields( + self, pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] + ): + """Bulk update pilots with a mapping. + + pilot_stamps_to_fields_mapping format: + ```py + [ + { + "PilotStamp": pilot_stamp, + "BenchMark": bench_mark, + "StatusReason": pilot_reason, + "AccountingSent": accounting_sent, + "Status": status, + "CurrentJobID": current_job_id, + "Queue": queue, + ... + } + ] + ``` + + The mapping helps to update multiple fields at a time. + + Raises PilotNotFoundError if one of the pilots is not found. + """ + stmt = ( + update(PilotAgents) + .where(PilotAgents.pilot_stamp == bindparam("b_pilot_stamp")) + .values( + { + key: bindparam(key) + for key in pilot_stamps_to_fields_mapping[0] + .model_dump(exclude_none=True) + .keys() + if key != "PilotStamp" + } + ) + ) + + values = [ + { + **{"b_pilot_stamp": mapping.PilotStamp}, + **mapping.model_dump(exclude={"PilotStamp"}, exclude_none=True), + } + for mapping in pilot_stamps_to_fields_mapping + ] + + res = await self.conn.execute(stmt, values) + + if res.rowcount != len(pilot_stamps_to_fields_mapping): + raise PilotNotFoundError( + data={"mapping": str(pilot_stamps_to_fields_mapping)} + ) + + # ----------------------------- Search Functions ----------------------------- + + async def search_pilots( + self, + parameters: list[str] | None, + search: list[SearchSpec], + sorts: list[SortSpec], + *, + distinct: bool = False, + per_page: int = 100, + page: int | None = None, + ) -> tuple[int, list[dict[Any, Any]]]: + """Search for pilot information in the database.""" + return await self.search( + model=PilotAgents, + parameters=parameters, + search=search, + sorts=sorts, + distinct=distinct, + per_page=per_page, + page=page, + ) + + async def search_pilot_to_job_mapping( + self, + parameters: list[str] | None, + search: list[SearchSpec], + sorts: list[SortSpec], + *, + distinct: bool = False, + per_page: int = 100, + page: int | None = None, + ) -> tuple[int, list[dict[Any, Any]]]: + """Search for jobs that are associated with pilots.""" + return await self.search( + model=JobToPilotMapping, + parameters=parameters, + search=search, + sorts=sorts, + distinct=distinct, + per_page=per_page, + page=page, + ) + + async def pilot_summary( + self, group_by: list[str], search: list[SearchSpec] + ) -> list[dict[str, str | int]]: + """Get a summary of the pilots.""" + return await self.summary(PilotAgents, group_by=group_by, search=search) diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py b/diracx-db/src/diracx/db/sql/pilots/schema.py similarity index 92% rename from diracx-db/src/diracx/db/sql/pilot_agents/schema.py rename to diracx-db/src/diracx/db/sql/pilots/schema.py index bff7c460c..af087f1f8 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py +++ b/diracx-db/src/diracx/db/sql/pilots/schema.py @@ -10,6 +10,8 @@ ) from sqlalchemy.orm import declarative_base +from diracx.core.models import PilotStatus + from ..utils import Column, EnumBackedBool, NullColumn PilotAgentsDBBase = declarative_base() @@ -31,12 +33,13 @@ class PilotAgents(PilotAgentsDBBase): benchmark = Column("BenchMark", Double, default=0.0) submission_time = NullColumn("SubmissionTime", DateTime) last_update_time = NullColumn("LastUpdateTime", DateTime) - status = Column("Status", String(32), default="Unknown") + status = Column("Status", String(32), default=PilotStatus.UNKNOWN) status_reason = Column("StatusReason", String(255), default="Unknown") accounting_sent = Column("AccountingSent", EnumBackedBool(), default=False) __table_args__ = ( Index("PilotJobReference", "PilotJobReference"), + Index("PilotStamp", "PilotStamp"), Index("Status", "Status"), Index("Statuskey", "GridSite", "DestinationSite", "Status"), ) diff --git a/diracx-db/src/diracx/db/sql/utils/__init__.py b/diracx-db/src/diracx/db/sql/utils/__init__.py index 69b78b4bf..53b3f3c96 100644 --- a/diracx-db/src/diracx/db/sql/utils/__init__.py +++ b/diracx-db/src/diracx/db/sql/utils/__init__.py @@ -3,23 +3,29 @@ from .base import ( BaseSQLDB, SQLDBUnavailableError, + _get_columns, apply_search_filters, apply_sort_constraints, ) -from .functions import hash, substract_date, utcnow +from .functions import ( + hash, + substract_date, + utcnow, +) from .types import Column, DateNowColumn, EnumBackedBool, EnumColumn, NullColumn __all__ = ( - "utcnow", + "_get_columns", + "apply_search_filters", + "apply_sort_constraints", + "BaseSQLDB", "Column", - "NullColumn", "DateNowColumn", - "BaseSQLDB", "EnumBackedBool", "EnumColumn", - "apply_search_filters", - "apply_sort_constraints", - "substract_date", "hash", + "NullColumn", + "substract_date", "SQLDBUnavailableError", + "utcnow", ) diff --git a/diracx-db/src/diracx/db/sql/utils/base.py b/diracx-db/src/diracx/db/sql/utils/base.py index 6286364af..cdf753208 100644 --- a/diracx-db/src/diracx/db/sql/utils/base.py +++ b/diracx-db/src/diracx/db/sql/utils/base.py @@ -8,16 +8,16 @@ from collections.abc import AsyncIterator from contextvars import ContextVar from datetime import datetime -from typing import Self, cast +from typing import Any, Self, cast from pydantic import TypeAdapter -from sqlalchemy import DateTime, MetaData, select +from sqlalchemy import DateTime, MetaData, func, select from sqlalchemy.exc import OperationalError from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine from diracx.core.exceptions import InvalidQueryError from diracx.core.extensions import select_from_extension -from diracx.core.models import SortDirection +from diracx.core.models import SearchSpec, SortDirection, SortSpec from diracx.core.settings import SqlalchemyDsn from diracx.db.exceptions import DBUnavailableError @@ -227,6 +227,71 @@ async def ping(self): except OperationalError as e: raise SQLDBUnavailableError("Cannot ping the DB") from e + async def search( + self, + model: Any, + parameters: list[str] | None, + search: list[SearchSpec], + sorts: list[SortSpec], + *, + distinct: bool = False, + per_page: int = 100, + page: int | None = None, + ) -> tuple[int, list[dict[Any, Any]]]: + """Search in a SQL database, with filters.""" + # Find which columns to select + columns = _get_columns(model.__table__, parameters) + + stmt = select(*columns) + + stmt = apply_search_filters(model.__table__.columns.__getitem__, stmt, search) + stmt = apply_sort_constraints(model.__table__.columns.__getitem__, stmt, sorts) + + if distinct: + stmt = stmt.distinct() + + # Calculate total count before applying pagination + total_count_subquery = stmt.alias() + total_count_stmt = select(func.count()).select_from(total_count_subquery) + total = (await self.conn.execute(total_count_stmt)).scalar_one() + + # Apply pagination + if page is not None: + if page < 1: + raise InvalidQueryError("Page must be a positive integer") + if per_page < 1: + raise InvalidQueryError("Per page must be a positive integer") + stmt = stmt.offset((page - 1) * per_page).limit(per_page) + + # Execute the query + return total, [ + dict(row._mapping) async for row in (await self.conn.stream(stmt)) + ] + + async def summary( + self, model: Any, group_by: list[str], search: list[SearchSpec] + ) -> list[dict[str, str | int]]: + """Get a summary of a table.""" + columns = _get_columns(model.__table__, group_by) + + pk_columns = list(model.__table__.primary_key.columns) + if not pk_columns: + raise ValueError( + "Model has no primary key and no count_column was provided." + ) + count_col = pk_columns[0] + + stmt = select(*columns, func.count(count_col).label("count")) + stmt = apply_search_filters(model.__table__.columns.__getitem__, stmt, search) + stmt = stmt.group_by(*columns) + + # Execute the query + return [ + dict(row._mapping) + async for row in (await self.conn.stream(stmt)) + if row.count > 0 # type: ignore + ] + def find_time_resolution(value): if isinstance(value, datetime): @@ -258,6 +323,17 @@ def find_time_resolution(value): raise InvalidQueryError(f"Cannot parse {value=}") +def _get_columns(table, parameters): + columns = [x for x in table.columns] + if parameters: + if unrecognised_parameters := set(parameters) - set(table.columns.keys()): + raise InvalidQueryError( + f"Unrecognised parameters requested {unrecognised_parameters}" + ) + columns = [c for c in columns if c.name in parameters] + return columns + + def apply_search_filters(column_mapping, stmt, search): for query in search: try: diff --git a/diracx-db/src/diracx/db/sql/utils/functions.py b/diracx-db/src/diracx/db/sql/utils/functions.py index 34cb2a0da..536412406 100644 --- a/diracx-db/src/diracx/db/sql/utils/functions.py +++ b/diracx-db/src/diracx/db/sql/utils/functions.py @@ -2,16 +2,30 @@ import hashlib from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Sequence, Type -from sqlalchemy import DateTime, func +from sqlalchemy import DateTime, RowMapping, asc, desc, func, select +from sqlalchemy.ext.asyncio import AsyncConnection from sqlalchemy.ext.compiler import compiles -from sqlalchemy.sql import expression +from sqlalchemy.sql import ColumnElement, expression + +from diracx.core.exceptions import DiracFormattedError, InvalidQueryError if TYPE_CHECKING: from sqlalchemy.types import TypeEngine +def _get_columns(table, parameters): + columns = [x for x in table.columns] + if parameters: + if unrecognised_parameters := set(parameters) - set(table.columns.keys()): + raise InvalidQueryError( + f"Unrecognised parameters requested {unrecognised_parameters}" + ) + columns = [c for c in columns if c.name in parameters] + return columns + + class utcnow(expression.FunctionElement): # noqa: N801 type: TypeEngine = DateTime() inherit_cache: bool = True @@ -140,3 +154,73 @@ def substract_date(**kwargs: float) -> datetime: def hash(code: str): return hashlib.sha256(code.encode()).hexdigest() + + +def raw_hash(code: str): + return hashlib.sha256(code.encode()).digest() + + +async def fetch_records_bulk_or_raises( + conn: AsyncConnection, + model: Any, # Here, we currently must use `Any` because `declarative_base()` returns any + missing_elements_error_cls: Type[DiracFormattedError], + column_attribute_name: str, + column_name: str, + elements_to_fetch: list, + order_by: tuple[str, str] | None = None, + allow_more_than_one_result_per_input: bool = False, + allow_no_result: bool = False, +) -> Sequence[RowMapping]: + """Fetches a list of elements in a table, returns a list of elements. + All elements from the `element_to_fetch` **must** be present. + Raises the specified error if at least one is missing. + + Example: + fetch_records_bulk_or_raises( + self.conn, + PilotAgents, + PilotNotFound, + "pilot_id", + "PilotID", + [1,2,3] + ) + + """ + assert elements_to_fetch + + # Get the column that needs to be in elements_to_fetch + column = getattr(model, column_attribute_name) + + # Create the request + stmt = select(model).with_for_update().where(column.in_(elements_to_fetch)) + + if order_by: + column_name_to_order_by, direction = order_by + column_to_order_by = getattr(model, column_name_to_order_by) + + operator: ColumnElement = ( + asc(column_to_order_by) if direction == "asc" else desc(column_to_order_by) + ) + + stmt = stmt.order_by(operator) + + # Transform into dictionaries + raw_results = await conn.execute(stmt) + results = raw_results.mappings().all() + + # Detects duplicates + if not allow_more_than_one_result_per_input: + if len(results) > len(elements_to_fetch): + raise RuntimeError("Seems to have duplicates in the database.") + + if not allow_no_result: + # Checks if we have every elements we wanted + found_keys = {row[column_name] for row in results} + missing = set(elements_to_fetch) - found_keys + + if missing: + raise missing_elements_error_cls( + data={column_name: str(missing)}, detail=str(missing) + ) + + return results diff --git a/diracx-db/tests/jobs/test_job_db.py b/diracx-db/tests/jobs/test_job_db.py index e6ca58ce9..5ae49ad10 100644 --- a/diracx-db/tests/jobs/test_job_db.py +++ b/diracx-db/tests/jobs/test_job_db.py @@ -51,34 +51,34 @@ async def test_search_parameters(populated_job_db): """Test that we can search specific parameters for jobs in the database.""" async with populated_job_db as job_db: # Search a specific parameter: JobID - total, result = await job_db.search(["JobID"], [], []) + total, result = await job_db.search_jobs(["JobID"], [], []) assert total == 100 assert result for r in result: assert r.keys() == {"JobID"} # Search a specific parameter: Status - total, result = await job_db.search(["Status"], [], []) + total, result = await job_db.search_jobs(["Status"], [], []) assert total == 100 assert result for r in result: assert r.keys() == {"Status"} # Search for multiple parameters: JobID, Status - total, result = await job_db.search(["JobID", "Status"], [], []) + total, result = await job_db.search_jobs(["JobID", "Status"], [], []) assert total == 100 assert result for r in result: assert r.keys() == {"JobID", "Status"} # Search for a specific parameter but use distinct: Status - total, result = await job_db.search(["Status"], [], [], distinct=True) + total, result = await job_db.search_jobs(["Status"], [], [], distinct=True) assert total == 1 assert result # Search for a non-existent parameter: Dummy with pytest.raises(InvalidQueryError): - total, result = await job_db.search(["Dummy"], [], []) + total, result = await job_db.search_jobs(["Dummy"], [], []) async def test_search_conditions(populated_job_db): @@ -88,7 +88,7 @@ async def test_search_conditions(populated_job_db): condition = ScalarSearchSpec( parameter="JobID", operator=ScalarSearchOperator.EQUAL, value=3 ) - total, result = await job_db.search([], [condition], []) + total, result = await job_db.search_jobs([], [condition], []) assert total == 1 assert result assert len(result) == 1 @@ -98,7 +98,7 @@ async def test_search_conditions(populated_job_db): condition = ScalarSearchSpec( parameter="JobID", operator=ScalarSearchOperator.LESS_THAN, value=3 ) - total, result = await job_db.search([], [condition], []) + total, result = await job_db.search_jobs([], [condition], []) assert total == 2 assert result assert len(result) == 2 @@ -109,7 +109,7 @@ async def test_search_conditions(populated_job_db): condition = ScalarSearchSpec( parameter="JobID", operator=ScalarSearchOperator.NOT_EQUAL, value=3 ) - total, result = await job_db.search([], [condition], []) + total, result = await job_db.search_jobs([], [condition], []) assert total == 99 assert result assert len(result) == 99 @@ -119,14 +119,14 @@ async def test_search_conditions(populated_job_db): condition = ScalarSearchSpec( parameter="JobID", operator=ScalarSearchOperator.EQUAL, value=5873 ) - total, result = await job_db.search([], [condition], []) + total, result = await job_db.search_jobs([], [condition], []) assert not result # Search a specific vector condition: JobID in 1,2,3 condition = VectorSearchSpec( parameter="JobID", operator=VectorSearchOperator.IN, values=[1, 2, 3] ) - total, result = await job_db.search([], [condition], []) + total, result = await job_db.search_jobs([], [condition], []) assert total == 3 assert result assert len(result) == 3 @@ -136,7 +136,7 @@ async def test_search_conditions(populated_job_db): condition = VectorSearchSpec( parameter="JobID", operator=VectorSearchOperator.IN, values=[1, 2, 5873] ) - total, result = await job_db.search([], [condition], []) + total, result = await job_db.search_jobs([], [condition], []) assert total == 2 assert result assert len(result) == 2 @@ -146,7 +146,7 @@ async def test_search_conditions(populated_job_db): condition = VectorSearchSpec( parameter="JobID", operator=VectorSearchOperator.NOT_IN, values=[1, 2, 3] ) - total, result = await job_db.search([], [condition], []) + total, result = await job_db.search_jobs([], [condition], []) assert total == 97 assert result assert len(result) == 97 @@ -156,7 +156,7 @@ async def test_search_conditions(populated_job_db): condition = VectorSearchSpec( parameter="JobID", operator=VectorSearchOperator.NOT_IN, values=[1, 2, 5873] ) - total, result = await job_db.search([], [condition], []) + total, result = await job_db.search_jobs([], [condition], []) assert total == 98 assert result assert len(result) == 98 @@ -169,7 +169,7 @@ async def test_search_conditions(populated_job_db): condition2 = VectorSearchSpec( parameter="JobID", operator=VectorSearchOperator.IN, values=[4, 5, 6] ) - total, result = await job_db.search([], [condition1, condition2], []) + total, result = await job_db.search_jobs([], [condition1, condition2], []) assert total == 1 assert result assert len(result) == 1 @@ -183,7 +183,7 @@ async def test_search_conditions(populated_job_db): condition2 = VectorSearchSpec( parameter="JobID", operator=VectorSearchOperator.IN, values=[4, 5, 6] ) - total, result = await job_db.search([], [condition1, condition2], []) + total, result = await job_db.search_jobs([], [condition1, condition2], []) assert total == 0 assert not result @@ -193,7 +193,7 @@ async def test_search_sorts(populated_job_db): async with populated_job_db as job_db: # Search and sort by JobID in ascending order sort = SortSpec(parameter="JobID", direction=SortDirection.ASC) - total, result = await job_db.search([], [], [sort]) + total, result = await job_db.search_jobs([], [], [sort]) assert total == 100 assert result for i, r in enumerate(result): @@ -201,7 +201,7 @@ async def test_search_sorts(populated_job_db): # Search and sort by JobID in descending order sort = SortSpec(parameter="JobID", direction=SortDirection.DESC) - total, result = await job_db.search([], [], [sort]) + total, result = await job_db.search_jobs([], [], [sort]) assert total == 100 assert result for i, r in enumerate(result): @@ -209,7 +209,7 @@ async def test_search_sorts(populated_job_db): # Search and sort by Owner in ascending order sort = SortSpec(parameter="Owner", direction=SortDirection.ASC) - total, result = await job_db.search([], [], [sort]) + total, result = await job_db.search_jobs([], [], [sort]) assert total == 100 assert result # Assert that owner10 is before owner2 because of the lexicographical order @@ -218,7 +218,7 @@ async def test_search_sorts(populated_job_db): # Search and sort by Owner in descending order sort = SortSpec(parameter="Owner", direction=SortDirection.DESC) - total, result = await job_db.search([], [], [sort]) + total, result = await job_db.search_jobs([], [], [sort]) assert total == 100 assert result # Assert that owner10 is before owner2 because of the lexicographical order @@ -228,7 +228,7 @@ async def test_search_sorts(populated_job_db): # Search and sort by OwnerGroup in ascending order and JobID in descending order sort1 = SortSpec(parameter="OwnerGroup", direction=SortDirection.ASC) sort2 = SortSpec(parameter="JobID", direction=SortDirection.DESC) - total, result = await job_db.search([], [], [sort1, sort2]) + total, result = await job_db.search_jobs([], [], [sort1, sort2]) assert total == 100 assert result assert result[0]["OwnerGroup"] == "owner_group1" @@ -241,45 +241,45 @@ async def test_search_pagination(populated_job_db): """Test that we can search for jobs in the database.""" async with populated_job_db as job_db: # Search for the first 10 jobs - total, result = await job_db.search([], [], [], per_page=10, page=1) + total, result = await job_db.search_jobs([], [], [], per_page=10, page=1) assert total == 100 assert result assert len(result) == 10 assert result[0]["JobID"] == 1 # Search for the second 10 jobs - total, result = await job_db.search([], [], [], per_page=10, page=2) + total, result = await job_db.search_jobs([], [], [], per_page=10, page=2) assert total == 100 assert result assert len(result) == 10 assert result[0]["JobID"] == 11 # Search for the last 10 jobs - total, result = await job_db.search([], [], [], per_page=10, page=10) + total, result = await job_db.search_jobs([], [], [], per_page=10, page=10) assert total == 100 assert result assert len(result) == 10 assert result[0]["JobID"] == 91 # Search for the second 50 jobs - total, result = await job_db.search([], [], [], per_page=50, page=2) + total, result = await job_db.search_jobs([], [], [], per_page=50, page=2) assert total == 100 assert result assert len(result) == 50 assert result[0]["JobID"] == 51 # Invalid page number - total, result = await job_db.search([], [], [], per_page=10, page=11) + total, result = await job_db.search_jobs([], [], [], per_page=10, page=11) assert total == 100 assert not result # Invalid page number with pytest.raises(InvalidQueryError): - result = await job_db.search([], [], [], per_page=10, page=0) + result = await job_db.search_jobs([], [], [], per_page=10, page=0) # Invalid per_page number with pytest.raises(InvalidQueryError): - result = await job_db.search([], [], [], per_page=0, page=1) + result = await job_db.search_jobs([], [], [], per_page=0, page=1) async def test_set_job_commands_invalid_job_id(job_db: JobDB): diff --git a/diracx-db/tests/opensearch/test_search.py b/diracx-db/tests/opensearch/test_search.py index 93998ac3e..8013edd9a 100644 --- a/diracx-db/tests/opensearch/test_search.py +++ b/diracx-db/tests/opensearch/test_search.py @@ -120,15 +120,15 @@ async def prefilled_db(request): async def test_specified_parameters(prefilled_db: DummyOSDB): - results = await prefilled_db.search(None, [], []) - assert len(results) == 3 + total, results = await prefilled_db.search(None, [], []) + assert total == 3 assert DOC1 in results and DOC2 in results and DOC3 in results - results = await prefilled_db.search([], [], []) - assert len(results) == 3 + total, results = await prefilled_db.search([], [], []) + assert total == 3 assert DOC1 in results and DOC2 in results and DOC3 in results - results = await prefilled_db.search(["IntField"], [], []) + total, results = await prefilled_db.search(["IntField"], [], []) expected_results = [] for doc in [DOC1, DOC2, DOC3]: expected_doc = {key: doc[key] for key in {"IntField"}} @@ -136,58 +136,67 @@ async def test_specified_parameters(prefilled_db: DummyOSDB): # If it is the all() check below no longer makes sense assert expected_doc not in expected_results expected_results.append(expected_doc) - assert len(results) == len(expected_results) + assert total == len(expected_results) assert all(result in expected_results for result in results) - results = await prefilled_db.search(["IntField", "UnknownField"], [], []) + total, results = await prefilled_db.search(["IntField", "UnknownField"], [], []) expected_results = [ {"IntField": DOC1["IntField"], "UnknownField": DOC1["UnknownField"]}, {"IntField": DOC2["IntField"], "UnknownField": DOC2["UnknownField"]}, {"IntField": DOC3["IntField"]}, ] - assert len(results) == len(expected_results) + assert total == len(expected_results) assert all(result in expected_results for result in results) async def test_pagination_asc(prefilled_db: DummyOSDB): sort = [{"parameter": "IntField", "direction": "asc"}] - results = await prefilled_db.search(None, [], sort) + total, results = await prefilled_db.search(None, [], sort) assert results == [DOC3, DOC2, DOC1] + assert total == 3 # Pagination has no effect if a specific page isn't requested - results = await prefilled_db.search(None, [], sort, per_page=2) + total, results = await prefilled_db.search(None, [], sort, per_page=2) assert results == [DOC3, DOC2, DOC1] + assert total == 3 - results = await prefilled_db.search(None, [], sort, per_page=2, page=1) + total, results = await prefilled_db.search(None, [], sort, per_page=2, page=1) assert results == [DOC3, DOC2] + assert total == 3 - results = await prefilled_db.search(None, [], sort, per_page=2, page=2) + total, results = await prefilled_db.search(None, [], sort, per_page=2, page=2) assert results == [DOC1] + assert total == 3 - results = await prefilled_db.search(None, [], sort, per_page=2, page=3) + total, results = await prefilled_db.search(None, [], sort, per_page=2, page=3) assert results == [] + assert total == 3 - results = await prefilled_db.search(None, [], sort, per_page=1, page=1) + total, results = await prefilled_db.search(None, [], sort, per_page=1, page=1) assert results == [DOC3] + assert total == 3 - results = await prefilled_db.search(None, [], sort, per_page=1, page=2) + total, results = await prefilled_db.search(None, [], sort, per_page=1, page=2) assert results == [DOC2] + assert total == 3 - results = await prefilled_db.search(None, [], sort, per_page=1, page=3) + total, results = await prefilled_db.search(None, [], sort, per_page=1, page=3) assert results == [DOC1] + assert total == 3 - results = await prefilled_db.search(None, [], sort, per_page=1, page=4) + total, results = await prefilled_db.search(None, [], sort, per_page=1, page=4) assert results == [] + assert total == 3 async def test_pagination_desc(prefilled_db: DummyOSDB): sort = [{"parameter": "IntField", "direction": "desc"}] - results = await prefilled_db.search(None, [], sort, per_page=2, page=1) + total, results = await prefilled_db.search(None, [], sort, per_page=2, page=1) assert results == [DOC1, DOC2] - results = await prefilled_db.search(None, [], sort, per_page=2, page=2) + total, results = await prefilled_db.search(None, [], sort, per_page=2, page=2) assert results == [DOC3] @@ -195,22 +204,26 @@ async def test_eq_filter_long(prefilled_db: DummyOSDB): part = {"parameter": "IntField", "operator": "eq"} # Search for an ID which doesn't exist - results = await prefilled_db.search(None, [part | {"value": "78"}], []) + total, results = await prefilled_db.search(None, [part | {"value": "78"}], []) assert results == [] + assert total == 0 # Check the DB contains what we expect when not filtering - results = await prefilled_db.search(None, [], []) - assert len(results) == 3 + total, results = await prefilled_db.search(None, [], []) + assert total == 3 assert DOC1 in results assert DOC2 in results assert DOC3 in results # Search separately for the two documents which do exist - results = await prefilled_db.search(None, [part | {"value": "1234"}], []) + total, results = await prefilled_db.search(None, [part | {"value": "1234"}], []) assert results == [DOC1] - results = await prefilled_db.search(None, [part | {"value": "679"}], []) + assert total == 1 + total, results = await prefilled_db.search(None, [part | {"value": "679"}], []) assert results == [DOC2] - results = await prefilled_db.search(None, [part | {"value": "42"}], []) + assert total == 1 + total, results = await prefilled_db.search(None, [part | {"value": "42"}], []) + assert total == 1 assert results == [DOC3] @@ -218,80 +231,97 @@ async def test_operators_long(prefilled_db: DummyOSDB): part = {"parameter": "IntField"} query = part | {"operator": "neq", "value": "1234"} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC2["IntField"], DOC3["IntField"]} + assert total == 2 query = part | {"operator": "in", "values": ["1234", "42"]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC1["IntField"], DOC3["IntField"]} + assert total == 2 query = part | {"operator": "not in", "values": ["1234", "42"]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC2["IntField"]} + assert total == 1 query = part | {"operator": "lt", "value": "1234"} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC2["IntField"], DOC3["IntField"]} + assert total == 2 query = part | {"operator": "lt", "value": "679"} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC3["IntField"]} + assert total == 1 query = part | {"operator": "gt", "value": "1234"} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == set() + assert total == 0 query = part | {"operator": "lt", "value": "42"} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == set() + assert total == 0 async def test_operators_date(prefilled_db: DummyOSDB): part = {"parameter": "DateField"} query = part | {"operator": "eq", "value": DOC3["DateField"]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC3["IntField"]} + assert total == 1 query = part | {"operator": "neq", "value": DOC2["DateField"]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC1["IntField"], DOC3["IntField"]} + assert total == 2 doc1_time = DOC1["DateField"].strftime("%Y-%m-%dT%H:%M") doc2_time = DOC2["DateField"].strftime("%Y-%m-%dT%H:%M") doc3_time = DOC3["DateField"].strftime("%Y-%m-%dT%H:%M") query = part | {"operator": "in", "values": [doc1_time, doc2_time]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC1["IntField"], DOC2["IntField"]} + assert total == 2 query = part | {"operator": "not in", "values": [doc1_time, doc2_time]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC3["IntField"]} + assert total == 1 query = part | {"operator": "lt", "value": doc1_time} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC2["IntField"], DOC3["IntField"]} + assert total == 2 query = part | {"operator": "lt", "value": doc3_time} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC2["IntField"]} + assert total == 1 query = part | {"operator": "lt", "value": doc2_time} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == set() + assert total == 0 query = part | {"operator": "gt", "value": doc1_time} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == set() + assert total == 0 query = part | {"operator": "gt", "value": doc3_time} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC1["IntField"]} + assert total == 1 query = part | {"operator": "gt", "value": doc2_time} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC1["IntField"], DOC3["IntField"]} + assert total == 2 @pytest.mark.parametrize( @@ -312,24 +342,28 @@ async def test_operators_date_partial_doc1(prefilled_db: DummyOSDB, date_format: formatted_date = DOC1["DateField"].strftime(date_format) query = {"parameter": "DateField", "operator": "eq", "value": formatted_date} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC1["IntField"]} + assert total == 1 query = {"parameter": "DateField", "operator": "neq", "value": formatted_date} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC2["IntField"], DOC3["IntField"]} + assert total == 2 async def test_operators_keyword(prefilled_db: DummyOSDB): part = {"parameter": "KeywordField1"} query = part | {"operator": "eq", "value": DOC1["KeywordField1"]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC1["IntField"], DOC2["IntField"]} + assert total == 2 query = part | {"operator": "neq", "value": DOC1["KeywordField1"]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC3["IntField"]} + assert total == 1 part = {"parameter": "KeywordField0"} @@ -337,27 +371,31 @@ async def test_operators_keyword(prefilled_db: DummyOSDB): "operator": "in", "values": [DOC1["KeywordField0"], DOC3["KeywordField0"]], } - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC1["IntField"], DOC3["IntField"]} + assert total == 2 query = part | {"operator": "in", "values": ["missing"]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == set() + assert total == 0 query = part | { "operator": "not in", "values": [DOC1["KeywordField0"], DOC3["KeywordField0"]], } - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == {DOC2["IntField"]} + assert total == 1 query = part | {"operator": "not in", "values": ["missing"]} - results = await prefilled_db.search(["IntField"], [query], []) + total, results = await prefilled_db.search(["IntField"], [query], []) assert {x["IntField"] for x in results} == { DOC1["IntField"], DOC2["IntField"], DOC3["IntField"], } + assert total == 3 # The MockOSDBMixin doesn't validate if types are indexed correctly if not isinstance(prefilled_db, MockOSDBMixin): @@ -387,36 +425,42 @@ async def test_unindexed_field(prefilled_db: DummyOSDB): async def test_sort_long(prefilled_db: DummyOSDB): - results = await prefilled_db.search( + total, results = await prefilled_db.search( None, [], [{"parameter": "IntField", "direction": "asc"}] ) assert results == [DOC3, DOC2, DOC1] - results = await prefilled_db.search( + assert total == 3 + total, results = await prefilled_db.search( None, [], [{"parameter": "IntField", "direction": "desc"}] ) assert results == [DOC1, DOC2, DOC3] + assert total == 3 async def test_sort_date(prefilled_db: DummyOSDB): - results = await prefilled_db.search( + total, results = await prefilled_db.search( None, [], [{"parameter": "DateField", "direction": "asc"}] ) assert results == [DOC2, DOC3, DOC1] - results = await prefilled_db.search( + assert total == 3 + total, results = await prefilled_db.search( None, [], [{"parameter": "DateField", "direction": "desc"}] ) assert results == [DOC1, DOC3, DOC2] + assert total == 3 async def test_sort_keyword(prefilled_db: DummyOSDB): - results = await prefilled_db.search( + total, results = await prefilled_db.search( None, [], [{"parameter": "KeywordField0", "direction": "asc"}] ) assert results == [DOC1, DOC3, DOC2] - results = await prefilled_db.search( + assert total == 3 + total, results = await prefilled_db.search( None, [], [{"parameter": "KeywordField0", "direction": "desc"}] ) assert results == [DOC2, DOC3, DOC1] + assert total == 3 async def test_sort_text(prefilled_db: DummyOSDB): @@ -436,7 +480,7 @@ async def test_sort_unknown(prefilled_db: DummyOSDB): async def test_sort_multiple(prefilled_db: DummyOSDB): - results = await prefilled_db.search( + total, results = await prefilled_db.search( None, [], [ @@ -445,8 +489,9 @@ async def test_sort_multiple(prefilled_db: DummyOSDB): ], ) assert results == [DOC2, DOC1, DOC3] + assert total == 3 - results = await prefilled_db.search( + total, results = await prefilled_db.search( None, [], [ @@ -455,8 +500,9 @@ async def test_sort_multiple(prefilled_db: DummyOSDB): ], ) assert results == [DOC1, DOC2, DOC3] + assert total == 3 - results = await prefilled_db.search( + total, results = await prefilled_db.search( None, [], [ @@ -465,8 +511,9 @@ async def test_sort_multiple(prefilled_db: DummyOSDB): ], ) assert results == [DOC3, DOC2, DOC1] + assert total == 3 - results = await prefilled_db.search( + total, results = await prefilled_db.search( None, [], [ @@ -475,3 +522,4 @@ async def test_sort_multiple(prefilled_db: DummyOSDB): ], ) assert results == [DOC3, DOC2, DOC1] + assert total == 3 diff --git a/diracx-db/tests/pilot_agents/test_pilot_agents_db.py b/diracx-db/tests/pilot_agents/test_pilot_agents_db.py deleted file mode 100644 index 3ca989885..000000000 --- a/diracx-db/tests/pilot_agents/test_pilot_agents_db.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -import pytest - -from diracx.db.sql.pilot_agents.db import PilotAgentsDB - - -@pytest.fixture -async def pilot_agents_db(tmp_path) -> PilotAgentsDB: - agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") - async with agents_db.engine_context(): - async with agents_db.engine.begin() as conn: - await conn.run_sync(agents_db.metadata.create_all) - yield agents_db - - -async def test_insert_and_select(pilot_agents_db: PilotAgentsDB): - async with pilot_agents_db as pilot_agents_db: - # Add a pilot reference - refs = [f"ref_{i}" for i in range(10)] - stamps = [f"stamp_{i}" for i in range(10)] - stamp_dict = dict(zip(refs, stamps)) - - await pilot_agents_db.add_pilot_references( - refs, "test_vo", grid_type="DIRAC", pilot_stamps=stamp_dict - ) - - await pilot_agents_db.add_pilot_references( - refs, "test_vo", grid_type="DIRAC", pilot_stamps=None - ) diff --git a/diracx-db/tests/pilot_agents/__init__.py b/diracx-db/tests/pilots/__init__.py similarity index 100% rename from diracx-db/tests/pilot_agents/__init__.py rename to diracx-db/tests/pilots/__init__.py diff --git a/diracx-db/tests/pilots/test_pilot_management.py b/diracx-db/tests/pilots/test_pilot_management.py new file mode 100644 index 000000000..1e7397b39 --- /dev/null +++ b/diracx-db/tests/pilots/test_pilot_management.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest + +from diracx.core.exceptions import ( + PilotAlreadyAssociatedWithJobError, +) +from diracx.core.models import ( + PilotFieldsMapping, + PilotStatus, +) +from diracx.db.sql.pilots.db import PilotAgentsDB + +from .utils import ( + add_stamps, # noqa: F401 + create_old_pilots_environment, # noqa: F401 + create_timed_pilots, # noqa: F401 + get_pilot_jobs_ids_by_pilot_id, + get_pilots_by_stamp, +) + +MAIN_VO = "lhcb" +N = 100 + + +@pytest.fixture +async def pilot_db(tmp_path): + agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") + async with agents_db.engine_context(): + async with agents_db.engine.begin() as conn: + await conn.run_sync(agents_db.metadata.create_all) + yield agents_db + + +@pytest.mark.asyncio +async def test_insert_and_select(pilot_db: PilotAgentsDB): + async with pilot_db as pilot_db: + # Add pilots + refs = [f"ref_{i}" for i in range(10)] + stamps = [f"stamp_{i}" for i in range(10)] + pilot_references = dict(zip(stamps, refs)) + + await pilot_db.add_pilots( + stamps, MAIN_VO, grid_type="DIRAC", pilot_references=pilot_references + ) + + # Accept duplicates because it is checked by the logic + await pilot_db.add_pilots( + stamps, MAIN_VO, grid_type="DIRAC", pilot_references=None + ) + + +@pytest.mark.asyncio +async def test_insert_and_delete(pilot_db: PilotAgentsDB): + async with pilot_db as pilot_db: + # Add pilots + refs = [f"ref_{i}" for i in range(2)] + stamps = [f"stamp_{i}" for i in range(2)] + pilot_references = dict(zip(stamps, refs)) + + await pilot_db.add_pilots( + stamps, MAIN_VO, grid_type="DIRAC", pilot_references=pilot_references + ) + + # Works, the pilots exists + res = await get_pilots_by_stamp(pilot_db, [stamps[0]]) + await get_pilots_by_stamp(pilot_db, [stamps[0]]) + + # We delete the first pilot + await pilot_db.delete_pilots([res[0]["PilotID"]]) + + # We get the 2nd pilot that is not delete (no error) + await get_pilots_by_stamp(pilot_db, [stamps[1]]) + # We get the 1st pilot that is delete (error) + + assert not await get_pilots_by_stamp(pilot_db, [stamps[0]]) + + +@pytest.mark.asyncio +async def test_insert_and_select_single_then_modify(pilot_db: PilotAgentsDB): + async with pilot_db as pilot_db: + pilot_stamp = "stamp-test" + await pilot_db.add_pilots( + vo=MAIN_VO, + pilot_stamps=[pilot_stamp], + grid_type="grid-type", + ) + + res = await get_pilots_by_stamp(pilot_db, [pilot_stamp]) + assert len(res) == 1 + pilot = res[0] + + # Assert values + assert pilot["VO"] == MAIN_VO + assert pilot["PilotStamp"] == pilot_stamp + assert pilot["GridType"] == "grid-type" + assert pilot["BenchMark"] == 0.0 + assert pilot["Status"] == PilotStatus.SUBMITTED + assert pilot["StatusReason"] == "Unknown" + assert not pilot["AccountingSent"] + + # + # Modify a pilot, then check if every change is done + # + await pilot_db.update_pilot_fields( + [ + PilotFieldsMapping( + PilotStamp=pilot_stamp, + BenchMark=1.0, + StatusReason="NewReason", + AccountingSent=True, + Status=PilotStatus.WAITING, + ) + ] + ) + + res = await get_pilots_by_stamp(pilot_db, [pilot_stamp]) + assert len(res) == 1 + pilot = res[0] + + # Set values + assert pilot["VO"] == MAIN_VO + assert pilot["PilotStamp"] == pilot_stamp + assert pilot["GridType"] == "grid-type" + assert pilot["BenchMark"] == 1.0 + assert pilot["Status"] == PilotStatus.WAITING + assert pilot["StatusReason"] == "NewReason" + assert pilot["AccountingSent"] + + +@pytest.mark.asyncio +async def test_associate_pilot_with_job_and_get_it(pilot_db: PilotAgentsDB): + """We will proceed in few steps. + + 1. Create a pilot + 2. Verify that he is not associated with any job + 3. Associate with jobs + 4. Verify that he is associate with this job + 5. Associate with jobs that he already has and two that he has not + 6. Associate with jobs that he has not, but were involved in a crash + """ + async with pilot_db as pilot_db: + pilot_stamp = "stamp-test" + # Add pilot + await pilot_db.add_pilots( + vo=MAIN_VO, + pilot_stamps=[pilot_stamp], + grid_type="grid-type", + ) + + res = await get_pilots_by_stamp(pilot_db, [pilot_stamp]) + assert len(res) == 1 + pilot = res[0] + pilot_id = pilot["PilotID"] + + # Verify that he has no jobs + assert len(await get_pilot_jobs_ids_by_pilot_id(pilot_db, pilot_id)) == 0 + + now = datetime.now(tz=timezone.utc) + + # Associate pilot with jobs + pilot_jobs = [1, 2, 3] + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + for job_id in pilot_jobs + ] + await pilot_db.add_jobs_to_pilot(job_to_pilot_mapping) + + # Verify that he has all jobs + db_jobs = await get_pilot_jobs_ids_by_pilot_id(pilot_db, pilot_id) + # We test both length and if every job is included if for any reason we have duplicates + assert all(job in db_jobs for job in pilot_jobs) + assert len(pilot_jobs) == len(db_jobs) + + # Associate pilot with a job that he already has, and one that he has not + pilot_jobs = [10, 1, 5] + with pytest.raises(PilotAlreadyAssociatedWithJobError): + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + for job_id in pilot_jobs + ] + await pilot_db.add_jobs_to_pilot(job_to_pilot_mapping) + + # Associate pilot with jobs that he has not, but was previously in an error + # To test that the rollback worked + pilot_jobs = [5, 10] + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + for job_id in pilot_jobs + ] + await pilot_db.add_jobs_to_pilot(job_to_pilot_mapping) diff --git a/diracx-db/tests/pilots/test_query.py b/diracx-db/tests/pilots/test_query.py new file mode 100644 index 000000000..be80f0179 --- /dev/null +++ b/diracx-db/tests/pilots/test_query.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +import pytest + +from diracx.core.exceptions import InvalidQueryError +from diracx.core.models import ( + PilotFieldsMapping, + PilotStatus, + ScalarSearchOperator, + ScalarSearchSpec, + SortDirection, + SortSpec, + VectorSearchOperator, + VectorSearchSpec, +) +from diracx.db.sql.pilots.db import PilotAgentsDB + +MAIN_VO = "lhcb" +N = 100 + + +@pytest.fixture +async def pilot_db(tmp_path): + agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") + async with agents_db.engine_context(): + async with agents_db.engine.begin() as conn: + await conn.run_sync(agents_db.metadata.create_all) + yield agents_db + + +PILOT_REASONS = [ + "I was sick", + "I can't, I have a pony.", + "I was shopping", + "I was sleeping", +] + +PILOT_STATUSES = list(PilotStatus) + + +@pytest.fixture +async def populated_pilot_db(pilot_db): + async with pilot_db as pilot_db: + # Add pilots + refs = [f"ref_{i + 1}" for i in range(N)] + stamps = [f"stamp_{i + 1}" for i in range(N)] + pilot_references = dict(zip(stamps, refs)) + + vo = MAIN_VO + + await pilot_db.add_pilots( + stamps, vo, grid_type="DIRAC", pilot_references=pilot_references + ) + + await pilot_db.update_pilot_fields( + [ + PilotFieldsMapping( + PilotStamp=pilot_stamp, + BenchMark=i**2, + StatusReason=PILOT_REASONS[i % len(PILOT_REASONS)], + AccountingSent=True, + Status=PILOT_STATUSES[i % len(PILOT_STATUSES)], + CurrentJobID=i, + Queue=f"queue_{i}", + ) + for i, pilot_stamp in enumerate(stamps) + ] + ) + + yield pilot_db + + +async def test_search_parameters(populated_pilot_db): + """Test that we can search specific parameters for pilots in the database.""" + async with populated_pilot_db as pilot_db: + # Search a specific parameter: PilotID + total, result = await pilot_db.search_pilots(["PilotID"], [], []) + assert total == N + assert result + for r in result: + assert r.keys() == {"PilotID"} + + # Search a specific parameter: Status + total, result = await pilot_db.search_pilots(["Status"], [], []) + assert total == N + assert result + for r in result: + assert r.keys() == {"Status"} + + # Search for multiple parameters: PilotID, Status + total, result = await pilot_db.search_pilots(["PilotID", "Status"], [], []) + assert total == N + assert result + for r in result: + assert r.keys() == {"PilotID", "Status"} + + # Search for a specific parameter but use distinct: Status + total, result = await pilot_db.search_pilots(["Status"], [], [], distinct=True) + assert total == len(PILOT_STATUSES) + assert result + + # Search for a non-existent parameter: Dummy + with pytest.raises(InvalidQueryError): + total, result = await pilot_db.search_pilots(["Dummy"], [], []) + + +async def test_search_conditions(populated_pilot_db): + """Test that we can search for specific pilots in the database.""" + async with populated_pilot_db as pilot_db: + # Search a specific scalar condition: PilotID eq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=3 + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 1 + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 3 + + # Search a specific scalar condition: PilotID lt 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.LESS_THAN, value=3 + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 2 + assert result + assert len(result) == 2 + assert result[0]["PilotID"] == 1 + assert result[1]["PilotID"] == 2 + + # Search a specific scalar condition: PilotID neq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.NOT_EQUAL, value=3 + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 99 + assert result + assert len(result) == 99 + assert all(r["PilotID"] != 3 for r in result) + + # Search a specific scalar condition: PilotID eq 5873 (does not exist) + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=5873 + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert not result + + # Search a specific vector condition: PilotID in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 3] + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 3 + assert result + assert len(result) == 3 + assert all(r["PilotID"] in [1, 2, 3] for r in result) + + # Search a specific vector condition: PilotID in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 5873] + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 2 + assert result + assert len(result) == 2 + assert all(r["PilotID"] in [1, 2] for r in result) + + # Search a specific vector condition: PilotID not in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.NOT_IN, values=[1, 2, 3] + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 97 + assert result + assert len(result) == 97 + assert all(r["PilotID"] not in [1, 2, 3] for r in result) + + # Search a specific vector condition: PilotID not in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", + operator=VectorSearchOperator.NOT_IN, + values=[1, 2, 5873], + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 98 + assert result + assert len(result) == 98 + assert all(r["PilotID"] not in [1, 2] for r in result) + + # Search for multiple conditions based on different parameters: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotStamp", operator=ScalarSearchOperator.EQUAL, value="stamp_5" + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + total, result = await pilot_db.search_pilots([], [condition1, condition2], []) + assert total == 1 + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 5 + assert result[0]["PilotStamp"] == "stamp_5" + + # Search for multiple conditions based on the same parameter: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=70 + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + total, result = await pilot_db.search_pilots([], [condition1, condition2], []) + assert total == 0 + assert not result + + +async def test_search_sorts(populated_pilot_db): + """Test that we can search for pilots in the database and sort the results.""" + async with populated_pilot_db as pilot_db: + # Search and sort by PilotID in ascending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.ASC) + total, result = await pilot_db.search_pilots([], [], [sort]) + assert total == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == i + 1 + + # Search and sort by PilotID in descending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + total, result = await pilot_db.search_pilots([], [], [sort]) + assert total == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == N - i + + # Search and sort by PilotStamp in ascending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + total, result = await pilot_db.search_pilots([], [], [sort]) + assert total == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[2]["PilotStamp"] == "stamp_100" + assert result[12]["PilotStamp"] == "stamp_2" + + # Search and sort by PilotStamp in descending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.DESC) + total, result = await pilot_db.search_pilots([], [], [sort]) + assert total == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[97]["PilotStamp"] == "stamp_100" + assert result[87]["PilotStamp"] == "stamp_2" + + # Search and sort by PilotStamp in ascending order and PilotID in descending order + sort1 = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + sort2 = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + total, result = await pilot_db.search_pilots([], [], [sort1, sort2]) + assert total == N + assert result + assert result[0]["PilotStamp"] == "stamp_1" + assert result[0]["PilotID"] == 1 + assert result[99]["PilotStamp"] == "stamp_99" + assert result[99]["PilotID"] == 99 + + +@pytest.mark.parametrize( + "per_page, page, expected_len, expected_first_id, expect_exception", + [ + (10, 1, 10, 1, None), # Page 1 + (10, 2, 10, 11, None), # Page 2 + (10, 10, 10, 91, None), # Page 10 + (50, 2, 50, 51, None), # Page 2 with 50 per page + (10, 11, 0, None, None), # Page beyond range, should return empty + (10, 0, None, None, InvalidQueryError), # Invalid page + (0, 1, None, None, InvalidQueryError), # Invalid per_page + ], +) +async def test_search_pagination( + populated_pilot_db, + per_page, + page, + expected_len, + expected_first_id, + expect_exception, +): + """Test pagination logic in pilot search.""" + async with populated_pilot_db as pilot_db: + if expect_exception: + with pytest.raises(expect_exception): + await pilot_db.search_pilots([], [], [], per_page=per_page, page=page) + else: + total, result = await pilot_db.search_pilots( + [], [], [], per_page=per_page, page=page + ) + assert total == N + if expected_len == 0: + assert not result + else: + assert result + assert len(result) == expected_len + assert result[0]["PilotID"] == expected_first_id diff --git a/diracx-db/tests/pilots/utils.py b/diracx-db/tests/pilots/utils.py new file mode 100644 index 000000000..793310d0d --- /dev/null +++ b/diracx-db/tests/pilots/utils.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +import pytest +from sqlalchemy import update + +from diracx.core.models import ( + ScalarSearchOperator, + ScalarSearchSpec, + VectorSearchOperator, + VectorSearchSpec, +) +from diracx.db.sql.pilots.db import PilotAgentsDB +from diracx.db.sql.pilots.schema import PilotAgents + +MAIN_VO = "lhcb" +N = 100 + +# ------------ Fetching data ------------ + + +async def get_pilots_by_stamp( + pilot_db: PilotAgentsDB, pilot_stamps: list[str], parameters: list[str] = [] +) -> list[dict[Any, Any]]: + _, pilots = await pilot_db.search_pilots( + parameters=parameters, + search=[ + VectorSearchSpec( + parameter="PilotStamp", + operator=VectorSearchOperator.IN, + values=pilot_stamps, + ) + ], + sorts=[], + distinct=True, + per_page=1000, + ) + + return pilots + + +async def get_pilot_jobs_ids_by_pilot_id( + pilot_db: PilotAgentsDB, pilot_id: int +) -> list[int]: + _, jobs = await pilot_db.search_pilot_to_job_mapping( + parameters=["JobID"], + search=[ + ScalarSearchSpec( + parameter="PilotID", + operator=ScalarSearchOperator.EQUAL, + value=pilot_id, + ) + ], + sorts=[], + distinct=True, + per_page=10000, + ) + + return [job["JobID"] for job in jobs] + + +# ------------ Creating data ------------ + + +@pytest.fixture +async def add_stamps(pilot_db): + async def _add_stamps(start_n=0): + async with pilot_db as db: + # Add pilots + refs = [f"ref_{i}" for i in range(start_n, start_n + N)] + stamps = [f"stamp_{i}" for i in range(start_n, start_n + N)] + pilot_references = dict(zip(stamps, refs)) + + vo = MAIN_VO + + await db.add_pilots( + stamps, vo, grid_type="DIRAC", pilot_references=pilot_references + ) + + return await get_pilots_by_stamp(db, stamps) + + return _add_stamps + + +@pytest.fixture +async def create_timed_pilots(pilot_db, add_stamps): + async def _create_timed_pilots( + old_date: datetime, aborted: bool = False, start_n=0 + ): + # Get pilots + pilots = await add_stamps(start_n) + + async with pilot_db as db: + # Update manually their age + # Collect PilotStamps + pilot_stamps = [pilot["PilotStamp"] for pilot in pilots] + + stmt = ( + update(PilotAgents) + .where(PilotAgents.pilot_stamp.in_(pilot_stamps)) + .values(SubmissionTime=old_date) + ) + + if aborted: + stmt = stmt.values(Status="Aborted") + + res = await db.conn.execute(stmt) + assert res.rowcount == len(pilot_stamps) + + pilots = await get_pilots_by_stamp(db, pilot_stamps) + return pilots + + return _create_timed_pilots + + +@pytest.fixture +async def create_old_pilots_environment(pilot_db, create_timed_pilots): + non_aborted_recent = await create_timed_pilots( + datetime(2025, 1, 1, tzinfo=timezone.utc), False, N + ) + aborted_recent = await create_timed_pilots( + datetime(2025, 1, 1, tzinfo=timezone.utc), True, 2 * N + ) + + aborted_very_old = await create_timed_pilots( + datetime(2003, 3, 10, tzinfo=timezone.utc), True, 3 * N + ) + non_aborted_very_old = await create_timed_pilots( + datetime(2003, 3, 10, tzinfo=timezone.utc), False, 4 * N + ) + + pilot_number = 4 * N + + assert pilot_number == ( + len(non_aborted_recent) + + len(aborted_recent) + + len(aborted_very_old) + + len(non_aborted_very_old) + ) + + # Phase 0. Verify that we have the right environment + async with pilot_db as pilot_db: + # Ensure that we can get every pilot (only get first of each group) + await get_pilots_by_stamp(pilot_db, [non_aborted_recent[0]["PilotStamp"]]) + await get_pilots_by_stamp(pilot_db, [aborted_recent[0]["PilotStamp"]]) + await get_pilots_by_stamp(pilot_db, [aborted_very_old[0]["PilotStamp"]]) + await get_pilots_by_stamp(pilot_db, [non_aborted_very_old[0]["PilotStamp"]]) + + return non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old diff --git a/diracx-db/tests/test_dummy_db.py b/diracx-db/tests/test_dummy_db.py index e0106d833..9c10d9be2 100644 --- a/diracx-db/tests/test_dummy_db.py +++ b/diracx-db/tests/test_dummy_db.py @@ -27,7 +27,7 @@ async def test_insert_and_summary(dummy_db: DummyDB): # So it is important to write test this way async with dummy_db as dummy_db: # First we check that the DB is empty - result = await dummy_db.summary(["Model"], []) + result = await dummy_db.dummy_summary(["Model"], []) assert not result # Now we add some data in the DB @@ -44,13 +44,13 @@ async def test_insert_and_summary(dummy_db: DummyDB): # Check that there are now 10 cars assigned to a single driver async with dummy_db as dummy_db: - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert result[0]["count"] == 10 # Test the selection async with dummy_db as dummy_db: - result = await dummy_db.summary( + result = await dummy_db.dummy_summary( ["OwnerID"], [{"parameter": "Model", "operator": "eq", "value": "model_1"}] ) @@ -58,7 +58,7 @@ async def test_insert_and_summary(dummy_db: DummyDB): async with dummy_db as dummy_db: with pytest.raises(InvalidQueryError): - result = await dummy_db.summary( + result = await dummy_db.dummy_summary( ["OwnerID"], [ { @@ -93,7 +93,7 @@ async def test_successful_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert not result # Add data @@ -104,7 +104,7 @@ async def test_successful_transaction(dummy_db): ) assert result - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert result[0]["count"] == 10 # The connection is closed when the context manager is exited @@ -114,7 +114,7 @@ async def test_successful_transaction(dummy_db): # Start a new transaction # The previous data should still be there because the transaction was committed (successful) async with dummy_db as dummy_db: - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert result[0]["count"] == 10 @@ -129,12 +129,12 @@ async def test_failed_transaction(dummy_db): # The connection is created when the context manager is entered # This is our transaction - with pytest.raises(KeyError): + with pytest.raises(InvalidQueryError): async with dummy_db as dummy_db: assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert not result # Add data @@ -149,7 +149,8 @@ async def test_failed_transaction(dummy_db): assert result # This will raise an exception and the transaction will be rolled back - result = await dummy_db.summary(["unexistingfieldraisinganerror"], []) + + result = await dummy_db.dummy_summary(["unexistingfieldraisinganerror"], []) assert result[0]["count"] == 10 # The connection is closed when the context manager is exited @@ -159,7 +160,7 @@ async def test_failed_transaction(dummy_db): # Start a new transaction # The previous data should not be there because the transaction was rolled back (failed) async with dummy_db as dummy_db: - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert not result @@ -203,7 +204,7 @@ async def test_successful_with_exception_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert not result # Add data @@ -217,7 +218,7 @@ async def test_successful_with_exception_transaction(dummy_db): ) assert result - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert result[0]["count"] == 10 # This will raise an exception but the transaction will be rolled back @@ -231,7 +232,7 @@ async def test_successful_with_exception_transaction(dummy_db): # Start a new transaction # The previous data should not be there because the transaction was rolled back (failed) async with dummy_db as dummy_db: - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert not result # Start a new transaction, this time we commit it manually @@ -240,7 +241,7 @@ async def test_successful_with_exception_transaction(dummy_db): assert dummy_db.conn # First we check that the DB is empty - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert not result # Add data @@ -254,7 +255,7 @@ async def test_successful_with_exception_transaction(dummy_db): ) assert result - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert result[0]["count"] == 10 # Manually commit the transaction, and then raise an exception @@ -271,5 +272,5 @@ async def test_successful_with_exception_transaction(dummy_db): # Start a new transaction # The previous data should be there because the transaction was committed before the exception async with dummy_db as dummy_db: - result = await dummy_db.summary(["OwnerID"], []) + result = await dummy_db.dummy_summary(["OwnerID"], []) assert result[0]["count"] == 10 diff --git a/diracx-logic/src/diracx/logic/jobs/query.py b/diracx-logic/src/diracx/logic/jobs/query.py index efb4b2fc5..23fb4557e 100644 --- a/diracx-logic/src/diracx/logic/jobs/query.py +++ b/diracx-logic/src/diracx/logic/jobs/query.py @@ -5,9 +5,9 @@ from diracx.core.config.schema import Config from diracx.core.models import ( - JobSearchParams, - JobSummaryParams, ScalarSearchOperator, + SearchParams, + SummaryParams, ) from diracx.db.os.job_parameters import JobParametersDB from diracx.db.sql.job.db import JobDB @@ -27,7 +27,7 @@ async def search( preferred_username: str | None, page: int = 1, per_page: int = 100, - body: JobSearchParams | None = None, + body: SearchParams | None = None, ) -> tuple[int, list[dict[str, Any]]]: """Retrieve information about jobs.""" # Apply a limit to per_page to prevent abuse of the API @@ -35,7 +35,7 @@ async def search( per_page = MAX_PER_PAGE if body is None: - body = JobSearchParams() + body = SearchParams() if query_logging_info := ("LoggingInfo" in (body.parameters or [])): if body.parameters: @@ -62,7 +62,7 @@ async def search( } ) - total, jobs = await job_db.search( + total, jobs = await job_db.search_jobs( body.parameters, body.search, body.sort, @@ -85,7 +85,7 @@ async def summary( config: Config, job_db: JobDB, preferred_username: str, - body: JobSummaryParams, + body: SummaryParams, ): """Show information suitable for plotting.""" if not config.Operations["Defaults"].Services.JobMonitoring.GlobalJobsInfo: @@ -100,4 +100,4 @@ async def summary( "value": preferred_username, } ) - return await job_db.summary(body.grouping, body.search) + return await job_db.job_summary(body.grouping, body.search) diff --git a/diracx-logic/src/diracx/logic/jobs/status.py b/diracx-logic/src/diracx/logic/jobs/status.py index 82b670137..2f1138cac 100644 --- a/diracx-logic/src/diracx/logic/jobs/status.py +++ b/diracx-logic/src/diracx/logic/jobs/status.py @@ -124,7 +124,7 @@ async def set_job_statuses( } # search all jobs at once - _, results = await job_db.search( + _, results = await job_db.search_jobs( parameters=["Status", "StartExecTime", "EndExecTime", "JobID", "VO"], search=[ { @@ -291,7 +291,7 @@ async def reschedule_jobs( attribute_changes: defaultdict[int, dict[str, str]] = defaultdict(dict) jdl_changes = {} - _, results = await job_db.search( + _, results = await job_db.search_jobs( parameters=[ "Status", "MinorStatus", @@ -558,7 +558,7 @@ async def add_heartbeat( "operator": VectorSearchOperator.IN, "values": list(data), } - _, results = await job_db.search( + _, results = await job_db.search_jobs( parameters=["Status", "JobID"], search=[search_query], sorts=[] ) if len(results) != len(data): @@ -623,9 +623,15 @@ async def _insert_parameters( if not updates: return # Get the VOs for the job IDs (required for the index template) - job_vos = await job_db.summary( + job_vos = await job_db.job_summary( ["JobID", "VO"], - [{"parameter": "JobID", "operator": "in", "values": list(updates)}], + [ + VectorSearchSpec( + parameter="JobID", + operator=VectorSearchOperator.IN, + values=list(updates), + ) + ], ) job_id_to_vo = {int(x["JobID"]): str(x["VO"]) for x in job_vos} # Upsert the parameters into the JobParametersDB diff --git a/diracx-logic/src/diracx/logic/pilots/__init__.py b/diracx-logic/src/diracx/logic/pilots/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/diracx-logic/src/diracx/logic/pilots/management.py b/diracx-logic/src/diracx/logic/pilots/management.py new file mode 100644 index 000000000..417d3a9a6 --- /dev/null +++ b/diracx-logic/src/diracx/logic/pilots/management.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +from diracx.core.exceptions import PilotAlreadyExistsError, PilotNotFoundError +from diracx.core.models import PilotFieldsMapping +from diracx.db.sql import PilotAgentsDB + +from .query import ( + get_outdated_pilots, + get_pilot_ids_by_stamps, + get_pilot_jobs_ids_by_pilot_id, + get_pilots_by_stamp, +) + + +async def register_new_pilots( + pilot_db: PilotAgentsDB, + pilot_stamps: list[str], + vo: str, + grid_type: str, + grid_site: str, + destination_site: str, + status: str, + pilot_job_references: dict[str, str] | None, +): + # [IMPORTANT] Check unicity of pilot stamps + # If a pilot already exists, we raise an error (transaction will rollback) + existing_pilots = await get_pilots_by_stamp( + pilot_db=pilot_db, pilot_stamps=pilot_stamps + ) + + # If we found pilots from the list, this means some pilots already exists + if len(existing_pilots) > 0: + found_keys = {pilot["PilotStamp"] for pilot in existing_pilots} + + raise PilotAlreadyExistsError(data={"pilot_stamps": str(found_keys)}) + + await pilot_db.add_pilots( + pilot_stamps=pilot_stamps, + vo=vo, + grid_type=grid_type, + grid_site=grid_site, + destination_site=destination_site, + pilot_references=pilot_job_references, + status=status, + ) + + +async def delete_pilots( + pilot_db: PilotAgentsDB, + pilot_stamps: list[str] | None = None, + age_in_days: int | None = None, + delete_only_aborted: bool = True, + vo_constraint: str | None = None, +): + if pilot_stamps: + pilot_ids = await get_pilot_ids_by_stamps( + pilot_db=pilot_db, pilot_stamps=pilot_stamps, allow_missing=True + ) + else: + assert age_in_days + assert vo_constraint + + cutoff_date = datetime.now(tz=timezone.utc) - timedelta(days=age_in_days) + + pilots = await get_outdated_pilots( + pilot_db=pilot_db, + cutoff_date=cutoff_date, + only_aborted=delete_only_aborted, + parameters=["PilotID"], + vo_constraint=vo_constraint, + ) + + pilot_ids = [pilot["PilotID"] for pilot in pilots] + + await pilot_db.remove_jobs_from_pilots(pilot_ids) + await pilot_db.delete_pilot_logs(pilot_ids) + await pilot_db.delete_pilots(pilot_ids) + + +async def update_pilots_fields( + pilot_db: PilotAgentsDB, pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] +): + await pilot_db.update_pilot_fields(pilot_stamps_to_fields_mapping) + + +async def add_jobs_to_pilot( + pilot_db: PilotAgentsDB, pilot_stamp: str, job_ids: list[int] +): + pilot_ids = await get_pilot_ids_by_stamps( + pilot_db=pilot_db, pilot_stamps=[pilot_stamp] + ) + pilot_id = pilot_ids[0] + + now = datetime.now(tz=timezone.utc) + + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} for job_id in job_ids + ] + + await pilot_db.add_jobs_to_pilot( + job_to_pilot_mapping=job_to_pilot_mapping, + ) + + +async def get_pilot_jobs_ids_by_stamp( + pilot_db: PilotAgentsDB, pilot_stamp: str +) -> list[int]: + """Fetch pilot jobs by stamp.""" + try: + pilot_ids = await get_pilot_ids_by_stamps( + pilot_db=pilot_db, pilot_stamps=[pilot_stamp] + ) + pilot_id = pilot_ids[0] + except PilotNotFoundError: + return [] + + return await get_pilot_jobs_ids_by_pilot_id(pilot_db=pilot_db, pilot_id=pilot_id) diff --git a/diracx-logic/src/diracx/logic/pilots/query.py b/diracx-logic/src/diracx/logic/pilots/query.py new file mode 100644 index 000000000..6dd46abaa --- /dev/null +++ b/diracx-logic/src/diracx/logic/pilots/query.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from diracx.core.exceptions import PilotNotFoundError +from diracx.core.models import ( + PilotStatus, + ScalarSearchOperator, + ScalarSearchSpec, + SearchParams, + SearchSpec, + SortDirection, + SummaryParams, + VectorSearchOperator, + VectorSearchSpec, +) +from diracx.db.os.pilot_logs import PilotLogsDB +from diracx.db.sql import PilotAgentsDB + +MAX_PER_PAGE = 10000 + + +async def search( + pilot_db: PilotAgentsDB, + user_vo: str, + page: int = 1, + per_page: int = 100, + body: SearchParams | None = None, +) -> tuple[int, list[dict[str, Any]]]: + """Retrieve information about jobs.""" + # Apply a limit to per_page to prevent abuse of the API + if per_page > MAX_PER_PAGE: + per_page = MAX_PER_PAGE + + if body is None: + body = SearchParams() + + body.search.append( + ScalarSearchSpec( + parameter="VO", operator=ScalarSearchOperator.EQUAL, value=user_vo + ) + ) + + total, pilots = await pilot_db.search_pilots( + body.parameters, + body.search, + body.sort, + distinct=body.distinct, + page=page, + per_page=per_page, + ) + + return total, pilots + + +async def get_pilots_by_stamp( + pilot_db: PilotAgentsDB, + pilot_stamps: list[str], + parameters: list[str] = [], + allow_missing: bool = True, +) -> list[dict[Any, Any]]: + """Get pilots by their stamp. + + If `allow_missing` is set to False, if a pilot is missing, PilotNotFoundError will be raised. + """ + if parameters: + parameters.append("PilotStamp") + + _, pilots = await pilot_db.search_pilots( + parameters=parameters, + search=[ + VectorSearchSpec( + parameter="PilotStamp", + operator=VectorSearchOperator.IN, + values=pilot_stamps, + ) + ], + sorts=[], + distinct=True, + per_page=MAX_PER_PAGE, + ) + + # allow_missing is set as True by default to mark explicitly when we allow or not + if not allow_missing: + # Custom handling, to see which pilot_stamp does not exist (if so, say which one) + found_keys = {row["PilotStamp"] for row in pilots} + missing = set(pilot_stamps) - found_keys + + if missing: + raise PilotNotFoundError( + data={"pilot_stamp": str(missing)}, + detail=str(missing), + non_existing_pilots=missing, + ) + + return pilots + + +async def get_pilot_ids_by_stamps( + pilot_db: PilotAgentsDB, pilot_stamps: list[str], allow_missing=False +) -> list[int]: + pilots = await get_pilots_by_stamp( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + parameters=["PilotID"], + allow_missing=allow_missing, + ) + + return [pilot["PilotID"] for pilot in pilots] + + +async def get_pilot_jobs_ids_by_pilot_id( + pilot_db: PilotAgentsDB, pilot_id: int +) -> list[int]: + _, jobs = await pilot_db.search_pilot_to_job_mapping( + parameters=["JobID"], + search=[ + ScalarSearchSpec( + parameter="PilotID", + operator=ScalarSearchOperator.EQUAL, + value=pilot_id, + ) + ], + sorts=[], + distinct=True, + per_page=MAX_PER_PAGE, + ) + + return [job["JobID"] for job in jobs] + + +async def get_pilot_ids_by_job_id(pilot_db: PilotAgentsDB, job_id: int) -> list[int]: + _, pilots = await pilot_db.search_pilot_to_job_mapping( + parameters=["PilotID"], + search=[ + ScalarSearchSpec( + parameter="JobID", + operator=ScalarSearchOperator.EQUAL, + value=job_id, + ) + ], + sorts=[], + distinct=True, + per_page=MAX_PER_PAGE, + ) + + return [pilot["PilotID"] for pilot in pilots] + + +async def get_outdated_pilots( + pilot_db: PilotAgentsDB, + cutoff_date: datetime, + vo_constraint: str, + only_aborted: bool = True, + parameters: list[str] = [], +): + query: list[SearchSpec] = [ + ScalarSearchSpec( + parameter="SubmissionTime", + operator=ScalarSearchOperator.LESS_THAN, + value=cutoff_date, + ), + # Add VO to avoid deleting other VO's pilots + ScalarSearchSpec( + parameter="VO", operator=ScalarSearchOperator.EQUAL, value=vo_constraint + ), + ] + + if only_aborted: + query.append( + ScalarSearchSpec( + parameter="Status", + operator=ScalarSearchOperator.EQUAL, + value=PilotStatus.ABORTED, + ) + ) + + _, pilots = await pilot_db.search_pilots( + parameters=parameters, search=query, sorts=[] + ) + + return pilots + + +async def summary(pilot_db: PilotAgentsDB, body: SummaryParams, vo: str): + """Show information suitable for plotting.""" + body.search.append( + { + "parameter": "VO", + "operator": ScalarSearchOperator.EQUAL, + "value": vo, + } + ) + return await pilot_db.pilot_summary(body.grouping, body.search) + + +async def search_logs( + vo: str, + body: SearchParams | None, + per_page: int, + page: int, + pilot_logs_db: PilotLogsDB, +) -> tuple[int, list[dict]]: + """Retrieve logs from OpenSearch for a given PilotStamp.""" + # Apply a limit to per_page to prevent abuse of the API + if per_page > MAX_PER_PAGE: + per_page = MAX_PER_PAGE + + if body is None: + body = SearchParams() + + search = body.search + parameters = body.parameters + sorts = body.sort + + # Add the vo to make sure that we filter for pilots we can see + # TODO: Test it + search = search + [ + { + "parameter": "VO", + "operator": "eq", + "value": vo, + } + ] + + if not sorts: + sorts = [{"parameter": "TimeStamp", "direction": SortDirection("asc")}] + + return await pilot_logs_db.search( + parameters=parameters, search=search, sorts=sorts, per_page=per_page, page=page + ) diff --git a/diracx-logic/src/diracx/logic/pilots/resources.py b/diracx-logic/src/diracx/logic/pilots/resources.py new file mode 100644 index 000000000..292a74c9a --- /dev/null +++ b/diracx-logic/src/diracx/logic/pilots/resources.py @@ -0,0 +1,45 @@ +"""File dedicated to logic for pilot only resources (logs, jobs, etc.).""" + +from __future__ import annotations + +from diracx.core.exceptions import PilotNotFoundError +from diracx.core.models import LogLine +from diracx.db.os.pilot_logs import PilotLogsDB +from diracx.db.sql.pilots.db import PilotAgentsDB + +from .query import get_pilot_ids_by_stamps + + +async def send_message( + lines: list[LogLine], + pilot_logs_db: PilotLogsDB, + pilot_db: PilotAgentsDB, + vo: str, + pilot_stamp: str, +): + try: + pilot_ids = await get_pilot_ids_by_stamps( + pilot_db=pilot_db, pilot_stamps=[pilot_stamp] + ) + pilot_id = pilot_ids[0] # Semantic + except PilotNotFoundError: + # If a pilot is not found, then we still store the data (to not lost it) + # We log it as it's not supposed to happen + # If we arrive here, the pilot as been deleted but is still "alive" + pilot_id = -1 # To detect + + docs = [] + for line in lines: + docs.append( + { + "PilotStamp": pilot_stamp, + "PilotID": pilot_id, + "VO": vo, + "Severity": line.severity, + "Message": line.message, + "TimeStamp": line.timestamp, + "Scope": line.scope, + } + ) + # bulk insert pilot logs to OpenSearch DB: + await pilot_logs_db.bulk_insert(pilot_logs_db.index_name(vo, pilot_id), docs) diff --git a/diracx-routers/pyproject.toml b/diracx-routers/pyproject.toml index 6f554c74e..d97fa1aa2 100644 --- a/diracx-routers/pyproject.toml +++ b/diracx-routers/pyproject.toml @@ -46,10 +46,14 @@ auth = "diracx.routers.auth:router" config = "diracx.routers.configuration:router" health = "diracx.routers.health:router" jobs = "diracx.routers.jobs:router" +pilots = "diracx.routers.pilots:router" +"pilots/legacy" = "diracx.routers.legacy_pilot_resources:router" [project.entry-points."diracx.access_policies"] WMSAccessPolicy = "diracx.routers.jobs.access_policies:WMSAccessPolicy" SandboxAccessPolicy = "diracx.routers.jobs.access_policies:SandboxAccessPolicy" +PilotManagementAccessPolicy = "diracx.routers.pilots.access_policies:PilotManagementAccessPolicy" +LegacyPilotAccessPolicy = "diracx.routers.legacy_pilot_resources.access_policies:LegacyPilotAccessPolicy" # Minimum version of the client supported [project.entry-points."diracx.min_client_version"] diff --git a/diracx-routers/src/diracx/routers/dependencies.py b/diracx-routers/src/diracx/routers/dependencies.py index 8eb2bd265..88a5be6d0 100644 --- a/diracx-routers/src/diracx/routers/dependencies.py +++ b/diracx-routers/src/diracx/routers/dependencies.py @@ -23,6 +23,7 @@ from diracx.core.settings import DevelopmentSettings as _DevelopmentSettings from diracx.core.settings import SandboxStoreSettings as _SandboxStoreSettings from diracx.db.os import JobParametersDB as _JobParametersDB +from diracx.db.os import PilotLogsDB as _PilotLogsDB from diracx.db.sql import AuthDB as _AuthDB from diracx.db.sql import JobDB as _JobDB from diracx.db.sql import JobLoggingDB as _JobLoggingDB @@ -50,6 +51,7 @@ def add_settings_annotation(cls: T) -> T: # Opensearch databases JobParametersDB = Annotated[_JobParametersDB, Depends(_JobParametersDB.session)] +PilotLogsDB = Annotated[_PilotLogsDB, Depends(_PilotLogsDB.session)] # Miscellaneous diff --git a/diracx-routers/src/diracx/routers/jobs/access_policies.py b/diracx-routers/src/diracx/routers/jobs/access_policies.py index 1fd5a63ae..2239e4764 100644 --- a/diracx-routers/src/diracx/routers/jobs/access_policies.py +++ b/diracx-routers/src/diracx/routers/jobs/access_policies.py @@ -6,6 +6,7 @@ from fastapi import Depends, HTTPException, status +from diracx.core.models import VectorSearchOperator, VectorSearchSpec from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER from diracx.db.sql import JobDB, SandboxMetadataDB from diracx.routers.access_policies import BaseAccessPolicy @@ -82,9 +83,13 @@ async def policy( # Now we know we are either in READ/MODIFY for a NORMAL_USER # so just make sure that whatever job_id was given belongs # to the current user - job_owners = await job_db.summary( + job_owners = await job_db.job_summary( ["Owner", "VO"], - [{"parameter": "JobID", "operator": "in", "values": job_ids}], + [ + VectorSearchSpec( + parameter="JobID", operator=VectorSearchOperator.IN, values=job_ids + ) + ], ) expected_owner = { diff --git a/diracx-routers/src/diracx/routers/jobs/query.py b/diracx-routers/src/diracx/routers/jobs/query.py index a8667b7dd..db270ca4d 100644 --- a/diracx-routers/src/diracx/routers/jobs/query.py +++ b/diracx-routers/src/diracx/routers/jobs/query.py @@ -6,8 +6,8 @@ from fastapi import Body, Depends, Response from diracx.core.models import ( - JobSearchParams, - JobSummaryParams, + SearchParams, + SummaryParams, ) from diracx.core.properties import JOB_ADMINISTRATOR from diracx.logic.jobs.query import search as search_bl @@ -135,7 +135,7 @@ async def search( page: int = 1, per_page: int = 100, body: Annotated[ - JobSearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES) + SearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES) ] = None, ) -> list[dict[str, Any]]: """Retrieve information about jobs. @@ -183,7 +183,7 @@ async def summary( config: Config, job_db: JobDB, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], - body: JobSummaryParams, + body: SummaryParams, check_permissions: CheckWMSPolicyCallable, ): """Show information suitable for plotting.""" diff --git a/diracx-routers/src/diracx/routers/legacy_pilot_resources/__init__.py b/diracx-routers/src/diracx/routers/legacy_pilot_resources/__init__.py new file mode 100644 index 000000000..367fc6c93 --- /dev/null +++ b/diracx-routers/src/diracx/routers/legacy_pilot_resources/__init__.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import logging + +from ..fastapi_classes import DiracxRouter +from .logs import router as legacy_router + +logger = logging.getLogger(__name__) + +router = DiracxRouter(require_auth=False) +router.include_router(legacy_router) diff --git a/diracx-routers/src/diracx/routers/legacy_pilot_resources/access_policies.py b/diracx-routers/src/diracx/routers/legacy_pilot_resources/access_policies.py new file mode 100644 index 000000000..187323e8e --- /dev/null +++ b/diracx-routers/src/diracx/routers/legacy_pilot_resources/access_policies.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Annotated + +from fastapi import Depends, HTTPException, status + +from diracx.core.properties import GENERIC_PILOT, LIMITED_DELEGATION +from diracx.routers.access_policies import BaseAccessPolicy +from diracx.routers.utils.users import AuthorizedUserInfo + + +class LegacyPilotAccessPolicy(BaseAccessPolicy): + """Rules: + * Every user can access data about his VO + * An administrator can modify a pilot. + """ + + @staticmethod + async def policy( + policy_name: str, + user_info: AuthorizedUserInfo, + /, + ): + if ( + LIMITED_DELEGATION not in user_info.properties + and GENERIC_PILOT not in user_info.properties + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You must be a pilot to access this resource.", + ) + + return + + +CheckLegacyPilotPolicyCallable = Annotated[ + Callable, Depends(LegacyPilotAccessPolicy.check) +] diff --git a/diracx-routers/src/diracx/routers/legacy_pilot_resources/logs.py b/diracx-routers/src/diracx/routers/legacy_pilot_resources/logs.py new file mode 100644 index 000000000..452957d95 --- /dev/null +++ b/diracx-routers/src/diracx/routers/legacy_pilot_resources/logs.py @@ -0,0 +1,51 @@ +"""File dedicated to legacy pilot resources: pilots with DIRAC auth, without JWT.""" + +from __future__ import annotations + +import logging +from http import HTTPStatus +from typing import Annotated + +from fastapi import Body, Depends + +from diracx.core.models import LogLine +from diracx.logic.pilots.resources import send_message as send_message_bl +from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token + +from ..dependencies import PilotAgentsDB, PilotLogsDB +from ..fastapi_classes import DiracxRouter +from .access_policies import ( + CheckLegacyPilotPolicyCallable, +) + +logger = logging.getLogger(__name__) +router = DiracxRouter() + + +@router.post("/message", status_code=HTTPStatus.NO_CONTENT) +async def send_message( + lines: Annotated[ + list[LogLine], + Body(description="Message from the pilot to the logging system.", embed=True), + ], + pilot_stamp: Annotated[ + str, + Body( + description="PilotStamp, required as legacy pilots do not have a token with stamp in it." + ), + ], + pilot_logs_db: PilotLogsDB, + pilot_db: PilotAgentsDB, + check_permissions: CheckLegacyPilotPolicyCallable, + pilot_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], +): + """Send logs with legacy pilot.""" + await check_permissions() + + await send_message_bl( + lines=lines, + pilot_logs_db=pilot_logs_db, + pilot_db=pilot_db, + vo=pilot_info.vo, + pilot_stamp=pilot_stamp, + ) diff --git a/diracx-routers/src/diracx/routers/pilots/__init__.py b/diracx-routers/src/diracx/routers/pilots/__init__.py new file mode 100644 index 000000000..03f9b8422 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/__init__.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import logging + +from ..fastapi_classes import DiracxRouter +from .management import router as management_router +from .query import router as query_router + +logger = logging.getLogger(__name__) + +router = DiracxRouter() +router.include_router(management_router) +router.include_router(query_router) diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py new file mode 100644 index 000000000..a19ca4537 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +from collections.abc import Callable +from enum import StrEnum, auto +from typing import Annotated + +from fastapi import Depends, HTTPException, status + +from diracx.core.models import VectorSearchOperator, VectorSearchSpec +from diracx.core.properties import GENERIC_PILOT, SERVICE_ADMINISTRATOR +from diracx.db.sql.job.db import JobDB +from diracx.db.sql.pilots.db import PilotAgentsDB +from diracx.logic.pilots.query import get_pilots_by_stamp +from diracx.routers.access_policies import BaseAccessPolicy +from diracx.routers.utils.users import AuthorizedUserInfo + + +class ActionType(StrEnum): + # Change some pilot fields + MANAGE_PILOTS = auto() + # Read some pilot info + READ_PILOT_FIELDS = auto() + # Legacy Pilot + LEGACY_PILOT = auto() + + +class PilotManagementAccessPolicy(BaseAccessPolicy): + """Rules: + * Every user can access data about his VO + * An administrator can modify a pilot. + """ + + @staticmethod + async def policy( + policy_name: str, + user_info: AuthorizedUserInfo, + /, + *, + action: ActionType | None = None, + pilot_db: PilotAgentsDB | None = None, + pilot_stamps: list[str] | None = None, + job_db: JobDB | None = None, + job_ids: list[int] | None = None, + allow_legacy_pilots: bool = False, + ): + assert action, "action is a mandatory parameter" + + is_a_pilot_if_allowed = ( + allow_legacy_pilots and GENERIC_PILOT in user_info.properties + ) + + if action == ActionType.LEGACY_PILOT: + if is_a_pilot_if_allowed: + return + + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You must be a pilot to access this resource.", + ) + + # Users can query + # NOTE: Add into queries a VO constraint + # To manage pilots, user have to be an admin + # In some special cases (described with allow_legacy_pilots), we can allow pilots + if action == ActionType.MANAGE_PILOTS: + # To make it clear, we separate + is_an_admin = SERVICE_ADMINISTRATOR in user_info.properties + + if not is_an_admin and not is_a_pilot_if_allowed: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have the permission to manage pilots.", + ) + + if action == ActionType.READ_PILOT_FIELDS: + if GENERIC_PILOT in user_info.properties: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Pilots can't read other pilots info.", + ) + + # + # Additional checks if job_ids or pilot_stamps are provided + # + + # First, if job_ids are provided, we check who is the owner + if job_db and job_ids: + job_owners = await job_db.job_summary( + ["Owner", "VO"], + [ + VectorSearchSpec( + parameter="JobID", + operator=VectorSearchOperator.IN, + values=job_ids, + ) + ], + ) + + expected_owner = { + "Owner": user_info.preferred_username, + "VO": user_info.vo, + "count": len(set(job_ids)), + } + # All the jobs belong to the user doing the query + # and all of them are present + if not job_owners == [expected_owner]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have the rights to modify a pilot.", + ) + + # This is for example when we submit pilots, we use the user VO, so no need to verify + if pilot_db and pilot_stamps: + # Else, check its VO + pilots = await get_pilots_by_stamp( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + parameters=["VO"], + allow_missing=True, + ) + + if len(pilots) != len(pilot_stamps): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="At least one pilot does not exist.", + ) + + if not all(pilot["VO"] == user_info.vo for pilot in pilots): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have access to all pilots.", + ) + + +CheckPilotManagementPolicyCallable = Annotated[ + Callable, Depends(PilotManagementAccessPolicy.check) +] diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py new file mode 100644 index 000000000..21ff63796 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +from http import HTTPStatus +from typing import Annotated + +from fastapi import Body, Depends, HTTPException, Query, status + +from diracx.core.exceptions import ( + PilotAlreadyExistsError, +) +from diracx.core.models import ( + PilotFieldsMapping, + PilotStatus, +) +from diracx.core.properties import GENERIC_PILOT +from diracx.logic.pilots.management import ( + delete_pilots as delete_pilots_bl, +) +from diracx.logic.pilots.management import ( + get_pilot_jobs_ids_by_stamp, + register_new_pilots, + update_pilots_fields, +) +from diracx.logic.pilots.query import get_pilot_ids_by_job_id +from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token + +from ..dependencies import JobDB, PilotAgentsDB +from ..fastapi_classes import DiracxRouter +from .access_policies import ( + ActionType, + CheckPilotManagementPolicyCallable, +) + +router = DiracxRouter() + + +@router.post("/") +async def add_pilot_stamps( + pilot_db: PilotAgentsDB, + pilot_stamps: Annotated[ + list[str], + Body(description="List of the pilot stamps we want to add to the db."), + ], + vo: Annotated[str, Body(description="Pilot virtual organization.")], + check_permissions: CheckPilotManagementPolicyCallable, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + grid_type: Annotated[str, Body(description="Grid type of the pilots.")] = "Dirac", + grid_site: Annotated[str, Body(description="Pilots grid site.")] = "Unknown", + destination_site: Annotated[ + str, Body(description="Pilots destination site.") + ] = "NotAssigned", + pilot_references: Annotated[ + dict[str, str] | None, + Body(description="Association of a pilot reference with a pilot stamp."), + ] = None, + pilot_status: Annotated[ + PilotStatus, Body(description="Status of the pilots.") + ] = PilotStatus.SUBMITTED, +): + """Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + """ + # TODO: Verify that grid types, sites, destination sites, etc. are valids + await check_permissions( + action=ActionType.MANAGE_PILOTS, + allow_legacy_pilots=True, # dirac-admin-add-pilot + ) + + # Prevent someone who stole a pilot X509 to create thousands of pilots at a time + # (It would be still able to create thousands of pilots, but slower) + if GENERIC_PILOT in user_info.properties: + if len(pilot_stamps) != 1: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="As a pilot, you can only create yourself.", + ) + + try: + await register_new_pilots( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + vo=vo, + grid_type=grid_type, + grid_site=grid_site, + destination_site=destination_site, + pilot_job_references=pilot_references, + status=pilot_status, + ) + except PilotAlreadyExistsError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e + + +@router.delete("/", status_code=HTTPStatus.NO_CONTENT) +async def delete_pilots( + pilot_db: PilotAgentsDB, + check_permissions: CheckPilotManagementPolicyCallable, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + pilot_stamps: Annotated[ + list[str] | None, Query(description="Stamps of the pilots we want to delete.") + ] = None, + age_in_days: Annotated[ + int | None, + Query( + description=( + "The number of days that define the maximum age of pilots to be deleted." + "Pilots older than this age will be considered for deletion." + ) + ), + ] = None, + delete_only_aborted: Annotated[ + bool, + Query( + description=( + "Flag indicating whether to only delete pilots whose status is 'Aborted'." + "If set to True, only pilots with the 'Aborted' status will be deleted." + "It is set by default as True to avoid any mistake." + "This flag is only used for deletion by time." + ) + ), + ] = False, +): + """Endpoint to delete a pilot. + + Two features: + + 1. Or you provide pilot_stamps, so you can delete pilots by their stamp + 2. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. + """ + vo_constraint: str | None = None + + # If we delete by pilot_stamps, we check that we can access them + # Else, we add a constraint to the request, to avoid deleting pilots from another VO + if pilot_stamps: + await check_permissions( + action=ActionType.MANAGE_PILOTS, + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + ) + else: + vo_constraint = user_info.vo + + if not pilot_stamps and not age_in_days: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="pilot_stamps or age_in_days have to be provided.", + ) + + await delete_pilots_bl( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + vo_constraint=vo_constraint, + ) + + +EXAMPLE_UPDATE_FIELDS = { + "Update the BenchMark field": { + "summary": "Update BenchMark", + "description": "Update only the BenchMark for one pilot.", + "value": { + "pilot_stamps_to_fields_mapping": [ + {"PilotStamp": "the_pilot_stamp", "BenchMark": 1.0} + ] + }, + }, + "Update multiple statuses": { + "summary": "Update multiple pilots", + "description": "Update multiple pilots statuses.", + "value": { + "pilot_stamps_to_fields_mapping": [ + {"PilotStamp": "the_first_pilot_stamp", "Status": "Waiting"}, + {"PilotStamp": "the_second_pilot_stamp", "Status": "Waiting"}, + ] + }, + }, +} + + +@router.patch("/metadata", status_code=HTTPStatus.NO_CONTENT) +async def update_pilot_fields( + pilot_stamps_to_fields_mapping: Annotated[ + list[PilotFieldsMapping], + Body( + description="(pilot_stamp, pilot_fields) mapping to change.", + embed=True, + openapi_examples=EXAMPLE_UPDATE_FIELDS, # type: ignore + ), + ], + pilot_db: PilotAgentsDB, + check_permissions: CheckPilotManagementPolicyCallable, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], +): + """Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + """ + # Ensures stamps validity + pilot_stamps = [mapping.PilotStamp for mapping in pilot_stamps_to_fields_mapping] + await check_permissions( + action=ActionType.MANAGE_PILOTS, + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + allow_legacy_pilots=True, # dirac-admin-add-pilot + ) + + # Prevent someone who stole a pilot X509 to modify thousands of pilots at a time + # (It would be still able to modify thousands of pilots, but slower) + # We are not able to affirm that this pilot modifies itself + if GENERIC_PILOT in user_info.properties: + if len(pilot_stamps) != 1: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="As a pilot, you can only modify yourself.", + ) + + await update_pilots_fields( + pilot_db=pilot_db, + pilot_stamps_to_fields_mapping=pilot_stamps_to_fields_mapping, + ) + + +@router.get("/jobs") +async def get_pilot_jobs( + pilot_db: PilotAgentsDB, + job_db: JobDB, + check_permissions: CheckPilotManagementPolicyCallable, + pilot_stamp: Annotated[ + str | None, Query(description="The stamp of the pilot.") + ] = None, + job_id: Annotated[int | None, Query(description="The ID of the job.")] = None, +) -> list[int]: + """Endpoint only for admins, to get jobs of a pilot.""" + if pilot_stamp: + # Check VO + await check_permissions( + action=ActionType.READ_PILOT_FIELDS, + pilot_db=pilot_db, + pilot_stamps=[pilot_stamp], + ) + + return await get_pilot_jobs_ids_by_stamp( + pilot_db=pilot_db, + pilot_stamp=pilot_stamp, + ) + elif job_id: + # Check job owner + await check_permissions( + action=ActionType.READ_PILOT_FIELDS, job_db=job_db, job_ids=[job_id] + ) + + return await get_pilot_ids_by_job_id(pilot_db=pilot_db, job_id=job_id) + + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="You must provide either pilot_stamp or job_id", + ) diff --git a/diracx-routers/src/diracx/routers/pilots/query.py b/diracx-routers/src/diracx/routers/pilots/query.py new file mode 100644 index 000000000..7044a0326 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/query.py @@ -0,0 +1,286 @@ +from __future__ import annotations + +from http import HTTPStatus +from typing import Annotated, Any + +from fastapi import Body, Depends, Response +from opensearchpy import RequestError + +from diracx.core.models import SearchParams, SummaryParams +from diracx.logic.pilots.query import search as search_bl +from diracx.logic.pilots.query import search_logs as search_logs_bl +from diracx.logic.pilots.query import summary as summary_bl + +from ..dependencies import PilotAgentsDB, PilotLogsDB +from ..fastapi_classes import DiracxRouter +from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token +from .access_policies import ( + ActionType, + CheckPilotManagementPolicyCallable, +) + +router = DiracxRouter() + +EXAMPLE_SEARCHES = { + "Show all": { + "summary": "Show all", + "description": "Shows all pilots the current user has access to.", + "value": {}, + }, + "A specific pilot": { + "summary": "A specific pilot", + "description": "Search for a specific pilot by ID", + "value": {"search": [{"parameter": "PilotID", "operator": "eq", "value": "5"}]}, + }, + "Get ordered pilot statuses": { + "summary": "Get ordered pilot statuses", + "description": "Get only pilot statuses for specific pilots, ordered by status", + "value": { + "parameters": ["PilotID", "Status"], + "search": [ + {"parameter": "PilotID", "operator": "in", "values": ["6", "2", "3"]} + ], + "sort": [{"parameter": "PilotID", "direction": "asc"}], + }, + }, +} + + +EXAMPLE_RESPONSES: dict[int | str, dict[str, Any]] = { + 200: { + "description": "List of matching results", + "content": { + "application/json": { + "example": [ + { + "PilotID": 3, + "SubmissionTime": "2023-05-25T07:03:35.602654", + "LastUpdateTime": "2023-05-25T07:03:35.602656", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 1.0, + }, + { + "PilotID": 5, + "SubmissionTime": "2023-06-25T07:03:35.602654", + "LastUpdateTime": "2023-07-25T07:03:35.602652", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 63.1, + }, + ] + } + }, + }, + 206: { + "description": "Partial Content. Only a part of the requested range could be served.", + "headers": { + "Content-Range": { + "description": "The range of pilots returned in this response", + "schema": {"type": "string", "example": "pilots 0-1/4"}, + } + }, + "model": list[dict[str, Any]], + "content": { + "application/json": { + "example": [ + { + "PilotID": 3, + "SubmissionTime": "2023-05-25T07:03:35.602654", + "LastUpdateTime": "2023-05-25T07:03:35.602656", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 1.0, + }, + { + "PilotID": 5, + "SubmissionTime": "2023-06-25T07:03:35.602654", + "LastUpdateTime": "2023-07-25T07:03:35.602652", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 63.1, + }, + ] + } + }, + }, +} + + +@router.post("/search", responses=EXAMPLE_RESPONSES) +async def search( + pilot_db: PilotAgentsDB, + check_permissions: CheckPilotManagementPolicyCallable, + response: Response, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + page: int = 1, + per_page: int = 100, + body: Annotated[ + SearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES) # type: ignore + ] = None, +) -> list[dict[str, Any]]: + """Retrieve information about pilots.""" + # Inspired by /api/jobs/query + await check_permissions(action=ActionType.READ_PILOT_FIELDS) + + total, pilots = await search_bl( + pilot_db=pilot_db, + user_vo=user_info.vo, + page=page, + per_page=per_page, + body=body, + ) + + # Set the Content-Range header if needed + # https://datatracker.ietf.org/doc/html/rfc7233#section-4 + + # No pilots found but there are pilots for the requested search + # https://datatracker.ietf.org/doc/html/rfc7233#section-4.4 + if len(pilots) == 0 and total > 0: + response.headers["Content-Range"] = f"pilots */{total}" + response.status_code = HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE + + # The total number of pilots is greater than the number of pilots returned + # https://datatracker.ietf.org/doc/html/rfc7233#section-4.2 + elif len(pilots) < total: + first_idx = per_page * (page - 1) + last_idx = min(first_idx + len(pilots), total) - 1 if total > 0 else 0 + response.headers["Content-Range"] = f"pilots {first_idx}-{last_idx}/{total}" + response.status_code = HTTPStatus.PARTIAL_CONTENT + return pilots + + +EXAMPLE_SEARCHES_LOGS = { + "Show all": { + "summary": "Show all", + "description": "Shows all pilots the current user has access to.", + "value": {}, + }, + "A specific pilot": { + "summary": "A specific pilot", + "description": "Search for a specific pilot by ID", + "value": {"search": [{"parameter": "PilotID", "operator": "eq", "value": "5"}]}, + }, + "Get a specific severity": { + "summary": "Get ordered pilot statuses", + "description": 'Get only pilot logs that have a severity of "ERROR", ordered by PilotID', + "value": { + "parameters": ["PilotID", "Severity"], + "search": [{"parameter": "Severity", "operator": "eq", "value": "ERROR"}], + "sort": [{"parameter": "PilotID", "direction": "asc"}], + }, + }, +} + + +EXAMPLE_RESPONSES_LOGS: dict[int | str, dict[str, Any]] = { + 200: { + "description": "List of matching results", + "content": { + "application/json": { + "example": [ + { + "PilotID": 3, + "Severity": "ERROR", + "TimeStamp": "2023-05-25T07:03:35.602656", + }, + { + "PilotID": 5, + "Severity": "INFO", + "TimeStamp": "2023-07-25T07:03:35.602652", + }, + ] + } + }, + }, + 206: { + "description": "Partial Content. Only a part of the requested range could be served.", + "headers": { + "Content-Range": { + "description": "The range of logs returned in this response", + "schema": {"type": "string", "example": "logs 0-1/4"}, + } + }, + "model": list[dict[str, Any]], + "content": { + "application/json": { + "example": [ + { + "PilotID": 3, + "Severity": "ERROR", + "TimeStamp": "2023-05-25T07:03:35.602656", + }, + { + "PilotID": 5, + "Severity": "INFO", + "TimeStamp": "2023-07-25T07:03:35.602652", + }, + ] + } + }, + }, +} + + +@router.post("/search/logs", responses=EXAMPLE_RESPONSES_LOGS) +async def search_logs( + pilot_logs_db: PilotLogsDB, + response: Response, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + check_permissions: CheckPilotManagementPolicyCallable, + page: int = 1, + per_page: int = 100, + body: Annotated[ + SearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES_LOGS) + ] = None, +) -> list[dict]: + # users will only see logs from their own VO if enforced by a policy: + await check_permissions( + action=ActionType.READ_PILOT_FIELDS, + ) + + try: + total, logs = await search_logs_bl( + vo=user_info.vo, + body=body, + per_page=per_page, + page=page, + pilot_logs_db=pilot_logs_db, + ) + except RequestError: + total, logs = 0, [] + + # Set the Content-Range header if needed + # https://datatracker.ietf.org/doc/html/rfc7233#section-4 + + # No logs found but there are logs for the requested search + # https://datatracker.ietf.org/doc/html/rfc7233#section-4.4 + if len(logs) == 0 and total > 0: + response.headers["Content-Range"] = f"logs */{total}" + response.status_code = HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE + + # The total number of logs is greater than the number of pilots returned + # https://datatracker.ietf.org/doc/html/rfc7233#section-4.2 + elif len(logs) < total: + first_idx = per_page * (page - 1) + last_idx = min(first_idx + len(logs), total) - 1 if total > 0 else 0 + response.headers["Content-Range"] = f"logs {first_idx}-{last_idx}/{total}" + response.status_code = HTTPStatus.PARTIAL_CONTENT + return logs + + +@router.post("/summary") +async def summary( + pilot_db: PilotAgentsDB, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + body: SummaryParams, + check_permissions: CheckPilotManagementPolicyCallable, +): + """Show information suitable for plotting.""" + await check_permissions(action=ActionType.READ_PILOT_FIELDS) + + return await summary_bl( + pilot_db=pilot_db, + body=body, + vo=user_info.vo, + ) diff --git a/diracx-routers/tests/jobs/test_wms_access_policy.py b/diracx-routers/tests/jobs/test_wms_access_policy.py index 351db139f..1a805899c 100644 --- a/diracx-routers/tests/jobs/test_wms_access_policy.py +++ b/diracx-routers/tests/jobs/test_wms_access_policy.py @@ -23,7 +23,7 @@ class FakeJobDB: - async def summary(self, *args): ... + async def job_summary(self, *args): ... class FakeSBMetadataDB: @@ -159,7 +159,7 @@ async def test_wms_access_policy_read_modify(job_db, monkeypatch): async def summary_matching(*args): return [{"Owner": "preferred_username", "VO": "lhcb", "count": 3}] - monkeypatch.setattr(job_db, "summary", summary_matching) + monkeypatch.setattr(job_db, "job_summary", summary_matching) await WMSAccessPolicy.policy( WMS_POLICY_NAME, @@ -182,7 +182,7 @@ async def summary_matching(*args): async def summary_other_owner(*args): return [{"Owner": "other_owner", "VO": "lhcb", "count": 3}] - monkeypatch.setattr(job_db, "summary", summary_other_owner) + monkeypatch.setattr(job_db, "job_summary", summary_other_owner) with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}"): await WMSAccessPolicy.policy( WMS_POLICY_NAME, @@ -196,7 +196,7 @@ async def summary_other_owner(*args): async def summary_other_vo(*args): return [{"Owner": "preferred_username", "VO": "gridpp", "count": 3}] - monkeypatch.setattr(job_db, "summary", summary_other_vo) + monkeypatch.setattr(job_db, "job_summary", summary_other_vo) with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}"): await WMSAccessPolicy.policy( WMS_POLICY_NAME, @@ -210,7 +210,7 @@ async def summary_other_vo(*args): async def summary_other_vo(*args): return [{"Owner": "preferred_username", "VO": "lhcb", "count": 2}] - monkeypatch.setattr(job_db, "summary", summary_other_vo) + monkeypatch.setattr(job_db, "job_summary", summary_other_vo) with pytest.raises(HTTPException, match=f"{status.HTTP_403_FORBIDDEN}"): await WMSAccessPolicy.policy( WMS_POLICY_NAME, diff --git a/diracx-routers/tests/pilots/test_pilot_creation.py b/diracx-routers/tests/pilots/test_pilot_creation.py new file mode 100644 index 000000000..2171bbf9d --- /dev/null +++ b/diracx-routers/tests/pilots/test_pilot_creation.py @@ -0,0 +1,284 @@ +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest +from sqlalchemy import update + +from diracx.core.models import ( + PilotFieldsMapping, + PilotStatus, +) +from diracx.db.sql import PilotAgentsDB +from diracx.db.sql.pilots.schema import PilotAgents + +pytestmark = pytest.mark.enabled_dependencies( + [ + "PilotCredentialsAccessPolicy", + "DevelopmentSettings", + "AuthDB", + "AuthSettings", + "ConfigSource", + "BaseAccessPolicy", + "PilotAgentsDB", + "PilotManagementAccessPolicy", + "JobDB", + ] +) + +MAIN_VO = "lhcb" +N = 100 + + +@pytest.fixture +def normal_test_client(client_factory): + with client_factory.normal_user() as client: + yield client + + +async def test_create_pilots(normal_test_client): + # Lots of request, to validate that it returns the credentials in the same order as the input references + pilot_stamps = [f"stamps_{i}" for i in range(N)] + + # -------------- Bulk insert -------------- + body = {"pilot_stamps": pilot_stamps, "vo": MAIN_VO} + + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + # -------------- Register a pilot that already exists, and one that does not -------------- + + body = { + "pilot_stamps": [pilot_stamps[0], pilot_stamps[0] + "_new_one"], + "vo": MAIN_VO, + } + + r = normal_test_client.post( + "/api/pilots/", + json=body, + headers={ + "Content-Type": "application/json", + }, + ) + + assert r.status_code == 409 + assert ( + r.json()["detail"] + == f"Pilot (pilot_stamps: {{'{pilot_stamps[0]}'}}) already exists" + ) + + # -------------- Register a pilot that does not exists **but** was called before in an error -------------- + # To prove that, if I tried to register a pilot that does not exist with one that already exists, + # i can normally add the one that did not exist before (it should not have added it before) + body = {"pilot_stamps": [pilot_stamps[0] + "_new_one"], "vo": MAIN_VO} + + r = normal_test_client.post( + "/api/pilots/", + json=body, + headers={ + "Content-Type": "application/json", + }, + ) + + assert r.status_code == 200 + + +async def test_create_pilot_and_delete_it(normal_test_client): + pilot_stamp = "stamps_1" + + # -------------- Insert -------------- + body = {"pilot_stamps": [pilot_stamp], "vo": MAIN_VO} + + # Create a pilot + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + # -------------- Duplicate -------------- + # Duplicate because it exists, should have 409 + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 409, r.json() + + # -------------- Delete -------------- + params = {"pilot_stamps": [pilot_stamp]} + + # We delete the pilot + r = normal_test_client.delete( + "/api/pilots/", + params=params, + ) + + assert r.status_code == 204 + + # -------------- Insert -------------- + # Create a the same pilot, but works because it does not exist anymore + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + +async def test_create_pilot_and_modify_it(normal_test_client): + pilot_stamps = ["stamps_1", "stamp_2"] + + # -------------- Insert -------------- + body = {"pilot_stamps": pilot_stamps, "vo": MAIN_VO} + + # Create pilots + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + # -------------- Modify -------------- + # We modify only the first pilot + body = { + "pilot_stamps_to_fields_mapping": [ + PilotFieldsMapping( + PilotStamp=pilot_stamps[0], + BenchMark=1.0, + StatusReason="NewReason", + AccountingSent=True, + Status=PilotStatus.WAITING, + ).model_dump(exclude_unset=True) + ] + } + + r = normal_test_client.patch("/api/pilots/metadata", json=body) + + assert r.status_code == 204 + + body = { + "parameters": [], + "search": [], + "sort": [], + "distinct": True, + } + + r = normal_test_client.post("/api/pilots/search", json=body) + assert r.status_code == 200, r.json() + pilot1 = r.json()[0] + pilot2 = r.json()[1] + + assert pilot1["BenchMark"] == 1.0 + assert pilot1["StatusReason"] == "NewReason" + assert pilot1["AccountingSent"] + assert pilot1["Status"] == PilotStatus.WAITING + + assert pilot2["BenchMark"] != pilot1["BenchMark"] + assert pilot2["StatusReason"] != pilot1["StatusReason"] + assert pilot2["AccountingSent"] != pilot1["AccountingSent"] + assert pilot2["Status"] != pilot1["Status"] + + +@pytest.mark.asyncio +async def test_delete_pilots_by_age_and_stamp(normal_test_client): + # Generate 100 pilot stamps + pilot_stamps = [f"stamp_{i}" for i in range(100)] + + # -------------- Insert all pilots -------------- + body = {"pilot_stamps": pilot_stamps, "vo": MAIN_VO} + r = normal_test_client.post("/api/pilots/", json=body) + assert r.status_code == 200, r.json() + + # -------------- Modify last 50 pilots' fields -------------- + to_modify = pilot_stamps[50:] + mappings = [] + for idx, stamp in enumerate(to_modify): + # First 25 of modified set to ABORTED, others to WAITING + status = PilotStatus.ABORTED if idx < 25 else PilotStatus.WAITING + mapping = PilotFieldsMapping( + PilotStamp=stamp, + BenchMark=idx + 0.1, + StatusReason=f"Reason_{idx}", + AccountingSent=(idx % 2 == 0), + Status=status, + ).model_dump(exclude_unset=True) + mappings.append(mapping) + + r = normal_test_client.patch( + "/api/pilots/metadata", + json={"pilot_stamps_to_fields_mapping": mappings}, + ) + assert r.status_code == 204 + + # -------------- Directly set SubmissionTime to March 14, 2003 for last 50 -------------- + old_date = datetime(2003, 3, 14, tzinfo=timezone.utc) + # Access DB session from normal_test_client fixtures + db = normal_test_client.app.dependency_overrides[PilotAgentsDB.transaction].args[0] + + async with db: + stmt = ( + update(PilotAgents) + .where(PilotAgents.pilot_stamp.in_(to_modify)) + .values(SubmissionTime=old_date) + ) + await db.conn.execute(stmt) + await db.conn.commit() + + # -------------- Verify all 100 pilots exist -------------- + search_body = {"parameters": [], "search": [], "sort": [], "distinct": True} + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert r.status_code == 200, r.json() + assert len(r.json()) == 100 + + # -------------- 1) Delete only old aborted pilots (25 expected) -------------- + # age_in_days large enough to include 2003-03-14 + r = normal_test_client.delete( + "/api/pilots/", + params={"age_in_days": 15, "delete_only_aborted": True}, + ) + assert r.status_code == 204 + # Expect 75 remaining + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert len(r.json()) == 75 + + # -------------- 2) Delete all old pilots (remaining 25 old) -------------- + r = normal_test_client.delete( + "/api/pilots/", + params={"age_in_days": 15}, + ) + assert r.status_code == 204 + + # Expect 50 remaining + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert len(r.json()) == 50 + + # -------------- 3) Delete one recent pilot by stamp -------------- + one_stamp = pilot_stamps[10] + r = normal_test_client.delete("/api/pilots/", params={"pilot_stamps": [one_stamp]}) + assert r.status_code == 204 + # Expect 49 remaining + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert len(r.json()) == 49 + + # -------------- 4) Delete all remaining pilots -------------- + # Collect remaining stamps + remaining = [p["PilotStamp"] for p in r.json()] + r = normal_test_client.delete("/api/pilots/", params={"pilot_stamps": remaining}) + assert r.status_code == 204 + # Expect none remaining + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert r.status_code == 200 + assert len(r.json()) == 0 + + # -------------- 5) Attempt deleting unknown pilot, expect 400 -------------- + r = normal_test_client.delete( + "/api/pilots/", params={"pilot_stamps": ["unknown_stamp"]} + ) + assert r.status_code == 204 diff --git a/diracx-routers/tests/pilots/test_pilot_logging.py b/diracx-routers/tests/pilots/test_pilot_logging.py new file mode 100644 index 000000000..136160a0e --- /dev/null +++ b/diracx-routers/tests/pilots/test_pilot_logging.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import pytest +from fastapi.testclient import TestClient + +from diracx.core.exceptions import InvalidQueryError + +pytestmark = pytest.mark.enabled_dependencies( + [ + "AuthDB", + "AuthSettings", + "PilotAgentsDB", + "PilotLogsDB", + "DevelopmentSettings", + "PilotManagementAccessPolicy", + "LegacyPilotAccessPolicy", + ] +) + +N = 100 + + +@pytest.fixture +def test_client(client_factory): + with client_factory.unauthenticated() as client: + yield client + + +@pytest.fixture +def normal_test_client(client_factory): + with client_factory.normal_user() as client: + yield client + + +@pytest.fixture +def create_pilots(normal_test_client: TestClient): + # Add a pilot stamps + pilot_stamps = [f"stamp_{i}" for i in range(N)] + + body = {"vo": "lhcb", "pilot_stamps": pilot_stamps} + + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + assert r.status_code == 200, r.json() + + return pilot_stamps + + +@pytest.fixture +async def create_logs(create_pilots, normal_test_client): + for i, stamp in enumerate(create_pilots): + lines = [ + { + "message": stamp, + "timestamp": "2022-02-26 13:48:35.123456", + "scope": "PilotParams" if i % 2 == 1 else "Commands", + "severity": "DEBUG" if i % 2 == 0 else "INFO", + } + ] + msg_dict = {"lines": lines, "pilot_stamp": stamp} + r = normal_test_client.post("/api/pilots/legacy/message", json=msg_dict) + + assert r.status_code == 204, r.json() + # Return only stamps + return create_pilots + + +@pytest.fixture +async def search(normal_test_client): + async def _search( + parameters, conditions, sorts, distinct=False, page=1, per_page=100 + ): + body = { + "parameters": parameters, + "search": conditions, + "sort": sorts, + "distinct": distinct, + } + + params = {"per_page": per_page, "page": page} + + r = normal_test_client.post("/api/pilots/search/logs", json=body, params=params) + + if r.status_code == 400: + # If we have a status_code 400, that means that the query failed + raise InvalidQueryError() + + return r.json(), r.headers + + return _search + + +async def test_single_send_and_retrieve_logs(normal_test_client: TestClient): + # Add a pilot stamps + pilot_stamp = ["stamp_1"] + + # -------------- Bulk insert -------------- + body = {"vo": "lhcb", "pilot_stamps": pilot_stamp} + + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + msg = "JSON file loaded: pilot.json\nJSON file analysed: pilot.json" + # message dict + lines = [] + for line in msg.split("\n"): + lines.append( + { + "message": line, + "timestamp": "2022-02-26 13:48:35.123456", + "scope": "PilotParams", + "severity": "DEBUG", + } + ) + msg_dict = {"lines": lines, "pilot_stamp": "stamp_1"} + + # send message + r = normal_test_client.post("/api/pilots/legacy/message", json=msg_dict) + + assert r.status_code == 204, r.json() + # get the message back: + data = { + "search": [{"parameter": "PilotStamp", "operator": "eq", "value": "stamp_1"}] + } + r = normal_test_client.post("/api/pilots/search/logs", json=data) + assert r.status_code == 200, r.text + assert [hit["Message"] for hit in r.json()] == msg.split("\n") + + +async def test_query_invalid_stamp(create_logs, normal_test_client): + data = { + "search": [ + {"parameter": "PilotStamp", "operator": "eq", "value": "not_a_stamp"} + ] + } + r = normal_test_client.post("/api/pilots/search/logs", json=data) + assert r.status_code == 200, r.text + assert len(r.json()) == 0 + + +async def test_query_each_length(create_logs, normal_test_client): + for stamp in create_logs: + data = { + "search": [{"parameter": "PilotStamp", "operator": "eq", "value": stamp}] + } + r = normal_test_client.post("/api/pilots/search/logs", json=data) + assert r.status_code == 200, r.text + assert len(r.json()) == 1 + + +async def test_query_each_field(create_logs, normal_test_client): + for i, stamp in enumerate(create_logs): + data = { + "search": [{"parameter": "PilotStamp", "operator": "eq", "value": stamp}], + "sort": [{"parameter": "PilotStamp", "direction": "asc"}], + } + r = normal_test_client.post("/api/pilots/search/logs", json=data) + assert r.status_code == 200, r.text + assert len(r.json()) == 1 + + # Reminder: + + # "message": str(i), + # "timestamp": "2022-02-26 13:48:35.123456", + # "scope": "PilotParams" if i % 2 == 1 else "Commands", + # "severity": "DEBUG" if i % 2 == 0 else "INFO", + log = r.json()[0] + + assert log["Message"] == f"stamp_{i}" + assert log["Scope"] == ("PilotParams" if i % 2 == 1 else "Commands") + assert log["Severity"] == ("DEBUG" if i % 2 == 0 else "INFO") + + +async def test_search_pagination(create_logs, search): + """Test that we can search for logs.""" + # Search for the first 10 logs + result, headers = await search([], [], [], per_page=10, page=1) + assert "Content-Range" in headers + # Because Content-Range = f"logs {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert total == N + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 1 + + # Search for the second 10 logs + result, headers = await search([], [], [], per_page=10, page=2) + assert "Content-Range" in headers + # Because Content-Range = f"logs {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert total == N + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 11 + + # Search for the last 10 logs + result, headers = await search([], [], [], per_page=10, page=10) + assert "Content-Range" in headers + # Because Content-Range = f"logs {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 91 + + # Search for the second 50 logs + result, headers = await search([], [], [], per_page=50, page=2) + assert "Content-Range" in headers + # Because Content-Range = f"logs {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert result + assert len(result) == 50 + assert result[0]["PilotID"] == 51 + + # Invalid page number + result, headers = await search([], [], [], per_page=10, page=11) + assert "Content-Range" in headers + # Because Content-Range = f"logs {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert not result + + # Invalid page number + with pytest.raises(InvalidQueryError): + result = await search([], [], [], per_page=10, page=0) + + # Invalid per_page number + with pytest.raises(InvalidQueryError): + result = await search([], [], [], per_page=0, page=1) diff --git a/diracx-routers/tests/pilots/test_query.py b/diracx-routers/tests/pilots/test_query.py new file mode 100644 index 000000000..c6d5cedb4 --- /dev/null +++ b/diracx-routers/tests/pilots/test_query.py @@ -0,0 +1,414 @@ +"""Inspired by pilots and jobs db search tests.""" + +from __future__ import annotations + +import pytest +from fastapi.testclient import TestClient + +from diracx.core.exceptions import InvalidQueryError +from diracx.core.models import ( + PilotFieldsMapping, + PilotStatus, + ScalarSearchOperator, + ScalarSearchSpec, + SortDirection, + SortSpec, + VectorSearchOperator, + VectorSearchSpec, +) + +pytestmark = pytest.mark.enabled_dependencies( + [ + "AuthSettings", + "ConfigSource", + "DevelopmentSettings", + "PilotAgentsDB", + "PilotManagementAccessPolicy", + ] +) + + +@pytest.fixture +def normal_test_client(client_factory): + with client_factory.normal_user() as client: + yield client + + +MAIN_VO = "lhcb" +N = 100 + +PILOT_REASONS = [ + "I was sick", + "I can't, I have a pony.", + "I was shopping", + "I was sleeping", +] + +PILOT_STATUSES = list(PilotStatus) + + +@pytest.fixture +async def populated_pilot_client(normal_test_client): + pilot_stamps = [f"stamp_{i}" for i in range(1, N + 1)] + + # -------------- Bulk insert -------------- + body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps} + + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + body = { + "pilot_stamps_to_fields_mapping": [ + PilotFieldsMapping( + PilotStamp=pilot_stamp, + BenchMark=i**2, + StatusReason=PILOT_REASONS[i % len(PILOT_REASONS)], + AccountingSent=True, + Status=PILOT_STATUSES[i % len(PILOT_STATUSES)], + CurrentJobID=i, + Queue=f"queue_{i}", + ).model_dump(exclude_unset=True) + for i, pilot_stamp in enumerate(pilot_stamps) + ] + } + + r = normal_test_client.patch("/api/pilots/metadata", json=body) + + assert r.status_code == 204 + + yield normal_test_client + + +async def test_pilot_summary(populated_pilot_client: TestClient): + # Group by StatusReason + r = populated_pilot_client.post( + "/api/pilots/summary", + json={ + "grouping": ["StatusReason"], + }, + ) + + assert r.status_code == 200 + + assert sum([el["count"] for el in r.json()]) == N + assert len(r.json()) == len(PILOT_REASONS) + + # Group by CurrentJobID + r = populated_pilot_client.post( + "/api/pilots/summary", + json={ + "grouping": ["CurrentJobID"], + }, + ) + + assert r.status_code == 200 + + assert all(el["count"] == 1 for el in r.json()) + assert len(r.json()) == N + + # Group by CurrentJobID where BenchMark < 10^2 + r = populated_pilot_client.post( + "/api/pilots/summary", + json={ + "grouping": ["CurrentJobID"], + "search": [{"parameter": "BenchMark", "operator": "lt", "value": 10**2}], + }, + ) + + assert r.status_code == 200, r.json() + + assert all(el["count"] == 1 for el in r.json()) + assert len(r.json()) == 10 + + +@pytest.fixture +async def search(populated_pilot_client): + async def _search( + parameters, conditions, sorts, distinct=False, page=1, per_page=100 + ): + body = { + "parameters": parameters, + "search": conditions, + "sort": sorts, + "distinct": distinct, + } + + params = {"per_page": per_page, "page": page} + + r = populated_pilot_client.post("/api/pilots/search", json=body, params=params) + + if r.status_code == 400: + # If we have a status_code 400, that means that the query failed + raise InvalidQueryError() + + return r.json(), r.headers + + return _search + + +async def test_search_parameters(search): + """Test that we can search specific parameters for pilots.""" + # Search a specific parameter: PilotID + result, headers = await search(["PilotID"], [], []) + assert len(result) == N + assert result + for r in result: + assert r.keys() == {"PilotID"} + assert "Content-Range" not in headers + + # Search a specific parameter: Status + result, headers = await search(["Status"], [], []) + assert len(result) == N + assert result + for r in result: + assert r.keys() == {"Status"} + assert "Content-Range" not in headers + + # Search for multiple parameters: PilotID, Status + result, headers = await search(["PilotID", "Status"], [], []) + assert len(result) == N + assert result + for r in result: + assert r.keys() == {"PilotID", "Status"} + assert "Content-Range" not in headers + + # Search for a specific parameter but use distinct: Status + result, headers = await search(["Status"], [], [], distinct=True) + assert len(result) == len(PILOT_STATUSES) + assert result + assert "Content-Range" not in headers + + # Search for a non-existent parameter: Dummy + with pytest.raises(InvalidQueryError): + result, headers = await search(["Dummy"], [], []) + + +async def test_search_conditions(search): + """Test that we can search for specific pilots.""" + # Search a specific scalar condition: PilotID eq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=3 + ) + result, headers = await search([], [condition], []) + assert len(result) == 1 + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 3 + assert "Content-Range" not in headers + + # Search a specific scalar condition: PilotID lt 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.LESS_THAN, value=3 + ) + result, headers = await search([], [condition], []) + assert len(result) == 2 + assert result + assert len(result) == 2 + assert result[0]["PilotID"] == 1 + assert result[1]["PilotID"] == 2 + assert "Content-Range" not in headers + + # Search a specific scalar condition: PilotID neq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.NOT_EQUAL, value=3 + ) + result, headers = await search([], [condition], []) + assert len(result) == 99 + assert result + assert len(result) == 99 + assert all(r["PilotID"] != 3 for r in result) + assert "Content-Range" not in headers + + # Search a specific scalar condition: PilotID eq 5873 (does not exist) + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=5873 + ) + result, headers = await search([], [condition], []) + assert not result + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 3] + ) + result, headers = await search([], [condition], []) + assert len(result) == 3 + assert result + assert len(result) == 3 + assert all(r["PilotID"] in [1, 2, 3] for r in result) + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 5873] + ) + result, headers = await search([], [condition], []) + assert len(result) == 2 + assert result + assert len(result) == 2 + assert all(r["PilotID"] in [1, 2] for r in result) + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID not in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.NOT_IN, values=[1, 2, 3] + ) + result, headers = await search([], [condition], []) + assert len(result) == 97 + assert result + assert len(result) == 97 + assert all(r["PilotID"] not in [1, 2, 3] for r in result) + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID not in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", + operator=VectorSearchOperator.NOT_IN, + values=[1, 2, 5873], + ) + result, headers = await search([], [condition], []) + assert len(result) == 98 + assert result + assert len(result) == 98 + assert all(r["PilotID"] not in [1, 2] for r in result) + assert "Content-Range" not in headers + + # Search for multiple conditions based on different parameters: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotStamp", operator=ScalarSearchOperator.EQUAL, value="stamp_5" + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + result, headers = await search([], [condition1, condition2], []) + + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 5 + assert result[0]["PilotStamp"] == "stamp_5" + assert "Content-Range" not in headers + + # Search for multiple conditions based on the same parameter: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=70 + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + result, headers = await search([], [condition1, condition2], []) + assert len(result) == 0 + assert not result + assert "Content-Range" not in headers + + +async def test_search_sorts(search): + """Test that we can search for pilots and sort the results.""" + # Search and sort by PilotID in ascending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.ASC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == i + 1 + assert "Content-Range" not in headers + + # Search and sort by PilotID in descending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == N - i + assert "Content-Range" not in headers + + # Search and sort by PilotStamp in ascending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[2]["PilotStamp"] == "stamp_100" + assert result[12]["PilotStamp"] == "stamp_2" + assert "Content-Range" not in headers + + # Search and sort by PilotStamp in descending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.DESC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[97]["PilotStamp"] == "stamp_100" + assert result[87]["PilotStamp"] == "stamp_2" + assert "Content-Range" not in headers + + # Search and sort by PilotStamp in ascending order and PilotID in descending order + sort1 = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + sort2 = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + result, headers = await search([], [], [sort1, sort2]) + assert len(result) == N + assert result + assert result[0]["PilotStamp"] == "stamp_1" + assert result[0]["PilotID"] == 1 + assert result[99]["PilotStamp"] == "stamp_99" + assert result[99]["PilotID"] == 99 + assert "Content-Range" not in headers + + +async def test_search_pagination(search): + """Test that we can search for pilots.""" + # Search for the first 10 pilots + result, headers = await search([], [], [], per_page=10, page=1) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert total == N + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 1 + + # Search for the second 10 pilots + result, headers = await search([], [], [], per_page=10, page=2) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert total == N + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 11 + + # Search for the last 10 pilots + result, headers = await search([], [], [], per_page=10, page=10) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 91 + + # Search for the second 50 pilots + result, headers = await search([], [], [], per_page=50, page=2) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert result + assert len(result) == 50 + assert result[0]["PilotID"] == 51 + + # Invalid page number + result, headers = await search([], [], [], per_page=10, page=11) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert not result + + # Invalid page number + with pytest.raises(InvalidQueryError): + result = await search([], [], [], per_page=10, page=0) + + # Invalid per_page number + with pytest.raises(InvalidQueryError): + result = await search([], [], [], per_page=0, page=1) diff --git a/diracx-testing/src/diracx/testing/mock_osdb.py b/diracx-testing/src/diracx/testing/mock_osdb.py index 5f7fe7f93..6a87b0579 100644 --- a/diracx-testing/src/diracx/testing/mock_osdb.py +++ b/diracx-testing/src/diracx/testing/mock_osdb.py @@ -10,9 +10,10 @@ from functools import partial from typing import Any, AsyncIterator -from sqlalchemy import select +from sqlalchemy import func, select from sqlalchemy.dialects.sqlite import insert as sqlite_insert +from diracx.core.exceptions import InvalidQueryError from diracx.core.models import SearchSpec, SortSpec from diracx.db.sql import utils as sql_utils @@ -53,7 +54,11 @@ def __init__(self, connection_kwargs: dict[str, Any]) -> None: for field, field_type in self.fields.items(): match field_type["type"]: case "date": + # TODO: Warning, maybe this will crash? See date_nanos + # I needed to set Varchar because it is sent as 2022-06-15T10:12:52.382719622Z, and not datetime column_type = DateNowColumn + case "date_nanos": + column_type = partial(Column, type_=String(32)) case "long": column_type = partial(Column, type_=Integer) case "keyword": @@ -100,6 +105,21 @@ async def upsert(self, vo, doc_id, document) -> None: stmt = stmt.on_conflict_do_update(index_elements=["doc_id"], set_=values) await self._sql_db.conn.execute(stmt) + async def bulk_insert(self, index_name: str, docs: list[dict[str, Any]]) -> None: + async with self._sql_db: + rows = [] + for doc in docs: + # don't use doc_id column explicitly. This ensures that doc_id is unique. + values = {} + for key, value in doc.items(): + if key in self.fields: + values[key] = value + else: + values.setdefault("extra", {})[key] = value + rows.append(values) + stmt = sqlite_insert(self._table).values(rows) + await self._sql_db.conn.execute(stmt) + async def search( self, parameters: list[str] | None, @@ -135,8 +155,17 @@ async def search( self._table.columns.__getitem__, stmt, sorts ) + # Calculate total count before applying pagination + total_count_subquery = stmt.alias() + total_count_stmt = select(func.count()).select_from(total_count_subquery) + total = (await self._sql_db.conn.execute(total_count_stmt)).scalar_one() + # Apply pagination if page is not None: + if page < 1: + raise InvalidQueryError("Page must be a positive integer") + if per_page < 1: + raise InvalidQueryError("Per page must be a positive integer") stmt = stmt.offset((page - 1) * per_page).limit(per_page) results = [] @@ -151,7 +180,8 @@ async def search( if v is None: result.pop(k) results.append(result) - return results + + return total, results async def ping(self): async with self._sql_db: diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py index 65282efb6..321d3238c 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py @@ -15,10 +15,18 @@ from . import models as _models from ._configuration import DiracConfiguration from ._utils.serialization import Deserializer, Serializer -from .operations import AuthOperations, ConfigOperations, JobsOperations, LollygagOperations, WellKnownOperations - - -class Dirac: # pylint: disable=client-accepts-api-version-keyword +from .operations import ( + AuthOperations, + ConfigOperations, + JobsOperations, + LollygagOperations, + PilotsLegacyOperations, + PilotsOperations, + WellKnownOperations, +) + + +class Dirac: # pylint: disable=client-accepts-api-version-keyword,too-many-instance-attributes """Dirac. :ivar well_known: WellKnownOperations operations @@ -31,6 +39,10 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype jobs: _generated.operations.JobsOperations :ivar lollygag: LollygagOperations operations :vartype lollygag: _generated.operations.LollygagOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.operations.PilotsOperations + :ivar pilots_legacy: PilotsLegacyOperations operations + :vartype pilots_legacy: _generated.operations.PilotsLegacyOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -68,6 +80,8 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.lollygag = LollygagOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots_legacy = PilotsLegacyOperations(self._client, self._config, self._serialize, self._deserialize) def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: """Runs the network request through the client's chained policies. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py index d67986dae..07253331f 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py @@ -15,10 +15,18 @@ from .. import models as _models from .._utils.serialization import Deserializer, Serializer from ._configuration import DiracConfiguration -from .operations import AuthOperations, ConfigOperations, JobsOperations, LollygagOperations, WellKnownOperations - - -class Dirac: # pylint: disable=client-accepts-api-version-keyword +from .operations import ( + AuthOperations, + ConfigOperations, + JobsOperations, + LollygagOperations, + PilotsLegacyOperations, + PilotsOperations, + WellKnownOperations, +) + + +class Dirac: # pylint: disable=client-accepts-api-version-keyword,too-many-instance-attributes """Dirac. :ivar well_known: WellKnownOperations operations @@ -31,6 +39,10 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype jobs: _generated.aio.operations.JobsOperations :ivar lollygag: LollygagOperations operations :vartype lollygag: _generated.aio.operations.LollygagOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.aio.operations.PilotsOperations + :ivar pilots_legacy: PilotsLegacyOperations operations + :vartype pilots_legacy: _generated.aio.operations.PilotsLegacyOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -68,6 +80,8 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.lollygag = LollygagOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots_legacy = PilotsLegacyOperations(self._client, self._config, self._serialize, self._deserialize) def send_request( self, request: HttpRequest, *, stream: bool = False, **kwargs: Any diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py index 572930a93..759b5d4e6 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py @@ -15,6 +15,8 @@ from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore from ._operations import LollygagOperations # type: ignore +from ._operations import PilotsOperations # type: ignore +from ._operations import PilotsLegacyOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -26,6 +28,8 @@ "ConfigOperations", "JobsOperations", "LollygagOperations", + "PilotsOperations", + "PilotsLegacyOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py index 30d2e1c17..a8b08e1ac 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py @@ -55,6 +55,14 @@ build_lollygag_get_gubbins_secrets_request, build_lollygag_get_owner_object_request, build_lollygag_insert_owner_object_request, + build_pilots_add_pilot_stamps_request, + build_pilots_delete_pilots_request, + build_pilots_get_pilot_jobs_request, + build_pilots_legacy_send_message_request, + build_pilots_search_logs_request, + build_pilots_search_request, + build_pilots_summary_request, + build_pilots_update_pilot_fields_request, build_well_known_get_installation_metadata_request, build_well_known_get_jwks_request, build_well_known_get_openid_configuration_request, @@ -1829,7 +1837,7 @@ async def patch_metadata(self, body: Union[Dict[str, Dict[str, Any]], IO[bytes]] @overload async def search( self, - body: Optional[_models.JobSearchParams] = None, + body: Optional[_models.SearchParams] = None, *, page: int = 1, per_page: int = 100, @@ -1843,7 +1851,7 @@ async def search( **TODO: Add more docs**. :param body: Default value is None. - :type body: ~_generated.models.JobSearchParams + :type body: ~_generated.models.SearchParams :keyword page: Default value is 1. :paramtype page: int :keyword per_page: Default value is 100. @@ -1889,7 +1897,7 @@ async def search( @distributed_trace_async async def search( self, - body: Optional[Union[_models.JobSearchParams, IO[bytes]]] = None, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, *, page: int = 1, per_page: int = 100, @@ -1901,8 +1909,8 @@ async def search( **TODO: Add more docs**. - :param body: Is either a JobSearchParams type or a IO[bytes] type. Default value is None. - :type body: ~_generated.models.JobSearchParams or IO[bytes] + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] :keyword page: Default value is 1. :paramtype page: int :keyword per_page: Default value is 100. @@ -1932,7 +1940,7 @@ async def search( _content = body else: if body is not None: - _json = self._serialize.body(body, "JobSearchParams") + _json = self._serialize.body(body, "SearchParams") else: _json = None @@ -1971,14 +1979,14 @@ async def search( @overload async def summary( - self, body: _models.JobSummaryParams, *, content_type: str = "application/json", **kwargs: Any + self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any ) -> Any: """Summary. Show information suitable for plotting. :param body: Required. - :type body: ~_generated.models.JobSummaryParams + :type body: ~_generated.models.SummaryParams :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str @@ -2004,13 +2012,13 @@ async def summary(self, body: IO[bytes], *, content_type: str = "application/jso """ @distributed_trace_async - async def summary(self, body: Union[_models.JobSummaryParams, IO[bytes]], **kwargs: Any) -> Any: + async def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: """Summary. Show information suitable for plotting. - :param body: Is either a JobSummaryParams type or a IO[bytes] type. Required. - :type body: ~_generated.models.JobSummaryParams or IO[bytes] + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] :return: any :rtype: any :raises ~azure.core.exceptions.HttpResponseError: @@ -2035,7 +2043,7 @@ async def summary(self, body: Union[_models.JobSummaryParams, IO[bytes]], **kwar if isinstance(body, (IOBase, bytes)): _content = body else: - _json = self._serialize.body(body, "JobSummaryParams") + _json = self._serialize.body(body, "SummaryParams") _request = build_jobs_summary_request( content_type=content_type, @@ -2324,3 +2332,831 @@ async def get_gubbins_secrets(self, **kwargs: Any) -> Any: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + async def add_pilot_stamps( + self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + + _request = build_pilots_add_pilot_stamps_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def delete_pilots( + self, + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any + ) -> None: + """Delete Pilots. + + Endpoint to delete a pilot. + + Two features: + + + #. Or you provide pilot_stamps, so you can delete pilots by their stamp + #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. + :paramtype pilot_stamps: list[str] + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Default value is None. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake.This flag is only used for deletion by time. Default value + is False. + :paramtype delete_only_aborted: bool + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + async def update_pilot_fields( + self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def update_pilot_fields( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def update_pilot_fields( + self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + + _request = build_pilots_update_pilot_fields_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @distributed_trace_async + async def get_pilot_jobs( + self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any + ) -> List[int]: + """Get Pilot Jobs. + + Endpoint only for admins, to get jobs of a pilot. + + :keyword pilot_stamp: The stamp of the pilot. Default value is None. + :paramtype pilot_stamp: str + :keyword job_id: The ID of the job. Default value is None. + :paramtype job_id: int + :return: list of int + :rtype: list[int] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[List[int]] = kwargs.pop("cls", None) + + _request = build_pilots_get_pilot_jobs_request( + pilot_stamp=pilot_stamp, + job_id=job_id, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("[int]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + async def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + async def search_logs( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def search_logs( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def search_logs( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_logs_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + async def summary( + self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + +class PilotsLegacyOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`pilots_legacy` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + async def send_message( + self, body: _models.BodyPilotsLegacySendMessage, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsLegacySendMessage + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def send_message(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def send_message(self, body: Union[_models.BodyPilotsLegacySendMessage, IO[bytes]], **kwargs: Any) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Is either a BodyPilotsLegacySendMessage type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsLegacySendMessage or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsLegacySendMessage") + + _request = build_pilots_legacy_send_message_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py index 2c1fc99e9..c889fb017 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py @@ -14,6 +14,9 @@ from ._models import ( # type: ignore BodyAuthGetOidcToken, BodyAuthGetOidcTokenGrantType, + BodyPilotsAddPilotStamps, + BodyPilotsLegacySendMessage, + BodyPilotsUpdatePilotFields, ExtendedMetadata, GroupInfo, HTTPValidationError, @@ -21,20 +24,22 @@ InitiateDeviceFlowResponse, InsertedJob, JobCommand, - JobSearchParams, - JobSearchParamsSearchItem, JobStatusUpdate, - JobSummaryParams, - JobSummaryParamsSearchItem, + LogLine, OpenIDConfiguration, + PilotFieldsMapping, SandboxDownloadResponse, SandboxInfo, SandboxUploadResponse, ScalarSearchSpec, ScalarSearchSpecValue, + SearchParams, + SearchParamsSearchItem, SetJobStatusReturn, SetJobStatusReturnSuccess, SortSpec, + SummaryParams, + SummaryParamsSearchItem, SupportInfo, TokenResponse, UserInfoResponse, @@ -48,6 +53,7 @@ from ._enums import ( # type: ignore ChecksumAlgorithm, JobStatus, + PilotStatus, SandboxFormat, SandboxType, ScalarSearchOperator, @@ -61,6 +67,9 @@ __all__ = [ "BodyAuthGetOidcToken", "BodyAuthGetOidcTokenGrantType", + "BodyPilotsAddPilotStamps", + "BodyPilotsLegacySendMessage", + "BodyPilotsUpdatePilotFields", "ExtendedMetadata", "GroupInfo", "HTTPValidationError", @@ -68,20 +77,22 @@ "InitiateDeviceFlowResponse", "InsertedJob", "JobCommand", - "JobSearchParams", - "JobSearchParamsSearchItem", "JobStatusUpdate", - "JobSummaryParams", - "JobSummaryParamsSearchItem", + "LogLine", "OpenIDConfiguration", + "PilotFieldsMapping", "SandboxDownloadResponse", "SandboxInfo", "SandboxUploadResponse", "ScalarSearchSpec", "ScalarSearchSpecValue", + "SearchParams", + "SearchParamsSearchItem", "SetJobStatusReturn", "SetJobStatusReturnSuccess", "SortSpec", + "SummaryParams", + "SummaryParamsSearchItem", "SupportInfo", "TokenResponse", "UserInfoResponse", @@ -92,6 +103,7 @@ "VectorSearchSpecValues", "ChecksumAlgorithm", "JobStatus", + "PilotStatus", "SandboxFormat", "SandboxType", "ScalarSearchOperator", diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py index 8098c62f4..44da9887d 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py @@ -34,6 +34,19 @@ class JobStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): RESCHEDULED = "Rescheduled" +class PilotStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """PilotStatus.""" + + SUBMITTED = "Submitted" + WAITING = "Waiting" + RUNNING = "Running" + DONE = "Done" + FAILED = "Failed" + DELETED = "Deleted" + ABORTED = "Aborted" + UNKNOWN = "Unknown" + + class SandboxFormat(str, Enum, metaclass=CaseInsensitiveEnumMeta): """SandboxFormat.""" diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index 714f0317a..c597bfcc1 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -94,6 +94,144 @@ class BodyAuthGetOidcTokenGrantType(_serialization.Model): """OAuth2 Grant type.""" +class BodyPilotsAddPilotStamps(_serialization.Model): + """Body_pilots_add_pilot_stamps. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :vartype pilot_stamps: list[str] + :ivar vo: Pilot virtual organization. Required. + :vartype vo: str + :ivar grid_type: Grid type of the pilots. + :vartype grid_type: str + :ivar grid_site: Pilots grid site. + :vartype grid_site: str + :ivar destination_site: Pilots destination site. + :vartype destination_site: str + :ivar pilot_references: Association of a pilot reference with a pilot stamp. + :vartype pilot_references: dict[str, str] + :ivar pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", "Running", + "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :vartype pilot_status: str or ~_generated.models.PilotStatus + """ + + _validation = { + "pilot_stamps": {"required": True}, + "vo": {"required": True}, + } + + _attribute_map = { + "pilot_stamps": {"key": "pilot_stamps", "type": "[str]"}, + "vo": {"key": "vo", "type": "str"}, + "grid_type": {"key": "grid_type", "type": "str"}, + "grid_site": {"key": "grid_site", "type": "str"}, + "destination_site": {"key": "destination_site", "type": "str"}, + "pilot_references": {"key": "pilot_references", "type": "{str}"}, + "pilot_status": {"key": "pilot_status", "type": "str"}, + } + + def __init__( + self, + *, + pilot_stamps: List[str], + vo: str, + grid_type: str = "Dirac", + grid_site: str = "Unknown", + destination_site: str = "NotAssigned", + pilot_references: Optional[Dict[str, str]] = None, + pilot_status: Optional[Union[str, "_models.PilotStatus"]] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :paramtype pilot_stamps: list[str] + :keyword vo: Pilot virtual organization. Required. + :paramtype vo: str + :keyword grid_type: Grid type of the pilots. + :paramtype grid_type: str + :keyword grid_site: Pilots grid site. + :paramtype grid_site: str + :keyword destination_site: Pilots destination site. + :paramtype destination_site: str + :keyword pilot_references: Association of a pilot reference with a pilot stamp. + :paramtype pilot_references: dict[str, str] + :keyword pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", + "Running", "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype pilot_status: str or ~_generated.models.PilotStatus + """ + super().__init__(**kwargs) + self.pilot_stamps = pilot_stamps + self.vo = vo + self.grid_type = grid_type + self.grid_site = grid_site + self.destination_site = destination_site + self.pilot_references = pilot_references + self.pilot_status = pilot_status + + +class BodyPilotsLegacySendMessage(_serialization.Model): + """Body_pilots/legacy_send_message. + + All required parameters must be populated in order to send to server. + + :ivar lines: Message from the pilot to the logging system. Required. + :vartype lines: list[~_generated.models.LogLine] + :ivar pilot_stamp: PilotStamp, required as legacy pilots do not have a token with stamp in it. + Required. + :vartype pilot_stamp: str + """ + + _validation = { + "lines": {"required": True}, + "pilot_stamp": {"required": True}, + } + + _attribute_map = { + "lines": {"key": "lines", "type": "[LogLine]"}, + "pilot_stamp": {"key": "pilot_stamp", "type": "str"}, + } + + def __init__(self, *, lines: List["_models.LogLine"], pilot_stamp: str, **kwargs: Any) -> None: + """ + :keyword lines: Message from the pilot to the logging system. Required. + :paramtype lines: list[~_generated.models.LogLine] + :keyword pilot_stamp: PilotStamp, required as legacy pilots do not have a token with stamp in + it. Required. + :paramtype pilot_stamp: str + """ + super().__init__(**kwargs) + self.lines = lines + self.pilot_stamp = pilot_stamp + + +class BodyPilotsUpdatePilotFields(_serialization.Model): + """Body_pilots_update_pilot_fields. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. Required. + :vartype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + + _validation = { + "pilot_stamps_to_fields_mapping": {"required": True}, + } + + _attribute_map = { + "pilot_stamps_to_fields_mapping": {"key": "pilot_stamps_to_fields_mapping", "type": "[PilotFieldsMapping]"}, + } + + def __init__(self, *, pilot_stamps_to_fields_mapping: List["_models.PilotFieldsMapping"], **kwargs: Any) -> None: + """ + :keyword pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. + Required. + :paramtype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + super().__init__(**kwargs) + self.pilot_stamps_to_fields_mapping = pilot_stamps_to_fields_mapping + + class ExtendedMetadata(_serialization.Model): """ExtendedMetadata. @@ -405,56 +543,6 @@ def __init__(self, *, job_id: int, command: str, arguments: Optional[str] = None self.arguments = arguments -class JobSearchParams(_serialization.Model): - """JobSearchParams. - - :ivar parameters: Parameters. - :vartype parameters: list[str] - :ivar search: Search. - :vartype search: list[~_generated.models.JobSearchParamsSearchItem] - :ivar sort: Sort. - :vartype sort: list[~_generated.models.SortSpec] - :ivar distinct: Distinct. - :vartype distinct: bool - """ - - _attribute_map = { - "parameters": {"key": "parameters", "type": "[str]"}, - "search": {"key": "search", "type": "[JobSearchParamsSearchItem]"}, - "sort": {"key": "sort", "type": "[SortSpec]"}, - "distinct": {"key": "distinct", "type": "bool"}, - } - - def __init__( - self, - *, - parameters: Optional[List[str]] = None, - search: List["_models.JobSearchParamsSearchItem"] = [], - sort: List["_models.SortSpec"] = [], - distinct: bool = False, - **kwargs: Any - ) -> None: - """ - :keyword parameters: Parameters. - :paramtype parameters: list[str] - :keyword search: Search. - :paramtype search: list[~_generated.models.JobSearchParamsSearchItem] - :keyword sort: Sort. - :paramtype sort: list[~_generated.models.SortSpec] - :keyword distinct: Distinct. - :paramtype distinct: bool - """ - super().__init__(**kwargs) - self.parameters = parameters - self.search = search - self.sort = sort - self.distinct = distinct - - -class JobSearchParamsSearchItem(_serialization.Model): - """JobSearchParamsSearchItem.""" - - class JobStatusUpdate(_serialization.Model): """JobStatusUpdate. @@ -505,42 +593,51 @@ def __init__( self.source = source -class JobSummaryParams(_serialization.Model): - """JobSummaryParams. +class LogLine(_serialization.Model): + """LogLine. All required parameters must be populated in order to send to server. - :ivar grouping: Grouping. Required. - :vartype grouping: list[str] - :ivar search: Search. - :vartype search: list[~_generated.models.JobSummaryParamsSearchItem] + :ivar timestamp: Timestamp. Required. + :vartype timestamp: str + :ivar severity: Severity. Required. + :vartype severity: str + :ivar message: Message. Required. + :vartype message: str + :ivar scope: Scope. Required. + :vartype scope: str """ _validation = { - "grouping": {"required": True}, + "timestamp": {"required": True}, + "severity": {"required": True}, + "message": {"required": True}, + "scope": {"required": True}, } _attribute_map = { - "grouping": {"key": "grouping", "type": "[str]"}, - "search": {"key": "search", "type": "[JobSummaryParamsSearchItem]"}, + "timestamp": {"key": "timestamp", "type": "str"}, + "severity": {"key": "severity", "type": "str"}, + "message": {"key": "message", "type": "str"}, + "scope": {"key": "scope", "type": "str"}, } - def __init__( - self, *, grouping: List[str], search: List["_models.JobSummaryParamsSearchItem"] = [], **kwargs: Any - ) -> None: + def __init__(self, *, timestamp: str, severity: str, message: str, scope: str, **kwargs: Any) -> None: """ - :keyword grouping: Grouping. Required. - :paramtype grouping: list[str] - :keyword search: Search. - :paramtype search: list[~_generated.models.JobSummaryParamsSearchItem] + :keyword timestamp: Timestamp. Required. + :paramtype timestamp: str + :keyword severity: Severity. Required. + :paramtype severity: str + :keyword message: Message. Required. + :paramtype message: str + :keyword scope: Scope. Required. + :paramtype scope: str """ super().__init__(**kwargs) - self.grouping = grouping - self.search = search - - -class JobSummaryParamsSearchItem(_serialization.Model): - """JobSummaryParamsSearchItem.""" + self.timestamp = timestamp + self.severity = severity + self.message = message + self.scope = scope class OpenIDConfiguration(_serialization.Model): @@ -676,6 +773,102 @@ def __init__( self.code_challenge_methods_supported = code_challenge_methods_supported +class PilotFieldsMapping(_serialization.Model): + """All the fields that a user can modify on a Pilot (except PilotStamp). + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: Pilotstamp. Required. + :vartype pilot_stamp: str + :ivar status_reason: Statusreason. + :vartype status_reason: str + :ivar status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :vartype status: str or ~_generated.models.PilotStatus + :ivar bench_mark: Benchmark. + :vartype bench_mark: float + :ivar destination_site: Destinationsite. + :vartype destination_site: str + :ivar queue: Queue. + :vartype queue: str + :ivar grid_site: Gridsite. + :vartype grid_site: str + :ivar grid_type: Gridtype. + :vartype grid_type: str + :ivar accounting_sent: Accountingsent. + :vartype accounting_sent: bool + :ivar current_job_id: Currentjobid. + :vartype current_job_id: int + """ + + _validation = { + "pilot_stamp": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "PilotStamp", "type": "str"}, + "status_reason": {"key": "StatusReason", "type": "str"}, + "status": {"key": "Status", "type": "str"}, + "bench_mark": {"key": "BenchMark", "type": "float"}, + "destination_site": {"key": "DestinationSite", "type": "str"}, + "queue": {"key": "Queue", "type": "str"}, + "grid_site": {"key": "GridSite", "type": "str"}, + "grid_type": {"key": "GridType", "type": "str"}, + "accounting_sent": {"key": "AccountingSent", "type": "bool"}, + "current_job_id": {"key": "CurrentJobID", "type": "int"}, + } + + def __init__( + self, + *, + pilot_stamp: str, + status_reason: Optional[str] = None, + status: Optional[Union[str, "_models.PilotStatus"]] = None, + bench_mark: Optional[float] = None, + destination_site: Optional[str] = None, + queue: Optional[str] = None, + grid_site: Optional[str] = None, + grid_type: Optional[str] = None, + accounting_sent: Optional[bool] = None, + current_job_id: Optional[int] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamp: Pilotstamp. Required. + :paramtype pilot_stamp: str + :keyword status_reason: Statusreason. + :paramtype status_reason: str + :keyword status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype status: str or ~_generated.models.PilotStatus + :keyword bench_mark: Benchmark. + :paramtype bench_mark: float + :keyword destination_site: Destinationsite. + :paramtype destination_site: str + :keyword queue: Queue. + :paramtype queue: str + :keyword grid_site: Gridsite. + :paramtype grid_site: str + :keyword grid_type: Gridtype. + :paramtype grid_type: str + :keyword accounting_sent: Accountingsent. + :paramtype accounting_sent: bool + :keyword current_job_id: Currentjobid. + :paramtype current_job_id: int + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.status_reason = status_reason + self.status = status + self.bench_mark = bench_mark + self.destination_site = destination_site + self.queue = queue + self.grid_site = grid_site + self.grid_type = grid_type + self.accounting_sent = accounting_sent + self.current_job_id = current_job_id + + class SandboxDownloadResponse(_serialization.Model): """SandboxDownloadResponse. @@ -857,6 +1050,56 @@ class ScalarSearchSpecValue(_serialization.Model): """Value.""" +class SearchParams(_serialization.Model): + """SearchParams. + + :ivar parameters: Parameters. + :vartype parameters: list[str] + :ivar search: Search. + :vartype search: list[~_generated.models.SearchParamsSearchItem] + :ivar sort: Sort. + :vartype sort: list[~_generated.models.SortSpec] + :ivar distinct: Distinct. + :vartype distinct: bool + """ + + _attribute_map = { + "parameters": {"key": "parameters", "type": "[str]"}, + "search": {"key": "search", "type": "[SearchParamsSearchItem]"}, + "sort": {"key": "sort", "type": "[SortSpec]"}, + "distinct": {"key": "distinct", "type": "bool"}, + } + + def __init__( + self, + *, + parameters: Optional[List[str]] = None, + search: List["_models.SearchParamsSearchItem"] = [], + sort: List["_models.SortSpec"] = [], + distinct: bool = False, + **kwargs: Any + ) -> None: + """ + :keyword parameters: Parameters. + :paramtype parameters: list[str] + :keyword search: Search. + :paramtype search: list[~_generated.models.SearchParamsSearchItem] + :keyword sort: Sort. + :paramtype sort: list[~_generated.models.SortSpec] + :keyword distinct: Distinct. + :paramtype distinct: bool + """ + super().__init__(**kwargs) + self.parameters = parameters + self.search = search + self.sort = sort + self.distinct = distinct + + +class SearchParamsSearchItem(_serialization.Model): + """SearchParamsSearchItem.""" + + class SetJobStatusReturn(_serialization.Model): """SetJobStatusReturn. @@ -1000,6 +1243,44 @@ def __init__(self, *, parameter: str, direction: Union[str, "_models.SortDirecti self.direction = direction +class SummaryParams(_serialization.Model): + """SummaryParams. + + All required parameters must be populated in order to send to server. + + :ivar grouping: Grouping. Required. + :vartype grouping: list[str] + :ivar search: Search. + :vartype search: list[~_generated.models.SummaryParamsSearchItem] + """ + + _validation = { + "grouping": {"required": True}, + } + + _attribute_map = { + "grouping": {"key": "grouping", "type": "[str]"}, + "search": {"key": "search", "type": "[SummaryParamsSearchItem]"}, + } + + def __init__( + self, *, grouping: List[str], search: List["_models.SummaryParamsSearchItem"] = [], **kwargs: Any + ) -> None: + """ + :keyword grouping: Grouping. Required. + :paramtype grouping: list[str] + :keyword search: Search. + :paramtype search: list[~_generated.models.SummaryParamsSearchItem] + """ + super().__init__(**kwargs) + self.grouping = grouping + self.search = search + + +class SummaryParamsSearchItem(_serialization.Model): + """SummaryParamsSearchItem.""" + + class SupportInfo(_serialization.Model): """SupportInfo. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py index 572930a93..759b5d4e6 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py @@ -15,6 +15,8 @@ from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore from ._operations import LollygagOperations # type: ignore +from ._operations import PilotsOperations # type: ignore +from ._operations import PilotsLegacyOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -26,6 +28,8 @@ "ConfigOperations", "JobsOperations", "LollygagOperations", + "PilotsOperations", + "PilotsLegacyOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py index 4e429a056..0db5ffa19 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py @@ -639,6 +639,162 @@ def build_lollygag_get_gubbins_secrets_request(**kwargs: Any) -> HttpRequest: # return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) +def build_pilots_add_pilot_stamps_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + +def build_pilots_delete_pilots_request( + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any +) -> HttpRequest: + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + # Construct URL + _url = "/api/pilots/" + + # Construct parameters + if pilot_stamps is not None: + _params["pilot_stamps"] = _SERIALIZER.query("pilot_stamps", pilot_stamps, "[str]") + if age_in_days is not None: + _params["age_in_days"] = _SERIALIZER.query("age_in_days", age_in_days, "int") + if delete_only_aborted is not None: + _params["delete_only_aborted"] = _SERIALIZER.query("delete_only_aborted", delete_only_aborted, "bool") + + return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) + + +def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + # Construct URL + _url = "/api/pilots/metadata" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + + return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) + + +def build_pilots_get_pilot_jobs_request( + *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/jobs" + + # Construct parameters + if pilot_stamp is not None: + _params["pilot_stamp"] = _SERIALIZER.query("pilot_stamp", pilot_stamp, "str") + if job_id is not None: + _params["job_id"] = _SERIALIZER.query("job_id", job_id, "int") + + # Construct headers + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/search" + + # Construct parameters + if page is not None: + _params["page"] = _SERIALIZER.query("page", page, "int") + if per_page is not None: + _params["per_page"] = _SERIALIZER.query("per_page", per_page, "int") + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_pilots_search_logs_request(*, page: int = 1, per_page: int = 100, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/search/logs" + + # Construct parameters + if page is not None: + _params["page"] = _SERIALIZER.query("page", page, "int") + if per_page is not None: + _params["per_page"] = _SERIALIZER.query("per_page", per_page, "int") + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_pilots_summary_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/summary" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + +def build_pilots_legacy_send_message_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + # Construct URL + _url = "/api/pilots/legacy/message" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -2400,7 +2556,7 @@ def patch_metadata( # pylint: disable=inconsistent-return-statements @overload def search( self, - body: Optional[_models.JobSearchParams] = None, + body: Optional[_models.SearchParams] = None, *, page: int = 1, per_page: int = 100, @@ -2414,7 +2570,7 @@ def search( **TODO: Add more docs**. :param body: Default value is None. - :type body: ~_generated.models.JobSearchParams + :type body: ~_generated.models.SearchParams :keyword page: Default value is 1. :paramtype page: int :keyword per_page: Default value is 100. @@ -2460,7 +2616,7 @@ def search( @distributed_trace def search( self, - body: Optional[Union[_models.JobSearchParams, IO[bytes]]] = None, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, *, page: int = 1, per_page: int = 100, @@ -2472,8 +2628,8 @@ def search( **TODO: Add more docs**. - :param body: Is either a JobSearchParams type or a IO[bytes] type. Default value is None. - :type body: ~_generated.models.JobSearchParams or IO[bytes] + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] :keyword page: Default value is 1. :paramtype page: int :keyword per_page: Default value is 100. @@ -2503,7 +2659,7 @@ def search( _content = body else: if body is not None: - _json = self._serialize.body(body, "JobSearchParams") + _json = self._serialize.body(body, "SearchParams") else: _json = None @@ -2541,13 +2697,13 @@ def search( return deserialized # type: ignore @overload - def summary(self, body: _models.JobSummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: + def summary(self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: """Summary. Show information suitable for plotting. :param body: Required. - :type body: ~_generated.models.JobSummaryParams + :type body: ~_generated.models.SummaryParams :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str @@ -2573,13 +2729,13 @@ def summary(self, body: IO[bytes], *, content_type: str = "application/json", ** """ @distributed_trace - def summary(self, body: Union[_models.JobSummaryParams, IO[bytes]], **kwargs: Any) -> Any: + def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: """Summary. Show information suitable for plotting. - :param body: Is either a JobSummaryParams type or a IO[bytes] type. Required. - :type body: ~_generated.models.JobSummaryParams or IO[bytes] + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] :return: any :rtype: any :raises ~azure.core.exceptions.HttpResponseError: @@ -2604,7 +2760,7 @@ def summary(self, body: Union[_models.JobSummaryParams, IO[bytes]], **kwargs: An if isinstance(body, (IOBase, bytes)): _content = body else: - _json = self._serialize.body(body, "JobSummaryParams") + _json = self._serialize.body(body, "SummaryParams") _request = build_jobs_summary_request( content_type=content_type, @@ -2893,3 +3049,829 @@ def get_gubbins_secrets(self, **kwargs: Any) -> Any: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + def add_pilot_stamps( + self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + + _request = build_pilots_add_pilot_stamps_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def delete_pilots( # pylint: disable=inconsistent-return-statements + self, + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any + ) -> None: + """Delete Pilots. + + Endpoint to delete a pilot. + + Two features: + + + #. Or you provide pilot_stamps, so you can delete pilots by their stamp + #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. + :paramtype pilot_stamps: list[str] + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Default value is None. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake.This flag is only used for deletion by time. Default value + is False. + :paramtype delete_only_aborted: bool + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + def update_pilot_fields( + self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def update_pilot_fields(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def update_pilot_fields( # pylint: disable=inconsistent-return-statements + self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + + _request = build_pilots_update_pilot_fields_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @distributed_trace + def get_pilot_jobs( + self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any + ) -> List[int]: + """Get Pilot Jobs. + + Endpoint only for admins, to get jobs of a pilot. + + :keyword pilot_stamp: The stamp of the pilot. Default value is None. + :paramtype pilot_stamp: str + :keyword job_id: The ID of the job. Default value is None. + :paramtype job_id: int + :return: list of int + :rtype: list[int] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[List[int]] = kwargs.pop("cls", None) + + _request = build_pilots_get_pilot_jobs_request( + pilot_stamp=pilot_stamp, + job_id=job_id, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("[int]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + def search_logs( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def search_logs( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def search_logs( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search Logs. + + Search Logs. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_logs_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + def summary(self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + +class PilotsLegacyOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`pilots_legacy` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + def send_message( + self, body: _models.BodyPilotsLegacySendMessage, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsLegacySendMessage + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def send_message(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def send_message( # pylint: disable=inconsistent-return-statements + self, body: Union[_models.BodyPilotsLegacySendMessage, IO[bytes]], **kwargs: Any + ) -> None: + """Send Message. + + Send logs with legacy pilot. + + :param body: Is either a BodyPilotsLegacySendMessage type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsLegacySendMessage or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsLegacySendMessage") + + _request = build_pilots_legacy_send_message_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore diff --git a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py index 467577394..a32823a3a 100644 --- a/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py +++ b/extensions/gubbins/gubbins-db/src/gubbins/db/sql/lollygag/db.py @@ -21,7 +21,7 @@ class LollygagDB(BaseSQLDB): # This needs to be here for the BaseSQLDB to create the engine metadata = LollygagDBBase.metadata - async def summary(self, group_by, search) -> list[dict[str, str | int]]: + async def lollygag_summary(self, group_by, search) -> list[dict[str, str | int]]: columns = [Cars.__table__.columns[x] for x in group_by] stmt = select(*columns, func.count(Cars.license_plate).label("count")) diff --git a/extensions/gubbins/gubbins-db/tests/test_lollygag_db.py b/extensions/gubbins/gubbins-db/tests/test_lollygag_db.py index b5ff7b84e..43d69e91e 100644 --- a/extensions/gubbins/gubbins-db/tests/test_lollygag_db.py +++ b/extensions/gubbins/gubbins-db/tests/test_lollygag_db.py @@ -31,7 +31,7 @@ async def test_insert_and_summary(lollygag_db: LollygagDB): # So it is important to write test this way async with lollygag_db as lollygag_db: # First we check that the DB is empty - result = await lollygag_db.summary(["Model"], []) + result = await lollygag_db.lollygag_summary(["Model"], []) assert not result # Now we add some data in the DB @@ -51,13 +51,13 @@ async def test_insert_and_summary(lollygag_db: LollygagDB): # Check that there are now 10 cars assigned to a single driver async with lollygag_db as lollygag_db: - result = await lollygag_db.summary(["OwnerID"], []) + result = await lollygag_db.lollygag_summary(["OwnerID"], []) assert result[0]["count"] == 10 # Test the selection async with lollygag_db as lollygag_db: - result = await lollygag_db.summary( + result = await lollygag_db.lollygag_summary( ["OwnerID"], [{"parameter": "Model", "operator": "eq", "value": "model_1"}] ) @@ -65,7 +65,7 @@ async def test_insert_and_summary(lollygag_db: LollygagDB): async with lollygag_db as lollygag_db: with pytest.raises(InvalidQueryError): - result = await lollygag_db.summary( + result = await lollygag_db.lollygag_summary( ["OwnerID"], [ {