From 3f9b65f75fbbe17a678ad84216643775bccad915 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 8 May 2025 14:07:16 -0700 Subject: [PATCH 01/27] chore: updated gapic layer for execute_query --- .../services/firestore/async_client.py | 107 +++ .../firestore_v1/services/firestore/client.py | 105 +++ .../services/firestore/transports/base.py | 17 + .../services/firestore/transports/grpc.py | 28 + .../firestore/transports/grpc_asyncio.py | 33 + .../services/firestore/transports/rest.py | 266 +++++- .../firestore/transports/rest_base.py | 120 ++- google/cloud/firestore_v1/types/__init__.py | 16 + google/cloud/firestore_v1/types/document.py | 165 ++++ .../cloud/firestore_v1/types/explain_stats.py | 53 ++ google/cloud/firestore_v1/types/firestore.py | 145 +++ google/cloud/firestore_v1/types/pipeline.py | 61 ++ .../unit/gapic/firestore_v1/test_firestore.py | 855 ++++++++++++++++-- 13 files changed, 1815 insertions(+), 156 deletions(-) create mode 100644 google/cloud/firestore_v1/types/explain_stats.py create mode 100644 google/cloud/firestore_v1/types/pipeline.py diff --git a/google/cloud/firestore_v1/services/firestore/async_client.py b/google/cloud/firestore_v1/services/firestore/async_client.py index 56cf7d3af..916914969 100644 --- a/google/cloud/firestore_v1/services/firestore/async_client.py +++ b/google/cloud/firestore_v1/services/firestore/async_client.py @@ -52,6 +52,7 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import query_profile @@ -236,6 +237,9 @@ def __init__( If a Callable is given, it will be called with the same set of initialization arguments as used in the FirestoreTransport constructor. If set to None, a transport is chosen automatically. + NOTE: "rest" transport functionality is currently in a + beta state (preview). We welcome your feedback via an + issue in this library's source repository. client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the client. @@ -1247,6 +1251,109 @@ async def sample_run_query(): # Done; return the response. return response + def execute_pipeline( + self, + request: Optional[Union[firestore.ExecutePipelineRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> Awaitable[AsyncIterable[firestore.ExecutePipelineResponse]]: + r"""Executes a pipeline query. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import firestore_v1 + + async def sample_execute_pipeline(): + # Create a client + client = firestore_v1.FirestoreAsyncClient() + + # Initialize request argument(s) + structured_pipeline = firestore_v1.StructuredPipeline() + structured_pipeline.pipeline.stages.name = "name_value" + + request = firestore_v1.ExecutePipelineRequest( + structured_pipeline=structured_pipeline, + transaction=b'transaction_blob', + database="database_value", + ) + + # Make the request + stream = await client.execute_pipeline(request=request) + + # Handle the response + async for response in stream: + print(response) + + Args: + request (Optional[Union[google.cloud.firestore_v1.types.ExecutePipelineRequest, dict]]): + The request object. The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + AsyncIterable[google.cloud.firestore_v1.types.ExecutePipelineResponse]: + The response for [Firestore.Execute][]. + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, firestore.ExecutePipelineRequest): + request = firestore.ExecutePipelineRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.execute_pipeline + ] + + header_params = {} + + routing_param_regex = re.compile("^projects/(?P[^/]+)(?:/.*)?$") + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("project_id"): + header_params["project_id"] = regex_match.group("project_id") + + routing_param_regex = re.compile( + "^projects/[^/]+/databases/(?P[^/]+)(?:/.*)?$" + ) + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("database_id"): + header_params["database_id"] = regex_match.group("database_id") + + if header_params: + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(header_params), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + def run_aggregation_query( self, request: Optional[Union[firestore.RunAggregationQueryRequest, dict]] = None, diff --git a/google/cloud/firestore_v1/services/firestore/client.py b/google/cloud/firestore_v1/services/firestore/client.py index 1fb800e61..340cd5ef2 100644 --- a/google/cloud/firestore_v1/services/firestore/client.py +++ b/google/cloud/firestore_v1/services/firestore/client.py @@ -67,6 +67,7 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats from google.cloud.firestore_v1.types import firestore from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import query_profile @@ -551,6 +552,9 @@ def __init__( If a Callable is given, it will be called with the same set of initialization arguments as used in the FirestoreTransport constructor. If set to None, a transport is chosen automatically. + NOTE: "rest" transport functionality is currently in a + beta state (preview). We welcome your feedback via an + issue in this library's source repository. client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the client. @@ -1630,6 +1634,107 @@ def sample_run_query(): # Done; return the response. return response + def execute_pipeline( + self, + request: Optional[Union[firestore.ExecutePipelineRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> Iterable[firestore.ExecutePipelineResponse]: + r"""Executes a pipeline query. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import firestore_v1 + + def sample_execute_pipeline(): + # Create a client + client = firestore_v1.FirestoreClient() + + # Initialize request argument(s) + structured_pipeline = firestore_v1.StructuredPipeline() + structured_pipeline.pipeline.stages.name = "name_value" + + request = firestore_v1.ExecutePipelineRequest( + structured_pipeline=structured_pipeline, + transaction=b'transaction_blob', + database="database_value", + ) + + # Make the request + stream = client.execute_pipeline(request=request) + + # Handle the response + for response in stream: + print(response) + + Args: + request (Union[google.cloud.firestore_v1.types.ExecutePipelineRequest, dict]): + The request object. The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + Iterable[google.cloud.firestore_v1.types.ExecutePipelineResponse]: + The response for [Firestore.Execute][]. + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, firestore.ExecutePipelineRequest): + request = firestore.ExecutePipelineRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.execute_pipeline] + + header_params = {} + + routing_param_regex = re.compile("^projects/(?P[^/]+)(?:/.*)?$") + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("project_id"): + header_params["project_id"] = regex_match.group("project_id") + + routing_param_regex = re.compile( + "^projects/[^/]+/databases/(?P[^/]+)(?:/.*)?$" + ) + regex_match = routing_param_regex.match(request.database) + if regex_match and regex_match.group("database_id"): + header_params["database_id"] = regex_match.group("database_id") + + if header_params: + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(header_params), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + def run_aggregation_query( self, request: Optional[Union[firestore.RunAggregationQueryRequest, dict]] = None, diff --git a/google/cloud/firestore_v1/services/firestore/transports/base.py b/google/cloud/firestore_v1/services/firestore/transports/base.py index 862b098d1..50e0b6dd3 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/base.py @@ -286,6 +286,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=300.0, client_info=client_info, ), + self.execute_pipeline: gapic_v1.method.wrap_method( + self.execute_pipeline, + default_timeout=None, + client_info=client_info, + ), self.run_aggregation_query: gapic_v1.method.wrap_method( self.run_aggregation_query, default_retry=retries.Retry( @@ -509,6 +514,18 @@ def run_query( ]: raise NotImplementedError() + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], + Union[ + firestore.ExecutePipelineResponse, + Awaitable[firestore.ExecutePipelineResponse], + ], + ]: + raise NotImplementedError() + @property def run_aggregation_query( self, diff --git a/google/cloud/firestore_v1/services/firestore/transports/grpc.py b/google/cloud/firestore_v1/services/firestore/transports/grpc.py index c302a73c2..2a8f4caf9 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/grpc.py +++ b/google/cloud/firestore_v1/services/firestore/transports/grpc.py @@ -571,6 +571,34 @@ def run_query( ) return self._stubs["run_query"] + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], firestore.ExecutePipelineResponse + ]: + r"""Return a callable for the execute pipeline method over gRPC. + + Executes a pipeline query. + + Returns: + Callable[[~.ExecutePipelineRequest], + ~.ExecutePipelineResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "execute_pipeline" not in self._stubs: + self._stubs["execute_pipeline"] = self._logged_channel.unary_stream( + "/google.firestore.v1.Firestore/ExecutePipeline", + request_serializer=firestore.ExecutePipelineRequest.serialize, + response_deserializer=firestore.ExecutePipelineResponse.deserialize, + ) + return self._stubs["execute_pipeline"] + @property def run_aggregation_query( self, diff --git a/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py b/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py index f46162296..8801dc45a 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py +++ b/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py @@ -587,6 +587,34 @@ def run_query( ) return self._stubs["run_query"] + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], Awaitable[firestore.ExecutePipelineResponse] + ]: + r"""Return a callable for the execute pipeline method over gRPC. + + Executes a pipeline query. + + Returns: + Callable[[~.ExecutePipelineRequest], + Awaitable[~.ExecutePipelineResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "execute_pipeline" not in self._stubs: + self._stubs["execute_pipeline"] = self._logged_channel.unary_stream( + "/google.firestore.v1.Firestore/ExecutePipeline", + request_serializer=firestore.ExecutePipelineRequest.serialize, + response_deserializer=firestore.ExecutePipelineResponse.deserialize, + ) + return self._stubs["execute_pipeline"] + @property def run_aggregation_query( self, @@ -962,6 +990,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=300.0, client_info=client_info, ), + self.execute_pipeline: self._wrap_method( + self.execute_pipeline, + default_timeout=None, + client_info=client_info, + ), self.run_aggregation_query: self._wrap_method( self.run_aggregation_query, default_retry=retries.AsyncRetry( diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest.py b/google/cloud/firestore_v1/services/firestore/transports/rest.py index 3794ecea3..4bd282fe6 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest.py @@ -123,6 +123,14 @@ def pre_delete_document(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata + def pre_execute_pipeline(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_execute_pipeline(self, response): + logging.log(f"Received response: {response}") + return response + def pre_get_document(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata @@ -441,6 +449,56 @@ def pre_delete_document( """ return request, metadata + def pre_execute_pipeline( + self, + request: firestore.ExecutePipelineRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore.ExecutePipelineRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for execute_pipeline + + Override in a subclass to manipulate the request or metadata + before they are sent to the Firestore server. + """ + return request, metadata + + def post_execute_pipeline( + self, response: rest_streaming.ResponseIterator + ) -> rest_streaming.ResponseIterator: + """Post-rpc interceptor for execute_pipeline + + DEPRECATED. Please use the `post_execute_pipeline_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response + after it is returned by the Firestore server but before + it is returned to user code. This `post_execute_pipeline` interceptor runs + before the `post_execute_pipeline_with_metadata` interceptor. + """ + return response + + def post_execute_pipeline_with_metadata( + self, + response: rest_streaming.ResponseIterator, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + rest_streaming.ResponseIterator, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Post-rpc interceptor for execute_pipeline + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the Firestore server but before it is returned to user code. + + We recommend only using this `post_execute_pipeline_with_metadata` + interceptor in new development instead of the `post_execute_pipeline` interceptor. + When both interceptors are used, this `post_execute_pipeline_with_metadata` interceptor runs after the + `post_execute_pipeline` interceptor. The (possibly modified) response returned by + `post_execute_pipeline` will be passed to + `post_execute_pipeline_with_metadata`. + """ + return response, metadata + def pre_get_document( self, request: firestore.GetDocumentRequest, @@ -932,35 +990,39 @@ def __init__( ) -> None: """Instantiate the transport. - Args: - host (Optional[str]): - The hostname to connect to (default: 'firestore.googleapis.com'). - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client - certificate to configure mutual TLS HTTP channel. It is ignored - if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you are developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - url_scheme: the protocol scheme for the API endpoint. Normally - "https", but for testing or local servers, - "http" can be specified. + NOTE: This REST transport functionality is currently in a beta + state (preview). We welcome your feedback via a GitHub issue in + this library's repository. Thank you! + + Args: + host (Optional[str]): + The hostname to connect to (default: 'firestore.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. """ # Run the base constructor # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. @@ -1852,6 +1914,142 @@ def __call__( if response.status_code >= 400: raise core_exceptions.from_http_response(response) + class _ExecutePipeline( + _BaseFirestoreRestTransport._BaseExecutePipeline, FirestoreRestStub + ): + def __hash__(self): + return hash("FirestoreRestTransport.ExecutePipeline") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + stream=True, + ) + return response + + def __call__( + self, + request: firestore.ExecutePipelineRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> rest_streaming.ResponseIterator: + r"""Call the execute pipeline method over HTTP. + + Args: + request (~.firestore.ExecutePipelineRequest): + The request object. The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.firestore.ExecutePipelineResponse: + The response for [Firestore.Execute][]. + """ + + http_options = ( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_http_options() + ) + + request, metadata = self._interceptor.pre_execute_pipeline( + request, metadata + ) + transcoded_request = _BaseFirestoreRestTransport._BaseExecutePipeline._get_transcoded_request( + http_options, request + ) + + body = ( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_request_body_json( + transcoded_request + ) + ) + + # Jsonify the query params + query_params = ( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_query_params_json( + transcoded_request + ) + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore_v1.FirestoreClient.ExecutePipeline", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "ExecutePipeline", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = FirestoreRestTransport._ExecutePipeline._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = rest_streaming.ResponseIterator( + response, firestore.ExecutePipelineResponse + ) + + resp = self._interceptor.post_execute_pipeline(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_execute_pipeline_with_metadata( + resp, response_metadata + ) + return resp + class _GetDocument(_BaseFirestoreRestTransport._BaseGetDocument, FirestoreRestStub): def __hash__(self): return hash("FirestoreRestTransport.GetDocument") @@ -3090,6 +3288,16 @@ def delete_document( # In C++ this would require a dynamic_cast return self._DeleteDocument(self._session, self._host, self._interceptor) # type: ignore + @property + def execute_pipeline( + self, + ) -> Callable[ + [firestore.ExecutePipelineRequest], firestore.ExecutePipelineResponse + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ExecutePipeline(self._session, self._host, self._interceptor) # type: ignore + @property def get_document( self, diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py index 1d95cd16e..721f0792f 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py @@ -130,7 +130,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -139,7 +139,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -148,7 +148,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseBatchWrite: @@ -187,7 +186,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -196,7 +195,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -205,7 +204,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseBeginTransaction: @@ -244,7 +242,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -253,7 +251,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -262,7 +260,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseCommit: @@ -301,7 +298,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -310,7 +307,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -319,7 +316,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseCreateDocument: @@ -358,7 +354,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -367,7 +363,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -376,7 +372,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseDeleteDocument: @@ -414,7 +409,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -423,7 +418,62 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseExecutePipeline: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{database=projects/*/databases/*}/documents:executePipeline", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.ExecutePipelineRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=False + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=False, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_unset_required_fields( + query_params + ) + ) + return query_params class _BaseGetDocument: @@ -461,7 +511,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -470,7 +520,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseListCollectionIds: @@ -514,7 +563,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -523,7 +572,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -532,7 +581,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseListDocuments: @@ -574,7 +622,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -583,7 +631,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseListen: @@ -631,7 +678,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -640,7 +687,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -649,7 +696,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseRollback: @@ -688,7 +734,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -697,7 +743,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -706,7 +752,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseRunAggregationQuery: @@ -750,7 +795,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -759,7 +804,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -768,7 +813,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseRunQuery: @@ -812,7 +856,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -821,7 +865,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -830,7 +874,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseUpdateDocument: @@ -869,7 +912,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True + transcoded_request["body"], use_integers_for_enums=False ) return body @@ -878,7 +921,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=True, + use_integers_for_enums=False, ) ) query_params.update( @@ -887,7 +930,6 @@ def _get_query_params_json(transcoded_request): ) ) - query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseWrite: diff --git a/google/cloud/firestore_v1/types/__init__.py b/google/cloud/firestore_v1/types/__init__.py index ae1004e13..ed1965d7f 100644 --- a/google/cloud/firestore_v1/types/__init__.py +++ b/google/cloud/firestore_v1/types/__init__.py @@ -28,9 +28,14 @@ from .document import ( ArrayValue, Document, + Function, MapValue, + Pipeline, Value, ) +from .explain_stats import ( + ExplainStats, +) from .firestore import ( BatchGetDocumentsRequest, BatchGetDocumentsResponse, @@ -42,6 +47,8 @@ CommitResponse, CreateDocumentRequest, DeleteDocumentRequest, + ExecutePipelineRequest, + ExecutePipelineResponse, GetDocumentRequest, ListCollectionIdsRequest, ListCollectionIdsResponse, @@ -62,6 +69,9 @@ WriteRequest, WriteResponse, ) +from .pipeline import ( + StructuredPipeline, +) from .query import ( Cursor, StructuredAggregationQuery, @@ -92,8 +102,11 @@ "TransactionOptions", "ArrayValue", "Document", + "Function", "MapValue", + "Pipeline", "Value", + "ExplainStats", "BatchGetDocumentsRequest", "BatchGetDocumentsResponse", "BatchWriteRequest", @@ -104,6 +117,8 @@ "CommitResponse", "CreateDocumentRequest", "DeleteDocumentRequest", + "ExecutePipelineRequest", + "ExecutePipelineResponse", "GetDocumentRequest", "ListCollectionIdsRequest", "ListCollectionIdsResponse", @@ -123,6 +138,7 @@ "UpdateDocumentRequest", "WriteRequest", "WriteResponse", + "StructuredPipeline", "Cursor", "StructuredAggregationQuery", "StructuredQuery", diff --git a/google/cloud/firestore_v1/types/document.py b/google/cloud/firestore_v1/types/document.py index 0942354f5..1757571b1 100644 --- a/google/cloud/firestore_v1/types/document.py +++ b/google/cloud/firestore_v1/types/document.py @@ -31,6 +31,8 @@ "Value", "ArrayValue", "MapValue", + "Function", + "Pipeline", }, ) @@ -183,6 +185,37 @@ class Value(proto.Message): map_value (google.cloud.firestore_v1.types.MapValue): A map value. + This field is a member of `oneof`_ ``value_type``. + field_reference_value (str): + Value which references a field. + + This is considered relative (vs absolute) since it only + refers to a field and not a field within a particular + document. + + **Requires:** + + - Must follow [field reference][FieldReference.field_path] + limitations. + + - Not allowed to be used when writing documents. + + This field is a member of `oneof`_ ``value_type``. + function_value (google.cloud.firestore_v1.types.Function): + A value that represents an unevaluated expression. + + **Requires:** + + - Not allowed to be used when writing documents. + + This field is a member of `oneof`_ ``value_type``. + pipeline_value (google.cloud.firestore_v1.types.Pipeline): + A value that represents an unevaluated pipeline. + + **Requires:** + + - Not allowed to be used when writing documents. + This field is a member of `oneof`_ ``value_type``. """ @@ -246,6 +279,23 @@ class Value(proto.Message): oneof="value_type", message="MapValue", ) + field_reference_value: str = proto.Field( + proto.STRING, + number=19, + oneof="value_type", + ) + function_value: "Function" = proto.Field( + proto.MESSAGE, + number=20, + oneof="value_type", + message="Function", + ) + pipeline_value: "Pipeline" = proto.Field( + proto.MESSAGE, + number=21, + oneof="value_type", + message="Pipeline", + ) class ArrayValue(proto.Message): @@ -285,4 +335,119 @@ class MapValue(proto.Message): ) +class Function(proto.Message): + r"""Represents an unevaluated scalar expression. + + For example, the expression ``like(user_name, "%alice%")`` is + represented as: + + :: + + name: "like" + args { field_reference: "user_name" } + args { string_value: "%alice%" } + + Attributes: + name (str): + Required. The name of the function to evaluate. + + **Requires:** + + - must be in snake case (lower case with underscore + separator). + args (MutableSequence[google.cloud.firestore_v1.types.Value]): + Optional. Ordered list of arguments the given + function expects. + options (MutableMapping[str, google.cloud.firestore_v1.types.Value]): + Optional. Optional named arguments that + certain functions may support. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + args: MutableSequence["Value"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Value", + ) + options: MutableMapping[str, "Value"] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=3, + message="Value", + ) + + +class Pipeline(proto.Message): + r"""A Firestore query represented as an ordered list of + operations / stages. + + Attributes: + stages (MutableSequence[google.cloud.firestore_v1.types.Pipeline.Stage]): + Required. Ordered list of stages to evaluate. + """ + + class Stage(proto.Message): + r"""A single operation within a pipeline. + + A stage is made up of a unique name, and a list of arguments. The + exact number of arguments & types is dependent on the stage type. + + To give an example, the stage ``filter(state = "MD")`` would be + encoded as: + + :: + + name: "filter" + args { + function_value { + name: "eq" + args { field_reference_value: "state" } + args { string_value: "MD" } + } + } + + See public documentation for the full list. + + Attributes: + name (str): + Required. The name of the stage to evaluate. + + **Requires:** + + - must be in snake case (lower case with underscore + separator). + args (MutableSequence[google.cloud.firestore_v1.types.Value]): + Optional. Ordered list of arguments the given + stage expects. + options (MutableMapping[str, google.cloud.firestore_v1.types.Value]): + Optional. Optional named arguments that + certain functions may support. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + args: MutableSequence["Value"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Value", + ) + options: MutableMapping[str, "Value"] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=3, + message="Value", + ) + + stages: MutableSequence[Stage] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=Stage, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/firestore_v1/types/explain_stats.py b/google/cloud/firestore_v1/types/explain_stats.py new file mode 100644 index 000000000..1fda228b6 --- /dev/null +++ b/google/cloud/firestore_v1/types/explain_stats.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.protobuf import any_pb2 # type: ignore + + +__protobuf__ = proto.module( + package="google.firestore.v1", + manifest={ + "ExplainStats", + }, +) + + +class ExplainStats(proto.Message): + r"""Explain stats for an RPC request, includes both the optimized + plan and execution stats. + + Attributes: + data (google.protobuf.any_pb2.Any): + The format depends on the ``output_format`` options in the + request. + + The only option today is ``TEXT``, which is a + ``google.protobuf.StringValue``. + """ + + data: any_pb2.Any = proto.Field( + proto.MESSAGE, + number=1, + message=any_pb2.Any, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/firestore_v1/types/firestore.py b/google/cloud/firestore_v1/types/firestore.py index 53a6c6e7a..f1753c92f 100644 --- a/google/cloud/firestore_v1/types/firestore.py +++ b/google/cloud/firestore_v1/types/firestore.py @@ -22,6 +22,8 @@ from google.cloud.firestore_v1.types import aggregation_result from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats as gf_explain_stats +from google.cloud.firestore_v1.types import pipeline from google.cloud.firestore_v1.types import query as gf_query from google.cloud.firestore_v1.types import query_profile from google.cloud.firestore_v1.types import write @@ -48,6 +50,8 @@ "RollbackRequest", "RunQueryRequest", "RunQueryResponse", + "ExecutePipelineRequest", + "ExecutePipelineResponse", "RunAggregationQueryRequest", "RunAggregationQueryResponse", "PartitionQueryRequest", @@ -835,6 +839,147 @@ class RunQueryResponse(proto.Message): ) +class ExecutePipelineRequest(proto.Message): + r"""The request for + [Firestore.ExecutePipeline][google.firestore.v1.Firestore.ExecutePipeline]. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + database (str): + Required. Database identifier, in the form + ``projects/{project}/databases/{database}``. + structured_pipeline (google.cloud.firestore_v1.types.StructuredPipeline): + A pipelined operation. + + This field is a member of `oneof`_ ``pipeline_type``. + transaction (bytes): + Run the query within an already active + transaction. + The value here is the opaque transaction ID to + execute the query in. + + This field is a member of `oneof`_ ``consistency_selector``. + new_transaction (google.cloud.firestore_v1.types.TransactionOptions): + Execute the pipeline in a new transaction. + + The identifier of the newly created transaction + will be returned in the first response on the + stream. This defaults to a read-only + transaction. + + This field is a member of `oneof`_ ``consistency_selector``. + read_time (google.protobuf.timestamp_pb2.Timestamp): + Execute the pipeline in a snapshot + transaction at the given time. + This must be a microsecond precision timestamp + within the past one hour, or if Point-in-Time + Recovery is enabled, can additionally be a whole + minute timestamp within the past 7 days. + + This field is a member of `oneof`_ ``consistency_selector``. + """ + + database: str = proto.Field( + proto.STRING, + number=1, + ) + structured_pipeline: pipeline.StructuredPipeline = proto.Field( + proto.MESSAGE, + number=2, + oneof="pipeline_type", + message=pipeline.StructuredPipeline, + ) + transaction: bytes = proto.Field( + proto.BYTES, + number=5, + oneof="consistency_selector", + ) + new_transaction: common.TransactionOptions = proto.Field( + proto.MESSAGE, + number=6, + oneof="consistency_selector", + message=common.TransactionOptions, + ) + read_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=7, + oneof="consistency_selector", + message=timestamp_pb2.Timestamp, + ) + + +class ExecutePipelineResponse(proto.Message): + r"""The response for [Firestore.Execute][]. + + Attributes: + transaction (bytes): + Newly created transaction identifier. + + This field is only specified as part of the first response + from the server, alongside the ``results`` field when the + original request specified + [ExecuteRequest.new_transaction][]. + results (MutableSequence[google.cloud.firestore_v1.types.Document]): + An ordered batch of results returned executing a pipeline. + + The batch size is variable, and can even be zero for when + only a partial progress message is returned. + + The fields present in the returned documents are only those + that were explicitly requested in the pipeline, this include + those like [``__name__``][google.firestore.v1.Document.name] + & + [``__update_time__``][google.firestore.v1.Document.update_time]. + This is explicitly a divergence from ``Firestore.RunQuery`` + / ``Firestore.GetDocument`` RPCs which always return such + fields even when they are not specified in the + [``mask``][google.firestore.v1.DocumentMask]. + execution_time (google.protobuf.timestamp_pb2.Timestamp): + The time at which the document(s) were read. + + This may be monotonically increasing; in this case, the + previous documents in the result stream are guaranteed not + to have changed between their ``execution_time`` and this + one. + + If the query returns no results, a response with + ``execution_time`` and no ``results`` will be sent, and this + represents the time at which the operation was run. + explain_stats (google.cloud.firestore_v1.types.ExplainStats): + Query explain stats. + + Contains all metadata related to pipeline + planning and execution, specific contents depend + on the supplied pipeline options. + """ + + transaction: bytes = proto.Field( + proto.BYTES, + number=1, + ) + results: MutableSequence[gf_document.Document] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message=gf_document.Document, + ) + execution_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=3, + message=timestamp_pb2.Timestamp, + ) + explain_stats: gf_explain_stats.ExplainStats = proto.Field( + proto.MESSAGE, + number=4, + message=gf_explain_stats.ExplainStats, + ) + + class RunAggregationQueryRequest(proto.Message): r"""The request for [Firestore.RunAggregationQuery][google.firestore.v1.Firestore.RunAggregationQuery]. diff --git a/google/cloud/firestore_v1/types/pipeline.py b/google/cloud/firestore_v1/types/pipeline.py new file mode 100644 index 000000000..29fbe884b --- /dev/null +++ b/google/cloud/firestore_v1/types/pipeline.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.cloud.firestore_v1.types import document + + +__protobuf__ = proto.module( + package="google.firestore.v1", + manifest={ + "StructuredPipeline", + }, +) + + +class StructuredPipeline(proto.Message): + r"""A Firestore query represented as an ordered list of operations / + stages. + + This is considered the top-level function which plans & executes a + query. It is logically equivalent to ``query(stages, options)``, but + prevents the client from having to build a function wrapper. + + Attributes: + pipeline (google.cloud.firestore_v1.types.Pipeline): + Required. The pipeline query to execute. + options (MutableMapping[str, google.cloud.firestore_v1.types.Value]): + Optional. Optional query-level arguments. + """ + + pipeline: document.Pipeline = proto.Field( + proto.MESSAGE, + number=1, + message=document.Pipeline, + ) + options: MutableMapping[str, document.Value] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=2, + message=document.Value, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/tests/unit/gapic/firestore_v1/test_firestore.py b/tests/unit/gapic/firestore_v1/test_firestore.py index eac609cab..d91e91c96 100644 --- a/tests/unit/gapic/firestore_v1/test_firestore.py +++ b/tests/unit/gapic/firestore_v1/test_firestore.py @@ -61,7 +61,9 @@ from google.cloud.firestore_v1.types import common from google.cloud.firestore_v1.types import document from google.cloud.firestore_v1.types import document as gf_document +from google.cloud.firestore_v1.types import explain_stats from google.cloud.firestore_v1.types import firestore +from google.cloud.firestore_v1.types import pipeline from google.cloud.firestore_v1.types import query from google.cloud.firestore_v1.types import query_profile from google.cloud.firestore_v1.types import write as gf_write @@ -3884,6 +3886,185 @@ async def test_run_query_field_headers_async(): ) in kw["metadata"] +@pytest.mark.parametrize( + "request_type", + [ + firestore.ExecutePipelineRequest, + dict, + ], +) +def test_execute_pipeline(request_type, transport: str = "grpc"): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = iter([firestore.ExecutePipelineResponse()]) + response = client.execute_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = firestore.ExecutePipelineRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + for message in response: + assert isinstance(message, firestore.ExecutePipelineResponse) + + +def test_execute_pipeline_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = firestore.ExecutePipelineRequest( + database="database_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.execute_pipeline(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == firestore.ExecutePipelineRequest( + database="database_value", + ) + + +def test_execute_pipeline_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.execute_pipeline in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.execute_pipeline + ] = mock_rpc + request = {} + client.execute_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.execute_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_execute_pipeline_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.execute_pipeline + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.execute_pipeline + ] = mock_rpc + + request = {} + await client.execute_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.execute_pipeline(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_execute_pipeline_async( + transport: str = "grpc_asyncio", request_type=firestore.ExecutePipelineRequest +): + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[firestore.ExecutePipelineResponse()] + ) + response = await client.execute_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = firestore.ExecutePipelineRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + message = await response.read() + assert isinstance(message, firestore.ExecutePipelineResponse) + + +@pytest.mark.asyncio +async def test_execute_pipeline_async_from_dict(): + await test_execute_pipeline_async(request_type=dict) + + @pytest.mark.parametrize( "request_type", [ @@ -6008,7 +6189,7 @@ def test_get_document_rest_required_fields(request_type=firestore.GetDocumentReq response = client.get_document(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -6149,7 +6330,7 @@ def test_list_documents_rest_required_fields( response = client.list_documents(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -6350,7 +6531,7 @@ def test_update_document_rest_required_fields( response = client.update_document(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -6542,7 +6723,7 @@ def test_delete_document_rest_required_fields( response = client.delete_document(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -6728,7 +6909,7 @@ def test_batch_get_documents_rest_required_fields( iter_content.return_value = iter(json_return_value) response = client.batch_get_documents(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -6851,7 +7032,7 @@ def test_begin_transaction_rest_required_fields( response = client.begin_transaction(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -7028,7 +7209,7 @@ def test_commit_rest_required_fields(request_type=firestore.CommitRequest): response = client.commit(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -7208,7 +7389,7 @@ def test_rollback_rest_required_fields(request_type=firestore.RollbackRequest): response = client.rollback(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -7396,7 +7577,7 @@ def test_run_query_rest_required_fields(request_type=firestore.RunQueryRequest): iter_content.return_value = iter(json_return_value) response = client.run_query(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -7410,7 +7591,7 @@ def test_run_query_rest_unset_required_fields(): assert set(unset_fields) == (set(()) & set(("parent",))) -def test_run_aggregation_query_rest_use_cached_wrapped_rpc(): +def test_execute_pipeline_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -7424,10 +7605,7 @@ def test_run_aggregation_query_rest_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert ( - client._transport.run_aggregation_query - in client._transport._wrapped_methods - ) + assert client._transport.execute_pipeline in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() @@ -7435,29 +7613,29 @@ def test_run_aggregation_query_rest_use_cached_wrapped_rpc(): "foo" # operation_request.operation in compute client(s) expect a string. ) client._transport._wrapped_methods[ - client._transport.run_aggregation_query + client._transport.execute_pipeline ] = mock_rpc request = {} - client.run_aggregation_query(request) + client.execute_pipeline(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - client.run_aggregation_query(request) + client.execute_pipeline(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 assert mock_rpc.call_count == 2 -def test_run_aggregation_query_rest_required_fields( - request_type=firestore.RunAggregationQueryRequest, +def test_execute_pipeline_rest_required_fields( + request_type=firestore.ExecutePipelineRequest, ): transport_class = transports.FirestoreRestTransport request_init = {} - request_init["parent"] = "" + request_init["database"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -7468,21 +7646,21 @@ def test_run_aggregation_query_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).run_aggregation_query._get_unset_required_fields(jsonified_request) + ).execute_pipeline._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["parent"] = "parent_value" + jsonified_request["database"] = "database_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).run_aggregation_query._get_unset_required_fields(jsonified_request) + ).execute_pipeline._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "parent" in jsonified_request - assert jsonified_request["parent"] == "parent_value" + assert "database" in jsonified_request + assert jsonified_request["database"] == "database_value" client = FirestoreClient( credentials=ga_credentials.AnonymousCredentials(), @@ -7491,7 +7669,7 @@ def test_run_aggregation_query_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = firestore.RunAggregationQueryResponse() + return_value = firestore.ExecutePipelineResponse() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -7513,7 +7691,7 @@ def test_run_aggregation_query_rest_required_fields( response_value.status_code = 200 # Convert return value to protobuf type - return_value = firestore.RunAggregationQueryResponse.pb(return_value) + return_value = firestore.ExecutePipelineResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) json_return_value = "[{}]".format(json_return_value) @@ -7523,23 +7701,23 @@ def test_run_aggregation_query_rest_required_fields( with mock.patch.object(response_value, "iter_content") as iter_content: iter_content.return_value = iter(json_return_value) - response = client.run_aggregation_query(request) + response = client.execute_pipeline(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_run_aggregation_query_rest_unset_required_fields(): +def test_execute_pipeline_rest_unset_required_fields(): transport = transports.FirestoreRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.run_aggregation_query._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("parent",))) + unset_fields = transport.execute_pipeline._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("database",))) -def test_partition_query_rest_use_cached_wrapped_rpc(): +def test_run_aggregation_query_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -7553,30 +7731,35 @@ def test_partition_query_rest_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.partition_query in client._transport._wrapped_methods + assert ( + client._transport.run_aggregation_query + in client._transport._wrapped_methods + ) # Replace cached wrapped function with mock mock_rpc = mock.Mock() mock_rpc.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client._transport._wrapped_methods[client._transport.partition_query] = mock_rpc + client._transport._wrapped_methods[ + client._transport.run_aggregation_query + ] = mock_rpc request = {} - client.partition_query(request) + client.run_aggregation_query(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - client.partition_query(request) + client.run_aggregation_query(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 assert mock_rpc.call_count == 2 -def test_partition_query_rest_required_fields( - request_type=firestore.PartitionQueryRequest, +def test_run_aggregation_query_rest_required_fields( + request_type=firestore.RunAggregationQueryRequest, ): transport_class = transports.FirestoreRestTransport @@ -7592,7 +7775,7 @@ def test_partition_query_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).partition_query._get_unset_required_fields(jsonified_request) + ).run_aggregation_query._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present @@ -7601,7 +7784,7 @@ def test_partition_query_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).partition_query._get_unset_required_fields(jsonified_request) + ).run_aggregation_query._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone @@ -7615,7 +7798,7 @@ def test_partition_query_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = firestore.PartitionQueryResponse() + return_value = firestore.RunAggregationQueryResponse() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -7637,68 +7820,192 @@ def test_partition_query_rest_required_fields( response_value.status_code = 200 # Convert return value to protobuf type - return_value = firestore.PartitionQueryResponse.pb(return_value) + return_value = firestore.RunAggregationQueryResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) + json_return_value = "[{}]".format(json_return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.partition_query(request) + with mock.patch.object(response_value, "iter_content") as iter_content: + iter_content.return_value = iter(json_return_value) + response = client.run_aggregation_query(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_partition_query_rest_unset_required_fields(): +def test_run_aggregation_query_rest_unset_required_fields(): transport = transports.FirestoreRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.partition_query._get_unset_required_fields({}) + unset_fields = transport.run_aggregation_query._get_unset_required_fields({}) assert set(unset_fields) == (set(()) & set(("parent",))) -def test_partition_query_rest_pager(transport: str = "rest"): - client = FirestoreClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, "request") as req: - # TODO(kbandes): remove this mock unless there's a good reason for it. - # with mock.patch.object(path_template, 'transcode') as transcode: - # Set the response as a series of pages - response = ( - firestore.PartitionQueryResponse( - partitions=[ - query.Cursor(), - query.Cursor(), - query.Cursor(), - ], - next_page_token="abc", - ), - firestore.PartitionQueryResponse( - partitions=[], - next_page_token="def", - ), - firestore.PartitionQueryResponse( - partitions=[ - query.Cursor(), - ], - next_page_token="ghi", - ), - firestore.PartitionQueryResponse( - partitions=[ - query.Cursor(), - query.Cursor(), - ], - ), +def test_partition_query_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - # Two responses for two calls - response = response + response + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.partition_query in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.partition_query] = mock_rpc + + request = {} + client.partition_query(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.partition_query(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_partition_query_rest_required_fields( + request_type=firestore.PartitionQueryRequest, +): + transport_class = transports.FirestoreRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).partition_query._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).partition_query._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = firestore.PartitionQueryResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = firestore.PartitionQueryResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.partition_query(request) + + expected_params = [] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_partition_query_rest_unset_required_fields(): + transport = transports.FirestoreRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.partition_query._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("parent",))) + + +def test_partition_query_rest_pager(transport: str = "rest"): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + firestore.PartitionQueryResponse( + partitions=[ + query.Cursor(), + query.Cursor(), + query.Cursor(), + ], + next_page_token="abc", + ), + firestore.PartitionQueryResponse( + partitions=[], + next_page_token="def", + ), + firestore.PartitionQueryResponse( + partitions=[ + query.Cursor(), + ], + next_page_token="ghi", + ), + firestore.PartitionQueryResponse( + partitions=[ + query.Cursor(), + query.Cursor(), + ], + ), + ) + # Two responses for two calls + response = response + response # Wrap the values into proper Response objs response = tuple(firestore.PartitionQueryResponse.to_json(x) for x in response) @@ -7854,7 +8161,7 @@ def test_list_collection_ids_rest_required_fields( response = client.list_collection_ids(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -8094,7 +8401,7 @@ def test_batch_write_rest_required_fields(request_type=firestore.BatchWriteReque response = client.batch_write(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -8226,7 +8533,7 @@ def test_create_document_rest_required_fields( response = client.create_document(request) - expected_params = [("$alt", "json;enum-encoding=int")] + expected_params = [] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -8553,6 +8860,27 @@ def test_run_query_empty_call_grpc(): assert args[0] == request_msg +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_execute_pipeline_empty_call_grpc(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + call.return_value = iter([firestore.ExecutePipelineResponse()]) + client.execute_pipeline(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest() + + assert args[0] == request_msg + + # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. def test_run_aggregation_query_empty_call_grpc(): @@ -8662,6 +8990,60 @@ def test_create_document_empty_call_grpc(): assert args[0] == request_msg +def test_execute_pipeline_routing_parameters_request_1_grpc(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + call.return_value = iter([firestore.ExecutePipelineResponse()]) + client.execute_pipeline(request={"database": "projects/sample1/sample2"}) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/sample2"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + +def test_execute_pipeline_routing_parameters_request_2_grpc(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + call.return_value = iter([firestore.ExecutePipelineResponse()]) + client.execute_pipeline( + request={"database": "projects/sample1/databases/sample2/sample3"} + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/databases/sample2/sample3"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1", "database_id": "sample2"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + def test_transport_kind_grpc_asyncio(): transport = FirestoreAsyncClient.get_transport_class("grpc_asyncio")( credentials=async_anonymous_credentials() @@ -8911,6 +9293,32 @@ async def test_run_query_empty_call_grpc_asyncio(): assert args[0] == request_msg +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_execute_pipeline_empty_call_grpc_asyncio(): + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[firestore.ExecutePipelineResponse()] + ) + await client.execute_pipeline(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest() + + assert args[0] == request_msg + + # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. @pytest.mark.asyncio @@ -9048,6 +9456,70 @@ async def test_create_document_empty_call_grpc_asyncio(): assert args[0] == request_msg +@pytest.mark.asyncio +async def test_execute_pipeline_routing_parameters_request_1_grpc_asyncio(): + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[firestore.ExecutePipelineResponse()] + ) + await client.execute_pipeline(request={"database": "projects/sample1/sample2"}) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/sample2"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + +@pytest.mark.asyncio +async def test_execute_pipeline_routing_parameters_request_2_grpc_asyncio(): + client = FirestoreAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[firestore.ExecutePipelineResponse()] + ) + await client.execute_pipeline( + request={"database": "projects/sample1/databases/sample2/sample3"} + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/databases/sample2/sample3"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1", "database_id": "sample2"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + def test_transport_kind_rest(): transport = FirestoreClient.get_transport_class("rest")( credentials=ga_credentials.AnonymousCredentials() @@ -10233,6 +10705,137 @@ def test_run_query_rest_interceptors(null_interceptor): post_with_metadata.assert_called_once() +def test_execute_pipeline_rest_bad_request( + request_type=firestore.ExecutePipelineRequest, +): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"database": "projects/sample1/databases/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.execute_pipeline(request) + + +@pytest.mark.parametrize( + "request_type", + [ + firestore.ExecutePipelineRequest, + dict, + ], +) +def test_execute_pipeline_rest_call_success(request_type): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"database": "projects/sample1/databases/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = firestore.ExecutePipelineResponse( + transaction=b"transaction_blob", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = firestore.ExecutePipelineResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + json_return_value = "[{}]".format(json_return_value) + response_value.iter_content = mock.Mock(return_value=iter(json_return_value)) + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.execute_pipeline(request) + + assert isinstance(response, Iterable) + response = next(response) + + # Establish that the response is the type that we expect. + assert isinstance(response, firestore.ExecutePipelineResponse) + assert response.transaction == b"transaction_blob" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_execute_pipeline_rest_interceptors(null_interceptor): + transport = transports.FirestoreRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.FirestoreRestInterceptor(), + ) + client = FirestoreClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.FirestoreRestInterceptor, "post_execute_pipeline" + ) as post, mock.patch.object( + transports.FirestoreRestInterceptor, "post_execute_pipeline_with_metadata" + ) as post_with_metadata, mock.patch.object( + transports.FirestoreRestInterceptor, "pre_execute_pipeline" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + post_with_metadata.assert_not_called() + pb_message = firestore.ExecutePipelineRequest.pb( + firestore.ExecutePipelineRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = firestore.ExecutePipelineResponse.to_json( + firestore.ExecutePipelineResponse() + ) + req.return_value.iter_content = mock.Mock(return_value=iter(return_value)) + + request = firestore.ExecutePipelineRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = firestore.ExecutePipelineResponse() + post_with_metadata.return_value = firestore.ExecutePipelineResponse(), metadata + + client.execute_pipeline( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + post_with_metadata.assert_called_once() + + def test_run_aggregation_query_rest_bad_request( request_type=firestore.RunAggregationQueryRequest, ): @@ -11409,6 +12012,26 @@ def test_run_query_empty_call_rest(): assert args[0] == request_msg +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_execute_pipeline_empty_call_rest(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + client.execute_pipeline(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest() + + assert args[0] == request_msg + + # This test is a coverage failsafe to make sure that totally empty calls, # i.e. request == None and no flattened fields passed, work. def test_run_aggregation_query_empty_call_rest(): @@ -11513,6 +12136,58 @@ def test_create_document_empty_call_rest(): assert args[0] == request_msg +def test_execute_pipeline_routing_parameters_request_1_rest(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + client.execute_pipeline(request={"database": "projects/sample1/sample2"}) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/sample2"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + +def test_execute_pipeline_routing_parameters_request_2_rest(): + client = FirestoreClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.execute_pipeline), "__call__") as call: + client.execute_pipeline( + request={"database": "projects/sample1/databases/sample2/sample3"} + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore.ExecutePipelineRequest( + **{"database": "projects/sample1/databases/sample2/sample3"} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1", "database_id": "sample2"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = FirestoreClient( @@ -11555,6 +12230,7 @@ def test_firestore_base_transport(): "commit", "rollback", "run_query", + "execute_pipeline", "run_aggregation_query", "partition_query", "write", @@ -11860,6 +12536,9 @@ def test_firestore_client_transport_session_collision(transport_name): session1 = client1.transport.run_query._session session2 = client2.transport.run_query._session assert session1 != session2 + session1 = client1.transport.execute_pipeline._session + session2 = client2.transport.execute_pipeline._session + assert session1 != session2 session1 = client1.transport.run_aggregation_query._session session2 = client2.transport.run_aggregation_query._session assert session1 != session2 From 17e71b9e92f7ac293a49e7990cce5eb17cff897a Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 17 Jun 2025 15:54:31 -0700 Subject: [PATCH 02/27] feat: add pipelines structure (#1046) --- google/cloud/firestore_v1/_helpers.py | 3 + google/cloud/firestore_v1/_pipeline_stages.py | 81 ++++ google/cloud/firestore_v1/async_client.py | 9 + google/cloud/firestore_v1/async_pipeline.py | 96 +++++ google/cloud/firestore_v1/base_client.py | 17 + google/cloud/firestore_v1/base_pipeline.py | 151 +++++++ google/cloud/firestore_v1/client.py | 9 + google/cloud/firestore_v1/field_path.py | 4 +- google/cloud/firestore_v1/pipeline.py | 90 ++++ .../firestore_v1/pipeline_expressions.py | 85 ++++ google/cloud/firestore_v1/pipeline_result.py | 139 +++++++ google/cloud/firestore_v1/pipeline_source.py | 53 +++ noxfile.py | 1 + tests/unit/v1/test_async_client.py | 11 + tests/unit/v1/test_async_pipeline.py | 393 ++++++++++++++++++ tests/unit/v1/test_client.py | 12 + tests/unit/v1/test_pipeline.py | 370 +++++++++++++++++ tests/unit/v1/test_pipeline_expressions.py | 104 +++++ tests/unit/v1/test_pipeline_result.py | 176 ++++++++ tests/unit/v1/test_pipeline_source.py | 56 +++ tests/unit/v1/test_pipeline_stages.py | 121 ++++++ 21 files changed, 1979 insertions(+), 2 deletions(-) create mode 100644 google/cloud/firestore_v1/_pipeline_stages.py create mode 100644 google/cloud/firestore_v1/async_pipeline.py create mode 100644 google/cloud/firestore_v1/base_pipeline.py create mode 100644 google/cloud/firestore_v1/pipeline.py create mode 100644 google/cloud/firestore_v1/pipeline_expressions.py create mode 100644 google/cloud/firestore_v1/pipeline_result.py create mode 100644 google/cloud/firestore_v1/pipeline_source.py create mode 100644 tests/unit/v1/test_async_pipeline.py create mode 100644 tests/unit/v1/test_pipeline.py create mode 100644 tests/unit/v1/test_pipeline_expressions.py create mode 100644 tests/unit/v1/test_pipeline_result.py create mode 100644 tests/unit/v1/test_pipeline_source.py create mode 100644 tests/unit/v1/test_pipeline_stages.py diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index 399bdb066..1fbc1a476 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -120,6 +120,9 @@ def __ne__(self, other): else: return not equality_val + def __repr__(self): + return f"{type(self).__name__}(latitude={self.latitude}, longitude={self.longitude})" + def verify_path(path, is_collection) -> None: """Verifies that a ``path`` has the correct form. diff --git a/google/cloud/firestore_v1/_pipeline_stages.py b/google/cloud/firestore_v1/_pipeline_stages.py new file mode 100644 index 000000000..3871a363d --- /dev/null +++ b/google/cloud/firestore_v1/_pipeline_stages.py @@ -0,0 +1,81 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Optional +from abc import ABC +from abc import abstractmethod + +from google.cloud.firestore_v1.types.document import Pipeline as Pipeline_pb +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.pipeline_expressions import Expr + + +class Stage(ABC): + """Base class for all pipeline stages. + + Each stage represents a specific operation (e.g., filtering, sorting, + transforming) within a Firestore pipeline. Subclasses define the specific + arguments and behavior for each operation. + """ + + def __init__(self, custom_name: Optional[str] = None): + self.name = custom_name or type(self).__name__.lower() + + def _to_pb(self) -> Pipeline_pb.Stage: + return Pipeline_pb.Stage( + name=self.name, args=self._pb_args(), options=self._pb_options() + ) + + @abstractmethod + def _pb_args(self) -> list[Value]: + """Return Ordered list of arguments the given stage expects""" + raise NotImplementedError + + def _pb_options(self) -> dict[str, Value]: + """Return optional named arguments that certain functions may support.""" + return {} + + def __repr__(self): + items = ("%s=%r" % (k, v) for k, v in self.__dict__.items() if k != "name") + return f"{self.__class__.__name__}({', '.join(items)})" + + +class Collection(Stage): + """Specifies a collection as the initial data source.""" + + def __init__(self, path: str): + super().__init__() + if not path.startswith("/"): + path = f"/{path}" + self.path = path + + def _pb_args(self): + return [Value(reference_value=self.path)] + + +class GenericStage(Stage): + """Represents a generic, named stage with parameters.""" + + def __init__(self, name: str, *params: Expr | Value): + super().__init__(name) + self.params: list[Value] = [ + p._to_pb() if isinstance(p, Expr) else p for p in params + ] + + def _pb_args(self): + return self.params + + def __repr__(self): + return f"{self.__class__.__name__}(name='{self.name}')" diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 15b31af31..3acbedc76 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -47,6 +47,8 @@ from google.cloud.firestore_v1.services.firestore.transports import ( grpc_asyncio as firestore_grpc_transport, ) +from google.cloud.firestore_v1.async_pipeline import AsyncPipeline +from google.cloud.firestore_v1.pipeline_source import PipelineSource if TYPE_CHECKING: # pragma: NO COVER import datetime @@ -427,3 +429,10 @@ def transaction(self, **kwargs) -> AsyncTransaction: A transaction attached to this client. """ return AsyncTransaction(self, **kwargs) + + @property + def _pipeline_cls(self): + return AsyncPipeline + + def pipeline(self) -> PipelineSource: + return PipelineSource(self) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py new file mode 100644 index 000000000..471c33093 --- /dev/null +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -0,0 +1,96 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import AsyncIterable, TYPE_CHECKING +from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1.base_pipeline import _BasePipeline + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.async_client import AsyncClient + from google.cloud.firestore_v1.pipeline_result import PipelineResult + from google.cloud.firestore_v1.async_transaction import AsyncTransaction + + +class AsyncPipeline(_BasePipeline): + """ + Pipelines allow for complex data transformations and queries involving + multiple stages like filtering, projection, aggregation, and vector search. + + This class extends `_BasePipeline` and provides methods to execute the + defined pipeline stages using an asynchronous `AsyncClient`. + + Usage Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> + >>> async def run_pipeline(): + ... client = AsyncClient(...) + ... pipeline = client.pipeline() + ... .collection("books") + ... .where(Field.of("published").gt(1980)) + ... .select("title", "author") + ... async for result in pipeline.execute(): + ... print(result) + + Use `client.pipeline()` to create instances of this class. + """ + + def __init__(self, client: AsyncClient, *stages: stages.Stage): + """ + Initializes an asynchronous Pipeline. + + Args: + client: The asynchronous `AsyncClient` instance to use for execution. + *stages: Initial stages for the pipeline. + """ + super().__init__(client, *stages) + + async def execute( + self, + transaction: "AsyncTransaction" | None = None, + ) -> list[PipelineResult]: + """ + Executes this pipeline and returns results as a list + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + """ + return [result async for result in self.stream(transaction=transaction)] + + async def stream( + self, + transaction: "AsyncTransaction" | None = None, + ) -> AsyncIterable[PipelineResult]: + """ + Process this pipeline as a stream, providing results through an Iterable + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + """ + request = self._prep_execute_request(transaction) + async for response in await self._client._firestore_api.execute_pipeline( + request + ): + for result in self._execute_response_helper(response): + yield result diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index 4a0e3f6b8..8c8b9532d 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -37,6 +37,7 @@ Optional, Tuple, Union, + Type, ) import google.api_core.client_options @@ -61,6 +62,8 @@ from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions from google.cloud.firestore_v1.field_path import render_field_path from google.cloud.firestore_v1.services.firestore import client as firestore_client +from google.cloud.firestore_v1.pipeline_source import PipelineSource +from google.cloud.firestore_v1.base_pipeline import _BasePipeline DEFAULT_DATABASE = "(default)" """str: The default database used in a :class:`~google.cloud.firestore_v1.client.Client`.""" @@ -500,6 +503,20 @@ def batch(self) -> BaseWriteBatch: def transaction(self, **kwargs) -> BaseTransaction: raise NotImplementedError + def pipeline(self) -> PipelineSource: + """ + Start a pipeline with this client. + + Returns: + :class:`~google.cloud.firestore_v1.pipeline_source.PipelineSource`: + A pipeline that uses this client` + """ + raise NotImplementedError + + @property + def _pipeline_cls(self) -> Type["_BasePipeline"]: + raise NotImplementedError + def _reference_info(references: list) -> Tuple[list, dict]: """Get information about document references. diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py new file mode 100644 index 000000000..dde906fe6 --- /dev/null +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -0,0 +1,151 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Iterable, Sequence, TYPE_CHECKING +from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1.types.pipeline import ( + StructuredPipeline as StructuredPipeline_pb, +) +from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest +from google.cloud.firestore_v1.pipeline_result import PipelineResult +from google.cloud.firestore_v1.pipeline_expressions import Expr +from google.cloud.firestore_v1 import _helpers + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.async_client import AsyncClient + from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse + from google.cloud.firestore_v1.transaction import BaseTransaction + + +class _BasePipeline: + """ + Base class for building Firestore data transformation and query pipelines. + + This class is not intended to be instantiated directly. + Use `client.collection.("...").pipeline()` to create pipeline instances. + """ + + def __init__(self, client: Client | AsyncClient): + """ + Initializes a new pipeline. + + Pipelines should not be instantiated directly. Instead, + call client.pipeline() to create an instance + + Args: + client: The client associated with the pipeline + """ + self._client = client + self.stages: Sequence[stages.Stage] = tuple() + + @classmethod + def _create_with_stages( + cls, client: Client | AsyncClient, *stages + ) -> _BasePipeline: + """ + Initializes a new pipeline with the given stages. + + Pipeline classes should not be instantiated directly. + + Args: + client: The client associated with the pipeline + *stages: Initial stages for the pipeline. + """ + new_instance = cls(client) + new_instance.stages = tuple(stages) + return new_instance + + def __repr__(self): + cls_str = type(self).__name__ + if not self.stages: + return f"{cls_str}()" + elif len(self.stages) == 1: + return f"{cls_str}({self.stages[0]!r})" + else: + stages_str = ",\n ".join([repr(s) for s in self.stages]) + return f"{cls_str}(\n {stages_str}\n)" + + def _to_pb(self) -> StructuredPipeline_pb: + return StructuredPipeline_pb( + pipeline={"stages": [s._to_pb() for s in self.stages]} + ) + + def _append(self, new_stage): + """ + Create a new Pipeline object with a new stage appended + """ + return self.__class__._create_with_stages(self._client, *self.stages, new_stage) + + def _prep_execute_request( + self, transaction: BaseTransaction | None + ) -> ExecutePipelineRequest: + """ + shared logic for creating an ExecutePipelineRequest + """ + database_name = ( + f"projects/{self._client.project}/databases/{self._client._database}" + ) + transaction_id = ( + _helpers.get_transaction_id(transaction) + if transaction is not None + else None + ) + request = ExecutePipelineRequest( + database=database_name, + transaction=transaction_id, + structured_pipeline=self._to_pb(), + ) + return request + + def _execute_response_helper( + self, response: ExecutePipelineResponse + ) -> Iterable[PipelineResult]: + """ + shared logic for unpacking an ExecutePipelineReponse into PipelineResults + """ + for doc in response.results: + ref = self._client.document(doc.name) if doc.name else None + yield PipelineResult( + self._client, + doc.fields, + ref, + response._pb.execution_time, + doc._pb.create_time if doc.create_time else None, + doc._pb.update_time if doc.update_time else None, + ) + + def generic_stage(self, name: str, *params: Expr) -> "_BasePipeline": + """ + Adds a generic, named stage to the pipeline with specified parameters. + + This method provides a flexible way to extend the pipeline's functionality + by adding custom stages. Each generic stage is defined by a unique `name` + and a set of `params` that control its behavior. + + Example: + >>> # Assume we don't have a built-in "where" stage + >>> pipeline = client.pipeline().collection("books") + >>> pipeline = pipeline.generic_stage("where", [Field.of("published").lt(900)]) + >>> pipeline = pipeline.select("title", "author") + + Args: + name: The name of the generic stage. + *params: A sequence of `Expr` objects representing the parameters for the stage. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.GenericStage(name, *params)) diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index ec906f991..c23943b24 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -49,6 +49,8 @@ grpc as firestore_grpc_transport, ) from google.cloud.firestore_v1.transaction import Transaction +from google.cloud.firestore_v1.pipeline import Pipeline +from google.cloud.firestore_v1.pipeline_source import PipelineSource if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.bulk_writer import BulkWriter @@ -408,3 +410,10 @@ def transaction(self, **kwargs) -> Transaction: A transaction attached to this client. """ return Transaction(self, **kwargs) + + @property + def _pipeline_cls(self): + return Pipeline + + def pipeline(self) -> PipelineSource: + return PipelineSource(self) diff --git a/google/cloud/firestore_v1/field_path.py b/google/cloud/firestore_v1/field_path.py index 27ac6cc45..32516d3be 100644 --- a/google/cloud/firestore_v1/field_path.py +++ b/google/cloud/firestore_v1/field_path.py @@ -16,7 +16,7 @@ from __future__ import annotations import re from collections import abc -from typing import Iterable, cast +from typing import Any, Iterable, cast, MutableMapping _FIELD_PATH_MISSING_TOP = "{!r} is not contained in the data" _FIELD_PATH_MISSING_KEY = "{!r} is not contained in the data for the key {!r}" @@ -170,7 +170,7 @@ def render_field_path(field_names: Iterable[str]): get_field_path = render_field_path # backward-compatibility -def get_nested_value(field_path: str, data: dict): +def get_nested_value(field_path: str, data: MutableMapping[str, Any]): """Get a (potentially nested) value from a dictionary. If the data is nested, for example: diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py new file mode 100644 index 000000000..9f568f925 --- /dev/null +++ b/google/cloud/firestore_v1/pipeline.py @@ -0,0 +1,90 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Iterable, TYPE_CHECKING +from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1.base_pipeline import _BasePipeline + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.pipeline_result import PipelineResult + from google.cloud.firestore_v1.transaction import Transaction + + +class Pipeline(_BasePipeline): + """ + Pipelines allow for complex data transformations and queries involving + multiple stages like filtering, projection, aggregation, and vector search. + + Usage Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> + >>> def run_pipeline(): + ... client = Client(...) + ... pipeline = client.pipeline() + ... .collection("books") + ... .where(Field.of("published").gt(1980)) + ... .select("title", "author") + ... for result in pipeline.execute(): + ... print(result) + + Use `client.pipeline()` to create instances of this class. + """ + + def __init__(self, client: Client, *stages: stages.Stage): + """ + Initializes a Pipeline. + + Args: + client: The `Client` instance to use for execution. + *stages: Initial stages for the pipeline. + """ + super().__init__(client, *stages) + + def execute( + self, + transaction: "Transaction" | None = None, + ) -> list[PipelineResult]: + """ + Executes this pipeline and returns results as a list + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + """ + return [result for result in self.stream(transaction=transaction)] + + def stream( + self, + transaction: "Transaction" | None = None, + ) -> Iterable[PipelineResult]: + """ + Process this pipeline as a stream, providing results through an Iterable + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + """ + request = self._prep_execute_request(transaction) + for response in self._client._firestore_api.execute_pipeline(request): + yield from self._execute_response_helper(response) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py new file mode 100644 index 000000000..5e0c775a2 --- /dev/null +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -0,0 +1,85 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import ( + Any, + Generic, + TypeVar, + Dict, +) +from abc import ABC +from abc import abstractmethod +import datetime +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1._helpers import GeoPoint +from google.cloud.firestore_v1._helpers import encode_value + +CONSTANT_TYPE = TypeVar( + "CONSTANT_TYPE", + str, + int, + float, + bool, + datetime.datetime, + bytes, + GeoPoint, + Vector, + list, + Dict[str, Any], + None, +) + + +class Expr(ABC): + """Represents an expression that can be evaluated to a value within the + execution of a pipeline. + + Expressions are the building blocks for creating complex queries and + transformations in Firestore pipelines. They can represent: + + - **Field references:** Access values from document fields. + - **Literals:** Represent constant values (strings, numbers, booleans). + - **Function calls:** Apply functions to one or more expressions. + - **Aggregations:** Calculate aggregate values (e.g., sum, average) over a set of documents. + + The `Expr` class provides a fluent API for building expressions. You can chain + together method calls to create complex expressions. + """ + + def __repr__(self): + return f"{self.__class__.__name__}()" + + @abstractmethod + def _to_pb(self) -> Value: + raise NotImplementedError + + +class Constant(Expr, Generic[CONSTANT_TYPE]): + """Represents a constant literal value in an expression.""" + + def __init__(self, value: CONSTANT_TYPE): + self.value: CONSTANT_TYPE = value + + @staticmethod + def of(value: CONSTANT_TYPE) -> Constant[CONSTANT_TYPE]: + """Creates a constant expression from a Python value.""" + return Constant(value) + + def __repr__(self): + return f"Constant.of({self.value!r})" + + def _to_pb(self) -> Value: + return encode_value(self.value) diff --git a/google/cloud/firestore_v1/pipeline_result.py b/google/cloud/firestore_v1/pipeline_result.py new file mode 100644 index 000000000..ada855fea --- /dev/null +++ b/google/cloud/firestore_v1/pipeline_result.py @@ -0,0 +1,139 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Any, MutableMapping, TYPE_CHECKING +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.field_path import get_nested_value +from google.cloud.firestore_v1.field_path import FieldPath + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.base_client import BaseClient + from google.cloud.firestore_v1.base_document import BaseDocumentReference + from google.protobuf.timestamp_pb2 import Timestamp + from google.cloud.firestore_v1.types.document import Value as ValueProto + from google.cloud.firestore_v1.vector import Vector + + +class PipelineResult: + """ + Contains data read from a Firestore Pipeline. The data can be extracted with + the `data()` or `get()` methods. + + If the PipelineResult represents a non-document result `ref` may be `None`. + """ + + def __init__( + self, + client: BaseClient, + fields_pb: MutableMapping[str, ValueProto], + ref: BaseDocumentReference | None = None, + execution_time: Timestamp | None = None, + create_time: Timestamp | None = None, + update_time: Timestamp | None = None, + ): + """ + PipelineResult should be returned from `pipeline.execute()`, not constructed manually. + + Args: + client: The Firestore client instance. + fields_pb: A map of field names to their protobuf Value representations. + ref: The DocumentReference or AsyncDocumentReference if this result corresponds to a document. + execution_time: The time at which the pipeline execution producing this result occurred. + create_time: The creation time of the document, if applicable. + update_time: The last update time of the document, if applicable. + """ + self._client = client + self._fields_pb = fields_pb + self._ref = ref + self._execution_time = execution_time + self._create_time = create_time + self._update_time = update_time + + def __repr__(self): + return f"{type(self).__name__}(data={self.data()})" + + @property + def ref(self) -> BaseDocumentReference | None: + """ + The `BaseDocumentReference` if this result represents a document, else `None`. + """ + return self._ref + + @property + def id(self) -> str | None: + """The ID of the document if this result represents a document, else `None`.""" + return self._ref.id if self._ref else None + + @property + def create_time(self) -> Timestamp | None: + """The creation time of the document. `None` if not applicable.""" + return self._create_time + + @property + def update_time(self) -> Timestamp | None: + """The last update time of the document. `None` if not applicable.""" + return self._update_time + + @property + def execution_time(self) -> Timestamp: + """ + The time at which the pipeline producing this result was executed. + + Raise: + ValueError: if not set + """ + if self._execution_time is None: + raise ValueError("'execution_time' is expected to exist, but it is None.") + return self._execution_time + + def __eq__(self, other: object) -> bool: + """ + Compares this `PipelineResult` to another object for equality. + + Two `PipelineResult` instances are considered equal if their document + references (if any) are equal and their underlying field data + (protobuf representation) is identical. + """ + if not isinstance(other, PipelineResult): + return NotImplemented + return (self._ref == other._ref) and (self._fields_pb == other._fields_pb) + + def data(self) -> dict | "Vector" | None: + """ + Retrieves all fields in the result. + + Returns: + The data in dictionary format, or `None` if the document doesn't exist. + """ + if self._fields_pb is None: + return None + + return _helpers.decode_dict(self._fields_pb, self._client) + + def get(self, field_path: str | FieldPath) -> Any: + """ + Retrieves the field specified by `field_path`. + + Args: + field_path: The field path (e.g. 'foo' or 'foo.bar') to a specific field. + + Returns: + The data at the specified field location, decoded to Python types. + """ + str_path = ( + field_path if isinstance(field_path, str) else field_path.to_api_repr() + ) + value = get_nested_value(str_path, self._fields_pb) + return _helpers.decode_value(value, self._client) diff --git a/google/cloud/firestore_v1/pipeline_source.py b/google/cloud/firestore_v1/pipeline_source.py new file mode 100644 index 000000000..f2f081fee --- /dev/null +++ b/google/cloud/firestore_v1/pipeline_source.py @@ -0,0 +1,53 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Generic, TypeVar, TYPE_CHECKING +from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1.base_pipeline import _BasePipeline + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.async_client import AsyncClient + + +PipelineType = TypeVar("PipelineType", bound=_BasePipeline) + + +class PipelineSource(Generic[PipelineType]): + """ + A factory for creating Pipeline instances, which provide a framework for building data + transformation and query pipelines for Firestore. + + Not meant to be instantiated directly. Instead, start by calling client.pipeline() + to obtain an instance of PipelineSource. From there, you can use the provided + methods to specify the data source for your pipeline. + """ + + def __init__(self, client: Client | AsyncClient): + self.client = client + + def _create_pipeline(self, source_stage): + return self.client._pipeline_cls._create_with_stages(self.client, source_stage) + + def collection(self, path: str) -> PipelineType: + """ + Creates a new Pipeline that operates on a specified Firestore collection. + + Args: + path: The path to the Firestore collection (e.g., "users") + Returns: + a new pipeline instance targeting the specified collection + """ + return self._create_pipeline(stages.Collection(path)) diff --git a/noxfile.py b/noxfile.py index 9e81d7179..a01af1bad 100644 --- a/noxfile.py +++ b/noxfile.py @@ -70,6 +70,7 @@ SYSTEM_TEST_EXTERNAL_DEPENDENCIES: List[str] = [ "pytest-asyncio==0.21.2", "six", + "pyyaml", ] SYSTEM_TEST_LOCAL_DEPENDENCIES: List[str] = [] SYSTEM_TEST_DEPENDENCIES: List[str] = [] diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index 4924856a8..210aae88d 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -560,6 +560,17 @@ def test_asyncclient_transaction(): assert transaction._id is None +def test_asyncclient_pipeline(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1.pipeline_source import PipelineSource + + client = _make_default_async_client() + ppl = client.pipeline() + assert client._pipeline_cls == AsyncPipeline + assert isinstance(ppl, PipelineSource) + assert ppl.client == client + + def _make_credentials(): import google.auth.credentials diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py new file mode 100644 index 000000000..3abc3619b --- /dev/null +++ b/tests/unit/v1/test_async_pipeline.py @@ -0,0 +1,393 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mock +import pytest + +from google.cloud.firestore_v1 import _pipeline_stages as stages + + +def _make_async_pipeline(*args, client=mock.Mock()): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + return AsyncPipeline._create_with_stages(client, *args) + + +async def _async_it(list): + for value in list: + yield value + + +def test_ctor(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + client = object() + instance = AsyncPipeline(client) + assert instance._client == client + assert len(instance.stages) == 0 + + +def test_create(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + client = object() + stages = [object() for i in range(10)] + instance = AsyncPipeline._create_with_stages(client, *stages) + assert instance._client == client + assert len(instance.stages) == 10 + assert instance.stages[0] == stages[0] + assert instance.stages[-1] == stages[-1] + + +def test_async_pipeline_repr_empty(): + ppl = _make_async_pipeline() + repr_str = repr(ppl) + assert repr_str == "AsyncPipeline()" + + +def test_async_pipeline_repr_single_stage(): + stage = mock.Mock() + stage.__repr__ = lambda x: "SingleStage" + ppl = _make_async_pipeline(stage) + repr_str = repr(ppl) + assert repr_str == "AsyncPipeline(SingleStage)" + + +def test_async_pipeline_repr_multiple_stage(): + stage_1 = stages.Collection("path") + stage_2 = stages.GenericStage("second", 2) + stage_3 = stages.GenericStage("third", 3) + ppl = _make_async_pipeline(stage_1, stage_2, stage_3) + repr_str = repr(ppl) + assert repr_str == ( + "AsyncPipeline(\n" + " Collection(path='/path'),\n" + " GenericStage(name='second'),\n" + " GenericStage(name='third')\n" + ")" + ) + + +def test_async_pipeline_repr_long(): + num_stages = 100 + stage_list = [stages.GenericStage("custom", i) for i in range(num_stages)] + ppl = _make_async_pipeline(*stage_list) + repr_str = repr(ppl) + assert repr_str.count("GenericStage") == num_stages + assert repr_str.count("\n") == num_stages + 1 + + +def test_async_pipeline__to_pb(): + from google.cloud.firestore_v1.types.pipeline import StructuredPipeline + + stage_1 = stages.GenericStage("first") + stage_2 = stages.GenericStage("second") + ppl = _make_async_pipeline(stage_1, stage_2) + pb = ppl._to_pb() + assert isinstance(pb, StructuredPipeline) + assert pb.pipeline.stages[0] == stage_1._to_pb() + assert pb.pipeline.stages[1] == stage_2._to_pb() + + +def test_async_pipeline_append(): + """append should create a new pipeline with the additional stage""" + stage_1 = stages.GenericStage("first") + ppl_1 = _make_async_pipeline(stage_1, client=object()) + stage_2 = stages.GenericStage("second") + ppl_2 = ppl_1._append(stage_2) + assert ppl_1 != ppl_2 + assert len(ppl_1.stages) == 1 + assert len(ppl_2.stages) == 2 + assert ppl_2.stages[0] == stage_1 + assert ppl_2.stages[1] == stage_2 + assert ppl_1._client == ppl_2._client + assert isinstance(ppl_2, type(ppl_1)) + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_empty(): + """ + test stream pipeline with mocked empty response + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + mock_rpc.return_value = _async_it([ExecutePipelineResponse()]) + ppl_1 = _make_async_pipeline(stages.GenericStage("s"), client=client) + + results = [r async for r in ppl_1.stream()] + assert results == [] + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_no_doc_ref(): + """ + test stream pipeline with no doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + mock_rpc.return_value = _async_it( + [ExecutePipelineResponse(results=[Document()], execution_time={"seconds": 9})] + ) + ppl_1 = _make_async_pipeline(stages.GenericStage("s"), client=client) + + results = [r async for r in ppl_1.stream()] + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.transaction == b"" + + response = results[0] + assert isinstance(response, PipelineResult) + assert response.ref is None + assert response.id is None + assert response.create_time is None + assert response.update_time is None + assert response.execution_time.seconds == 9 + assert response.data() == {} + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_populated(): + """ + test stream pipeline with fully populated doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + + mock_rpc.return_value = _async_it( + [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + create_time={"seconds": 1}, + update_time={"seconds": 2}, + fields={"key": Value(string_value="str_val")}, + ) + ], + execution_time={"seconds": 9}, + ) + ] + ) + ppl_1 = _make_async_pipeline(client=client) + + results = [r async for r in ppl_1.stream()] + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + response = results[0] + assert isinstance(response, PipelineResult) + assert isinstance(response.ref, DocumentReference) + assert response.ref.path == "test/my_doc" + assert response.id == "my_doc" + assert response.create_time.seconds == 1 + assert response.update_time.seconds == 2 + assert response.execution_time.seconds == 9 + assert response.data() == {"key": "str_val"} + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_multiple(): + """ + test stream pipeline with multiple docs and responses + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + + mock_rpc.return_value = _async_it( + [ + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=0)}), + Document(fields={"key": Value(integer_value=1)}), + ], + execution_time={"seconds": 0}, + ), + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=2)}), + Document(fields={"key": Value(integer_value=3)}), + ], + execution_time={"seconds": 1}, + ), + ] + ) + ppl_1 = _make_async_pipeline(client=client) + + results = [r async for r in ppl_1.stream()] + assert len(results) == 4 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + for idx, response in enumerate(results): + assert isinstance(response, PipelineResult) + assert response.data() == {"key": idx} + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_with_transaction(): + """ + test stream pipeline with transaction context + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.async_transaction import AsyncTransaction + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + + transaction = AsyncTransaction(client) + transaction._id = b"123" + + mock_rpc.return_value = _async_it([ExecutePipelineResponse()]) + ppl_1 = _make_async_pipeline(client=client) + + [r async for r in ppl_1.stream(transaction=transaction)] + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.transaction == b"123" + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_stream_equivalence(): + """ + Pipeline.stream should provide same results from pipeline.stream, as a list + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + mock_response = [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + fields={"key": Value(string_value="str_val")}, + ) + ], + ) + ] + mock_rpc.return_value = _async_it(mock_response) + ppl_1 = _make_async_pipeline(client=client) + + stream_results = [r async for r in ppl_1.stream()] + # reset response + mock_rpc.return_value = _async_it(mock_response) + stream_results = await ppl_1.execute() + assert stream_results == stream_results + assert stream_results[0].data()["key"] == "str_val" + assert stream_results[0].data()["key"] == "str_val" + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_stream_equivalence_mocked(): + """ + pipeline.stream should call pipeline.stream internally + """ + ppl_1 = _make_async_pipeline() + expected_data = [object(), object()] + expected_arg = object() + with mock.patch.object(ppl_1, "stream") as mock_stream: + mock_stream.return_value = _async_it(expected_data) + stream_results = await ppl_1.execute(expected_arg) + assert mock_stream.call_count == 1 + assert mock_stream.call_args[0] == () + assert len(mock_stream.call_args[1]) == 1 + assert mock_stream.call_args[1]["transaction"] == expected_arg + assert stream_results == expected_data + + +@pytest.mark.parametrize( + "method,args,result_cls", + [ + ("generic_stage", ("name",), stages.GenericStage), + ("generic_stage", ("name", mock.Mock()), stages.GenericStage), + ], +) +def test_async_pipeline_methods(method, args, result_cls): + start_ppl = _make_async_pipeline() + method_ptr = getattr(start_ppl, method) + result_ppl = method_ptr(*args) + assert result_ppl != start_ppl + assert len(start_ppl.stages) == 0 + assert len(result_ppl.stages) == 1 + assert isinstance(result_ppl.stages[0], result_cls) diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index df3ae15b4..9d0199f92 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -648,6 +648,18 @@ def test_client_transaction(database): assert transaction._id is None +@pytest.mark.parametrize("database", [None, DEFAULT_DATABASE, "somedb"]) +def test_client_pipeline(database): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1.pipeline_source import PipelineSource + + client = _make_default_client(database=database) + ppl = client.pipeline() + assert client._pipeline_cls == Pipeline + assert isinstance(ppl, PipelineSource) + assert ppl.client == client + + def _make_batch_response(**kwargs): from google.cloud.firestore_v1.types import firestore diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py new file mode 100644 index 000000000..6a3fef3ac --- /dev/null +++ b/tests/unit/v1/test_pipeline.py @@ -0,0 +1,370 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mock +import pytest + +from google.cloud.firestore_v1 import _pipeline_stages as stages + + +def _make_pipeline(*args, client=mock.Mock()): + from google.cloud.firestore_v1.pipeline import Pipeline + + return Pipeline._create_with_stages(client, *args) + + +def test_ctor(): + from google.cloud.firestore_v1.pipeline import Pipeline + + client = object() + instance = Pipeline(client) + assert instance._client == client + assert len(instance.stages) == 0 + + +def test_create(): + from google.cloud.firestore_v1.pipeline import Pipeline + + client = object() + stages = [object() for i in range(10)] + instance = Pipeline._create_with_stages(client, *stages) + assert instance._client == client + assert len(instance.stages) == 10 + assert instance.stages[0] == stages[0] + assert instance.stages[-1] == stages[-1] + + +def test_pipeline_repr_empty(): + ppl = _make_pipeline() + repr_str = repr(ppl) + assert repr_str == "Pipeline()" + + +def test_pipeline_repr_single_stage(): + stage = mock.Mock() + stage.__repr__ = lambda x: "SingleStage" + ppl = _make_pipeline(stage) + repr_str = repr(ppl) + assert repr_str == "Pipeline(SingleStage)" + + +def test_pipeline_repr_multiple_stage(): + stage_1 = stages.Collection("path") + stage_2 = stages.GenericStage("second", 2) + stage_3 = stages.GenericStage("third", 3) + ppl = _make_pipeline(stage_1, stage_2, stage_3) + repr_str = repr(ppl) + assert repr_str == ( + "Pipeline(\n" + " Collection(path='/path'),\n" + " GenericStage(name='second'),\n" + " GenericStage(name='third')\n" + ")" + ) + + +def test_pipeline_repr_long(): + num_stages = 100 + stage_list = [stages.GenericStage("custom", i) for i in range(num_stages)] + ppl = _make_pipeline(*stage_list) + repr_str = repr(ppl) + assert repr_str.count("GenericStage") == num_stages + assert repr_str.count("\n") == num_stages + 1 + + +def test_pipeline__to_pb(): + from google.cloud.firestore_v1.types.pipeline import StructuredPipeline + + stage_1 = stages.GenericStage("first") + stage_2 = stages.GenericStage("second") + ppl = _make_pipeline(stage_1, stage_2) + pb = ppl._to_pb() + assert isinstance(pb, StructuredPipeline) + assert pb.pipeline.stages[0] == stage_1._to_pb() + assert pb.pipeline.stages[1] == stage_2._to_pb() + + +def test_pipeline_append(): + """append should create a new pipeline with the additional stage""" + + stage_1 = stages.GenericStage("first") + ppl_1 = _make_pipeline(stage_1, client=object()) + stage_2 = stages.GenericStage("second") + ppl_2 = ppl_1._append(stage_2) + assert ppl_1 != ppl_2 + assert len(ppl_1.stages) == 1 + assert len(ppl_2.stages) == 2 + assert ppl_2.stages[0] == stage_1 + assert ppl_2.stages[1] == stage_2 + assert ppl_1._client == ppl_2._client + assert isinstance(ppl_2, type(ppl_1)) + + +def test_pipeline_stream_empty(): + """ + test stream pipeline with mocked empty response + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = client._firestore_api.execute_pipeline + mock_rpc.return_value = [ExecutePipelineResponse()] + ppl_1 = _make_pipeline(stages.GenericStage("s"), client=client) + + results = list(ppl_1.stream()) + assert results == [] + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + +def test_pipeline_stream_no_doc_ref(): + """ + test stream pipeline with no doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = client._firestore_api.execute_pipeline + mock_rpc.return_value = [ + ExecutePipelineResponse(results=[Document()], execution_time={"seconds": 9}) + ] + ppl_1 = _make_pipeline(stages.GenericStage("s"), client=client) + + results = list(ppl_1.stream()) + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + response = results[0] + assert isinstance(response, PipelineResult) + assert response.ref is None + assert response.id is None + assert response.create_time is None + assert response.update_time is None + assert response.execution_time.seconds == 9 + assert response.data() == {} + + +def test_pipeline_stream_populated(): + """ + test stream pipeline with fully populated doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = client._firestore_api.execute_pipeline + + mock_rpc.return_value = [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + create_time={"seconds": 1}, + update_time={"seconds": 2}, + fields={"key": Value(string_value="str_val")}, + ) + ], + execution_time={"seconds": 9}, + ) + ] + ppl_1 = _make_pipeline(client=client) + + results = list(ppl_1.stream()) + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.transaction == b"" + + response = results[0] + assert isinstance(response, PipelineResult) + assert isinstance(response.ref, DocumentReference) + assert response.ref.path == "test/my_doc" + assert response.id == "my_doc" + assert response.create_time.seconds == 1 + assert response.update_time.seconds == 2 + assert response.execution_time.seconds == 9 + assert response.data() == {"key": "str_val"} + + +def test_pipeline_stream_multiple(): + """ + test stream pipeline with multiple docs and responses + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = client._firestore_api.execute_pipeline + + mock_rpc.return_value = [ + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=0)}), + Document(fields={"key": Value(integer_value=1)}), + ], + execution_time={"seconds": 0}, + ), + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=2)}), + Document(fields={"key": Value(integer_value=3)}), + ], + execution_time={"seconds": 1}, + ), + ] + ppl_1 = _make_pipeline(client=client) + + results = list(ppl_1.stream()) + assert len(results) == 4 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + for idx, response in enumerate(results): + assert isinstance(response, PipelineResult) + assert response.data() == {"key": idx} + + +def test_pipeline_stream_with_transaction(): + """ + test stream pipeline with fully populated doc ref + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.transaction import Transaction + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = client._firestore_api.execute_pipeline + + transaction = Transaction(client) + transaction._id = b"123" + + mock_rpc.return_value = [ExecutePipelineResponse()] + ppl_1 = _make_pipeline(client=client) + + list(ppl_1.stream(transaction=transaction)) + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.transaction == b"123" + + +def test_pipeline_execute_stream_equivalence(): + """ + Pipeline.execute should provide same results from pipeline.stream, as a list + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = client._firestore_api.execute_pipeline + + mock_rpc.return_value = [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + fields={"key": Value(string_value="str_val")}, + ) + ], + ) + ] + ppl_1 = _make_pipeline(client=client) + + stream_results = list(ppl_1.stream()) + execute_results = ppl_1.execute() + assert stream_results == execute_results + assert stream_results[0].data()["key"] == "str_val" + assert execute_results[0].data()["key"] == "str_val" + + +def test_pipeline_execute_stream_equivalence_mocked(): + """ + pipeline.execute should call pipeline.stream internally + """ + ppl_1 = _make_pipeline() + expected_data = [object(), object()] + expected_arg = object() + with mock.patch.object(ppl_1, "stream") as mock_stream: + mock_stream.return_value = expected_data + stream_results = ppl_1.execute(expected_arg) + assert mock_stream.call_count == 1 + assert mock_stream.call_args[0] == () + assert len(mock_stream.call_args[1]) == 1 + assert mock_stream.call_args[1]["transaction"] == expected_arg + assert stream_results == expected_data + + +@pytest.mark.parametrize( + "method,args,result_cls", + [ + ("generic_stage", ("name",), stages.GenericStage), + ("generic_stage", ("name", mock.Mock()), stages.GenericStage), + ], +) +def test_pipeline_methods(method, args, result_cls): + start_ppl = _make_pipeline() + method_ptr = getattr(start_ppl, method) + result_ppl = method_ptr(*args) + assert result_ppl != start_ppl + assert len(start_ppl.stages) == 0 + assert len(result_ppl.stages) == 1 + assert isinstance(result_ppl.stages[0], result_cls) diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py new file mode 100644 index 000000000..19ebed3b5 --- /dev/null +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -0,0 +1,104 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import pytest +import datetime + +import google.cloud.firestore_v1.pipeline_expressions as expressions +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1._helpers import GeoPoint + + +class TestExpr: + def test_ctor(self): + """ + Base class should be abstract + """ + with pytest.raises(TypeError): + expressions.Expr() + + +class TestConstant: + @pytest.mark.parametrize( + "input_val, to_pb_val", + [ + ("test", Value(string_value="test")), + ("", Value(string_value="")), + (10, Value(integer_value=10)), + (0, Value(integer_value=0)), + (10.0, Value(double_value=10)), + (0.0, Value(double_value=0)), + (True, Value(boolean_value=True)), + (b"test", Value(bytes_value=b"test")), + (None, Value(null_value=0)), + ( + datetime.datetime(2025, 5, 12), + Value(timestamp_value={"seconds": 1747008000}), + ), + (GeoPoint(1, 2), Value(geo_point_value={"latitude": 1, "longitude": 2})), + ( + [0.0, 1.0, 2.0], + Value( + array_value={"values": [Value(double_value=i) for i in range(3)]} + ), + ), + ({"a": "b"}, Value(map_value={"fields": {"a": Value(string_value="b")}})), + ( + Vector([1.0, 2.0]), + Value( + map_value={ + "fields": { + "__type__": Value(string_value="__vector__"), + "value": Value( + array_value={ + "values": [Value(double_value=v) for v in [1, 2]], + } + ), + } + } + ), + ), + ], + ) + def test_to_pb(self, input_val, to_pb_val): + instance = expressions.Constant.of(input_val) + assert instance._to_pb() == to_pb_val + + @pytest.mark.parametrize( + "input_val,expected", + [ + ("test", "Constant.of('test')"), + ("", "Constant.of('')"), + (10, "Constant.of(10)"), + (0, "Constant.of(0)"), + (10.0, "Constant.of(10.0)"), + (0.0, "Constant.of(0.0)"), + (True, "Constant.of(True)"), + (b"test", "Constant.of(b'test')"), + (None, "Constant.of(None)"), + ( + datetime.datetime(2025, 5, 12), + "Constant.of(datetime.datetime(2025, 5, 12, 0, 0))", + ), + (GeoPoint(1, 2), "Constant.of(GeoPoint(latitude=1, longitude=2))"), + ([1, 2, 3], "Constant.of([1, 2, 3])"), + ({"a": "b"}, "Constant.of({'a': 'b'})"), + (Vector([1.0, 2.0]), "Constant.of(Vector<1.0, 2.0>)"), + ], + ) + def test_repr(self, input_val, expected): + instance = expressions.Constant.of(input_val) + repr_string = repr(instance) + assert repr_string == expected diff --git a/tests/unit/v1/test_pipeline_result.py b/tests/unit/v1/test_pipeline_result.py new file mode 100644 index 000000000..2facf7110 --- /dev/null +++ b/tests/unit/v1/test_pipeline_result.py @@ -0,0 +1,176 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mock +import pytest + +from google.cloud.firestore_v1.pipeline_result import PipelineResult + + +class TestPipelineResult: + def _make_one(self, *args, **kwargs): + if not args: + # use defaults if not passed + args = [mock.Mock(), {}] + return PipelineResult(*args, **kwargs) + + def test_ref(self): + expected = object() + instance = self._make_one(ref=expected) + assert instance.ref == expected + # should be None if not set + assert self._make_one().ref is None + + def test_id(self): + ref = mock.Mock() + ref.id = "test" + instance = self._make_one(ref=ref) + assert instance.id == "test" + # should be None if not set + assert self._make_one().id is None + + def test_create_time(self): + expected = object() + instance = self._make_one(create_time=expected) + assert instance.create_time == expected + # should be None if not set + assert self._make_one().create_time is None + + def test_update_time(self): + expected = object() + instance = self._make_one(update_time=expected) + assert instance.update_time == expected + # should be None if not set + assert self._make_one().update_time is None + + def test_exection_time(self): + expected = object() + instance = self._make_one(execution_time=expected) + assert instance.execution_time == expected + # should raise if not set + with pytest.raises(ValueError) as e: + self._make_one().execution_time + assert "execution_time" in e + + @pytest.mark.parametrize( + "first,second,result", + [ + ((object(), {}), (object(), {}), True), + ((object(), {1: 1}), (object(), {1: 1}), True), + ((object(), {1: 1}), (object(), {2: 2}), False), + ((object(), {}, "ref"), (object(), {}, "ref"), True), + ((object(), {}, "ref"), (object(), {}, "diff"), False), + ((object(), {1: 1}, "ref"), (object(), {1: 1}, "ref"), True), + ((object(), {1: 1}, "ref"), (object(), {2: 2}, "ref"), False), + ((object(), {1: 1}, "ref"), (object(), {1: 1}, "diff"), False), + ( + (object(), {1: 1}, "ref", 1, 2, 3), + (object(), {1: 1}, "ref", 4, 5, 6), + True, + ), + ], + ) + def test_eq(self, first, second, result): + first_obj = self._make_one(*first) + second_obj = self._make_one(*second) + assert (first_obj == second_obj) is result + + def test_eq_wrong_type(self): + instance = self._make_one() + result = instance == object() + assert result is False + + def test_data(self): + from google.cloud.firestore_v1.types.document import Value + + client = mock.Mock() + data = {"str": Value(string_value="hello world"), "int": Value(integer_value=5)} + instance = self._make_one(client, data) + got = instance.data() + assert len(got) == 2 + assert got["str"] == "hello world" + assert got["int"] == 5 + + def test_data_none(self): + client = object() + data = None + instance = self._make_one(client, data) + assert instance.data() is None + + def test_data_call(self): + """ + ensure decode_dict is called on .data + """ + client = object() + data = {"hello": "world"} + instance = self._make_one(client, data) + with mock.patch( + "google.cloud.firestore_v1._helpers.decode_dict" + ) as decode_mock: + got = instance.data() + decode_mock.assert_called_once_with(data, client) + assert got == decode_mock.return_value + + def test_get(self): + from google.cloud.firestore_v1.types.document import Value + + client = object() + data = {"key": Value(string_value="hello world")} + instance = self._make_one(client, data) + got = instance.get("key") + assert got == "hello world" + + def test_get_nested(self): + from google.cloud.firestore_v1.types.document import Value + + client = object() + data = {"first": {"second": Value(string_value="hello world")}} + instance = self._make_one(client, data) + got = instance.get("first.second") + assert got == "hello world" + + def test_get_field_path(self): + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.field_path import FieldPath + + client = object() + data = {"first": {"second": Value(string_value="hello world")}} + path = FieldPath.from_string("first.second") + instance = self._make_one(client, data) + got = instance.get(path) + assert got == "hello world" + + def test_get_failure(self): + """ + test calling get on value not in data + """ + client = object() + data = {} + instance = self._make_one(client, data) + with pytest.raises(KeyError): + instance.get("key") + + def test_get_call(self): + """ + ensure decode_value is called on .get() + """ + client = object() + data = {"key": "value"} + instance = self._make_one(client, data) + with mock.patch( + "google.cloud.firestore_v1._helpers.decode_value" + ) as decode_mock: + got = instance.get("key") + decode_mock.assert_called_once_with("value", client) + assert got == decode_mock.return_value diff --git a/tests/unit/v1/test_pipeline_source.py b/tests/unit/v1/test_pipeline_source.py new file mode 100644 index 000000000..cd8b56b68 --- /dev/null +++ b/tests/unit/v1/test_pipeline_source.py @@ -0,0 +1,56 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from google.cloud.firestore_v1.pipeline_source import PipelineSource +from google.cloud.firestore_v1.pipeline import Pipeline +from google.cloud.firestore_v1.async_pipeline import AsyncPipeline +from google.cloud.firestore_v1.client import Client +from google.cloud.firestore_v1.async_client import AsyncClient +from google.cloud.firestore_v1 import _pipeline_stages as stages + + +class TestPipelineSource: + _expected_pipeline_type = Pipeline + + def _make_client(self): + return Client() + + def test_make_from_client(self): + instance = self._make_client().pipeline() + assert isinstance(instance, PipelineSource) + + def test_create_pipeline(self): + instance = self._make_client().pipeline() + ppl = instance._create_pipeline(None) + assert isinstance(ppl, self._expected_pipeline_type) + + def test_collection(self): + instance = self._make_client().pipeline() + ppl = instance.collection("path") + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + first_stage = ppl.stages[0] + assert isinstance(first_stage, stages.Collection) + assert first_stage.path == "/path" + + +class TestPipelineSourceWithAsyncClient(TestPipelineSource): + """ + When an async client is used, it should produce async pipelines + """ + + _expected_pipeline_type = AsyncPipeline + + def _make_client(self): + return AsyncClient() diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py new file mode 100644 index 000000000..59d808d63 --- /dev/null +++ b/tests/unit/v1/test_pipeline_stages.py @@ -0,0 +1,121 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import pytest + +import google.cloud.firestore_v1._pipeline_stages as stages +from google.cloud.firestore_v1.pipeline_expressions import Constant +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1._helpers import GeoPoint + + +class TestStage: + def test_ctor(self): + """ + Base class should be abstract + """ + with pytest.raises(TypeError): + stages.Stage() + + +class TestCollection: + def _make_one(self, *args, **kwargs): + return stages.Collection(*args, **kwargs) + + @pytest.mark.parametrize( + "input_arg,expected", + [ + ("test", "Collection(path='/test')"), + ("/test", "Collection(path='/test')"), + ], + ) + def test_repr(self, input_arg, expected): + instance = self._make_one(input_arg) + repr_str = repr(instance) + assert repr_str == expected + + def test_to_pb(self): + input_arg = "test/col" + instance = self._make_one(input_arg) + result = instance._to_pb() + assert result.name == "collection" + assert len(result.args) == 1 + assert result.args[0].reference_value == "/test/col" + assert len(result.options) == 0 + + +class TestGenericStage: + def _make_one(self, *args, **kwargs): + return stages.GenericStage(*args, **kwargs) + + @pytest.mark.parametrize( + "input_args,expected_params", + [ + (("name",), []), + (("custom", Value(string_value="val")), [Value(string_value="val")]), + (("n", Value(integer_value=1)), [Value(integer_value=1)]), + (("n", Constant.of(1)), [Value(integer_value=1)]), + ( + ("n", Constant.of(True), Constant.of(False)), + [Value(boolean_value=True), Value(boolean_value=False)], + ), + ( + ("n", Constant.of(GeoPoint(1, 2))), + [Value(geo_point_value={"latitude": 1, "longitude": 2})], + ), + (("n", Constant.of(None)), [Value(null_value=0)]), + ( + ("n", Constant.of([0, 1, 2])), + [ + Value( + array_value={ + "values": [Value(integer_value=n) for n in range(3)] + } + ) + ], + ), + ( + ("n", Value(reference_value="/projects/p/databases/d/documents/doc")), + [Value(reference_value="/projects/p/databases/d/documents/doc")], + ), + ( + ("n", Constant.of({"a": "b"})), + [Value(map_value={"fields": {"a": Value(string_value="b")}})], + ), + ], + ) + def test_ctor(self, input_args, expected_params): + instance = self._make_one(*input_args) + assert instance.params == expected_params + + @pytest.mark.parametrize( + "input_args,expected", + [ + (("name",), "GenericStage(name='name')"), + (("custom", Value(string_value="val")), "GenericStage(name='custom')"), + ], + ) + def test_repr(self, input_args, expected): + instance = self._make_one(*input_args) + repr_str = repr(instance) + assert repr_str == expected + + def test_to_pb(self): + instance = self._make_one("name", Constant.of(True), Constant.of("test")) + result = instance._to_pb() + assert result.name == "name" + assert len(result.args) == 2 + assert result.args[0].boolean_value is True + assert result.args[1].string_value == "test" + assert len(result.options) == 0 From 6e0633600fad4c83f779e92bb66fab99a3a4cb57 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 16 Jul 2025 10:57:01 -0700 Subject: [PATCH 03/27] feat: add primary pipeline stages (#1048) --- google/cloud/firestore_v1/_pipeline_stages.py | 392 ++- google/cloud/firestore_v1/base_pipeline.py | 464 +++- .../firestore_v1/pipeline_expressions.py | 2217 +++++++++++++++++ google/cloud/firestore_v1/pipeline_source.py | 38 +- tests/system/pipeline_e2e.yaml | 1640 ++++++++++++ tests/system/test__helpers.py | 6 + tests/system/test_pipeline_acceptance.py | 285 +++ tests/system/test_system.py | 284 ++- tests/system/test_system_async.py | 230 +- tests/unit/v1/test_async_pipeline.py | 42 +- tests/unit/v1/test_pipeline.py | 42 +- tests/unit/v1/test_pipeline_expressions.py | 1137 ++++++++- tests/unit/v1/test_pipeline_source.py | 44 + tests/unit/v1/test_pipeline_stages.py | 690 ++++- 14 files changed, 7242 insertions(+), 269 deletions(-) create mode 100644 tests/system/pipeline_e2e.yaml create mode 100644 tests/system/test_pipeline_acceptance.py diff --git a/google/cloud/firestore_v1/_pipeline_stages.py b/google/cloud/firestore_v1/_pipeline_stages.py index 3871a363d..f7d311d89 100644 --- a/google/cloud/firestore_v1/_pipeline_stages.py +++ b/google/cloud/firestore_v1/_pipeline_stages.py @@ -13,13 +13,109 @@ # limitations under the License. from __future__ import annotations -from typing import Optional +from typing import Optional, Sequence, TYPE_CHECKING from abc import ABC from abc import abstractmethod +from enum import Enum from google.cloud.firestore_v1.types.document import Pipeline as Pipeline_pb from google.cloud.firestore_v1.types.document import Value -from google.cloud.firestore_v1.pipeline_expressions import Expr +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure +from google.cloud.firestore_v1.pipeline_expressions import ( + Accumulator, + Expr, + ExprWithAlias, + Field, + FilterCondition, + Selectable, + Ordering, +) +from google.cloud.firestore_v1._helpers import encode_value + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.base_pipeline import _BasePipeline + from google.cloud.firestore_v1.base_document import BaseDocumentReference + + +class FindNearestOptions: + """Options for configuring the `FindNearest` pipeline stage. + + Attributes: + limit (Optional[int]): The maximum number of nearest neighbors to return. + distance_field (Optional[Field]): An optional field to store the calculated + distance in the output documents. + """ + + def __init__( + self, + limit: Optional[int] = None, + distance_field: Optional[Field] = None, + ): + self.limit = limit + self.distance_field = distance_field + + def __repr__(self): + args = [] + if self.limit is not None: + args.append(f"limit={self.limit}") + if self.distance_field is not None: + args.append(f"distance_field={self.distance_field}") + return f"{self.__class__.__name__}({', '.join(args)})" + + +class SampleOptions: + """Options for the 'sample' pipeline stage.""" + + class Mode(Enum): + DOCUMENTS = "documents" + PERCENT = "percent" + + def __init__(self, value: int | float, mode: Mode | str): + self.value = value + self.mode = SampleOptions.Mode[mode.upper()] if isinstance(mode, str) else mode + + def __repr__(self): + if self.mode == SampleOptions.Mode.DOCUMENTS: + mode_str = "doc_limit" + else: + mode_str = "percentage" + return f"SampleOptions.{mode_str}({self.value})" + + @staticmethod + def doc_limit(value: int): + """ + Sample a set number of documents + + Args: + value: number of documents to sample + """ + return SampleOptions(value, mode=SampleOptions.Mode.DOCUMENTS) + + @staticmethod + def percentage(value: float): + """ + Sample a percentage of documents + + Args: + value: percentage of documents to return + """ + return SampleOptions(value, mode=SampleOptions.Mode.PERCENT) + + +class UnnestOptions: + """Options for configuring the `Unnest` pipeline stage. + + Attributes: + index_field (str): The name of the field to add to each output document, + storing the original 0-based index of the element within the array. + """ + + def __init__(self, index_field: str): + self.index_field = index_field + + def __repr__(self): + return f"{self.__class__.__name__}(index_field={self.index_field!r})" class Stage(ABC): @@ -52,6 +148,68 @@ def __repr__(self): return f"{self.__class__.__name__}({', '.join(items)})" +class AddFields(Stage): + """Adds new fields to outputs from previous stages.""" + + def __init__(self, *fields: Selectable): + super().__init__("add_fields") + self.fields = list(fields) + + def _pb_args(self): + return [ + Value( + map_value={ + "fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]} + } + ) + ] + + +class Aggregate(Stage): + """Performs aggregation operations, optionally grouped.""" + + def __init__( + self, + *args: ExprWithAlias[Accumulator], + accumulators: Sequence[ExprWithAlias[Accumulator]] = (), + groups: Sequence[str | Selectable] = (), + ): + super().__init__() + self.groups: list[Selectable] = [ + Field(f) if isinstance(f, str) else f for f in groups + ] + if args and accumulators: + raise ValueError( + "Aggregate stage contains both positional and keyword accumulators" + ) + self.accumulators = args or accumulators + + def _pb_args(self): + return [ + Value( + map_value={ + "fields": { + m[0]: m[1] for m in [f._to_map() for f in self.accumulators] + } + } + ), + Value( + map_value={ + "fields": {m[0]: m[1] for m in [f._to_map() for f in self.groups]} + } + ), + ] + + def __repr__(self): + accumulator_str = ", ".join(repr(v) for v in self.accumulators) + group_str = "" + if self.groups: + if self.accumulators: + group_str = ", " + group_str += f"groups={self.groups}" + return f"{self.__class__.__name__}({accumulator_str}{group_str})" + + class Collection(Stage): """Specifies a collection as the initial data source.""" @@ -65,6 +223,103 @@ def _pb_args(self): return [Value(reference_value=self.path)] +class CollectionGroup(Stage): + """Specifies a collection group as the initial data source.""" + + def __init__(self, collection_id: str): + super().__init__("collection_group") + self.collection_id = collection_id + + def _pb_args(self): + return [Value(string_value=self.collection_id)] + + +class Database(Stage): + """Specifies the default database as the initial data source.""" + + def __init__(self): + super().__init__() + + def _pb_args(self): + return [] + + +class Distinct(Stage): + """Returns documents with distinct combinations of specified field values.""" + + def __init__(self, *fields: str | Selectable): + super().__init__() + self.fields: list[Selectable] = [ + Field(f) if isinstance(f, str) else f for f in fields + ] + + def _pb_args(self) -> list[Value]: + return [ + Value( + map_value={ + "fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]} + } + ) + ] + + +class Documents(Stage): + """Specifies specific documents as the initial data source.""" + + def __init__(self, *paths: str): + super().__init__() + self.paths = paths + + def __repr__(self): + return f"{self.__class__.__name__}({', '.join([repr(p) for p in self.paths])})" + + @staticmethod + def of(*documents: "BaseDocumentReference") -> "Documents": + doc_paths = ["/" + doc.path for doc in documents] + return Documents(*doc_paths) + + def _pb_args(self): + return [ + Value( + array_value={ + "values": [Value(string_value=path) for path in self.paths] + } + ) + ] + + +class FindNearest(Stage): + """Performs vector distance (similarity) search.""" + + def __init__( + self, + field: str | Expr, + vector: Sequence[float] | Vector, + distance_measure: "DistanceMeasure", + options: Optional["FindNearestOptions"] = None, + ): + super().__init__("find_nearest") + self.field: Expr = Field(field) if isinstance(field, str) else field + self.vector: Vector = vector if isinstance(vector, Vector) else Vector(vector) + self.distance_measure = distance_measure + self.options = options or FindNearestOptions() + + def _pb_args(self): + return [ + self.field._to_pb(), + encode_value(self.vector), + Value(string_value=self.distance_measure.name.lower()), + ] + + def _pb_options(self) -> dict[str, Value]: + options = {} + if self.options and self.options.limit is not None: + options["limit"] = Value(integer_value=self.options.limit) + if self.options and self.options.distance_field is not None: + options["distance_field"] = self.options.distance_field._to_pb() + return options + + class GenericStage(Stage): """Represents a generic, named stage with parameters.""" @@ -79,3 +334,136 @@ def _pb_args(self): def __repr__(self): return f"{self.__class__.__name__}(name='{self.name}')" + + +class Limit(Stage): + """Limits the maximum number of documents returned.""" + + def __init__(self, limit: int): + super().__init__() + self.limit = limit + + def _pb_args(self): + return [Value(integer_value=self.limit)] + + +class Offset(Stage): + """Skips a specified number of documents.""" + + def __init__(self, offset: int): + super().__init__() + self.offset = offset + + def _pb_args(self): + return [Value(integer_value=self.offset)] + + +class RemoveFields(Stage): + """Removes specified fields from outputs.""" + + def __init__(self, *fields: str | Field): + super().__init__("remove_fields") + self.fields = [Field(f) if isinstance(f, str) else f for f in fields] + + def __repr__(self): + return f"{self.__class__.__name__}({', '.join(repr(f) for f in self.fields)})" + + def _pb_args(self) -> list[Value]: + return [f._to_pb() for f in self.fields] + + +class Sample(Stage): + """Performs pseudo-random sampling of documents.""" + + def __init__(self, limit_or_options: int | SampleOptions): + super().__init__() + if isinstance(limit_or_options, int): + options = SampleOptions.doc_limit(limit_or_options) + else: + options = limit_or_options + self.options: SampleOptions = options + + def _pb_args(self): + if self.options.mode == SampleOptions.Mode.DOCUMENTS: + return [ + Value(integer_value=self.options.value), + Value(string_value="documents"), + ] + else: + return [ + Value(double_value=self.options.value), + Value(string_value="percent"), + ] + + +class Select(Stage): + """Selects or creates a set of fields.""" + + def __init__(self, *selections: str | Selectable): + super().__init__() + self.projections = [Field(s) if isinstance(s, str) else s for s in selections] + + def _pb_args(self) -> list[Value]: + return [Selectable._value_from_selectables(*self.projections)] + + +class Sort(Stage): + """Sorts documents based on specified criteria.""" + + def __init__(self, *orders: "Ordering"): + super().__init__() + self.orders = list(orders) + + def _pb_args(self): + return [o._to_pb() for o in self.orders] + + +class Union(Stage): + """Performs a union of documents from two pipelines.""" + + def __init__(self, other: _BasePipeline): + super().__init__() + self.other = other + + def _pb_args(self): + return [Value(pipeline_value=self.other._to_pb().pipeline)] + + +class Unnest(Stage): + """Produces a document for each element in an array field.""" + + def __init__( + self, + field: Selectable | str, + alias: Field | str | None = None, + options: UnnestOptions | None = None, + ): + super().__init__() + self.field: Selectable = Field(field) if isinstance(field, str) else field + if alias is None: + self.alias = self.field + elif isinstance(alias, str): + self.alias = Field(alias) + else: + self.alias = alias + self.options = options + + def _pb_args(self): + return [self.field._to_pb(), self.alias._to_pb()] + + def _pb_options(self): + options = {} + if self.options is not None: + options["index_field"] = Value(string_value=self.options.index_field) + return options + + +class Where(Stage): + """Filters documents based on a specified condition.""" + + def __init__(self, condition: FilterCondition): + super().__init__() + self.condition = condition + + def _pb_args(self): + return [self.condition._to_pb()] diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index dde906fe6..50ae7ab62 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -18,9 +18,18 @@ from google.cloud.firestore_v1.types.pipeline import ( StructuredPipeline as StructuredPipeline_pb, ) +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest from google.cloud.firestore_v1.pipeline_result import PipelineResult -from google.cloud.firestore_v1.pipeline_expressions import Expr +from google.cloud.firestore_v1.pipeline_expressions import ( + Accumulator, + Expr, + ExprWithAlias, + Field, + FilterCondition, + Selectable, +) from google.cloud.firestore_v1 import _helpers if TYPE_CHECKING: # pragma: NO COVER @@ -35,7 +44,7 @@ class _BasePipeline: Base class for building Firestore data transformation and query pipelines. This class is not intended to be instantiated directly. - Use `client.collection.("...").pipeline()` to create pipeline instances. + Use `client.pipeline()` to create pipeline instances. """ def __init__(self, client: Client | AsyncClient): @@ -127,6 +136,328 @@ def _execute_response_helper( doc._pb.update_time if doc.update_time else None, ) + def add_fields(self, *fields: Selectable) -> "_BasePipeline": + """ + Adds new fields to outputs from previous stages. + + This stage allows you to compute values on-the-fly based on existing data + from previous stages or constants. You can use this to create new fields + or overwrite existing ones (if there is name overlap). + + The added fields are defined using `Selectable` expressions, which can be: + - `Field`: References an existing document field. + - `Function`: Performs a calculation using functions like `add`, + `multiply` with assigned aliases using `Expr.as_()`. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, add + >>> pipeline = client.pipeline().collection("books") + >>> pipeline = pipeline.add_fields( + ... Field.of("rating").as_("bookRating"), # Rename 'rating' to 'bookRating' + ... add(5, Field.of("quantity")).as_("totalCost") # Calculate 'totalCost' + ... ) + + Args: + *fields: The fields to add to the documents, specified as `Selectable` + expressions. + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.AddFields(*fields)) + + def remove_fields(self, *fields: Field | str) -> "_BasePipeline": + """ + Removes fields from outputs of previous stages. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = client.pipeline().collection("books") + >>> # Remove by name + >>> pipeline = pipeline.remove_fields("rating", "cost") + >>> # Remove by Field object + >>> pipeline = pipeline.remove_fields(Field.of("rating"), Field.of("cost")) + + + Args: + *fields: The fields to remove, specified as field names (str) or + `Field` objects. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.RemoveFields(*fields)) + + def select(self, *selections: str | Selectable) -> "_BasePipeline": + """ + Selects or creates a set of fields from the outputs of previous stages. + + The selected fields are defined using `Selectable` expressions or field names: + - `Field`: References an existing document field. + - `Function`: Represents the result of a function with an assigned alias + name using `Expr.as_()`. + - `str`: The name of an existing field. + + If no selections are provided, the output of this stage is empty. Use + `add_fields()` instead if only additions are desired. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, to_upper + >>> pipeline = client.pipeline().collection("books") + >>> # Select by name + >>> pipeline = pipeline.select("name", "address") + >>> # Select using Field and Function expressions + >>> pipeline = pipeline.select( + ... Field.of("name"), + ... Field.of("address").to_upper().as_("upperAddress"), + ... ) + + Args: + *selections: The fields to include in the output documents, specified as + field names (str) or `Selectable` expressions. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Select(*selections)) + + def where(self, condition: FilterCondition) -> "_BasePipeline": + """ + Filters the documents from previous stages to only include those matching + the specified `FilterCondition`. + + This stage allows you to apply conditions to the data, similar to a "WHERE" + clause in SQL. You can filter documents based on their field values, using + implementations of `FilterCondition`, typically including but not limited to: + - field comparators: `eq`, `lt` (less than), `gt` (greater than), etc. + - logical operators: `And`, `Or`, `Not`, etc. + - advanced functions: `regex_matches`, `array_contains`, etc. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, And, + >>> pipeline = client.pipeline().collection("books") + >>> # Using static functions + >>> pipeline = pipeline.where( + ... And( + ... Field.of("rating").gt(4.0), # Filter for ratings > 4.0 + ... Field.of("genre").eq("Science Fiction") # Filter for genre + ... ) + ... ) + >>> # Using methods on expressions + >>> pipeline = pipeline.where( + ... And( + ... Field.of("rating").gt(4.0), + ... Field.of("genre").eq("Science Fiction") + ... ) + ... ) + + + Args: + condition: The `FilterCondition` to apply. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Where(condition)) + + def find_nearest( + self, + field: str | Expr, + vector: Sequence[float] | "Vector", + distance_measure: "DistanceMeasure", + options: stages.FindNearestOptions | None = None, + ) -> "_BasePipeline": + """ + Performs vector distance (similarity) search with given parameters on the + stage inputs. + + This stage adds a "nearest neighbor search" capability to your pipelines. + Given a field or expression that evaluates to a vector and a target vector, + this stage will identify and return the inputs whose vector is closest to + the target vector, using the specified distance measure and options. + + Example: + >>> from google.cloud.firestore_v1.base_vector_query import DistanceMeasure + >>> from google.cloud.firestore_v1.pipeline_stages import FindNearestOptions + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> + >>> target_vector = [0.1, 0.2, 0.3] + >>> pipeline = client.pipeline().collection("books") + >>> # Find using field name + >>> pipeline = pipeline.find_nearest( + ... "topicVectors", + ... target_vector, + ... DistanceMeasure.COSINE, + ... options=FindNearestOptions(limit=10, distance_field="distance") + ... ) + >>> # Find using Field expression + >>> pipeline = pipeline.find_nearest( + ... Field.of("topicVectors"), + ... target_vector, + ... DistanceMeasure.COSINE, + ... options=FindNearestOptions(limit=10, distance_field="distance") + ... ) + + Args: + field: The name of the field (str) or an expression (`Expr`) that + evaluates to the vector data. This field should store vector values. + vector: The target vector (sequence of floats or `Vector` object) to + compare against. + distance_measure: The distance measure (`DistanceMeasure`) to use + (e.g., `DistanceMeasure.COSINE`, `DistanceMeasure.EUCLIDEAN`). + limit: The maximum number of nearest neighbors to return. + options: Configuration options (`FindNearestOptions`) for the search, + such as limit and output distance field name. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append( + stages.FindNearest(field, vector, distance_measure, options) + ) + + def sort(self, *orders: stages.Ordering) -> "_BasePipeline": + """ + Sorts the documents from previous stages based on one or more `Ordering` criteria. + + This stage allows you to order the results of your pipeline. You can specify + multiple `Ordering` instances to sort by multiple fields or expressions in + ascending or descending order. If documents have the same value for a sorting + criterion, the next specified ordering will be used. If all orderings result + in equal comparison, the documents are considered equal and the relative order + is unspecified. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = client.pipeline().collection("books") + >>> # Sort books by rating descending, then title ascending + >>> pipeline = pipeline.sort( + ... Field.of("rating").descending(), + ... Field.of("title").ascending() + ... ) + + Args: + *orders: One or more `Ordering` instances specifying the sorting criteria. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Sort(*orders)) + + def sample(self, limit_or_options: int | stages.SampleOptions) -> "_BasePipeline": + """ + Performs a pseudo-random sampling of the documents from the previous stage. + + This stage filters documents pseudo-randomly. + - If an `int` limit is provided, it specifies the maximum number of documents + to emit. If fewer documents are available, all are passed through. + - If `SampleOptions` are provided, they specify how sampling is performed + (e.g., by document count or percentage). + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import SampleOptions + >>> pipeline = client.pipeline().collection("books") + >>> # Sample 10 books, if available. + >>> pipeline = pipeline.sample(10) + >>> pipeline = pipeline.sample(SampleOptions.doc_limit(10)) + >>> # Sample 50% of books. + >>> pipeline = pipeline.sample(SampleOptions.percentage(0.5)) + + + Args: + limit_or_options: Either an integer specifying the maximum number of + documents to sample, or a `SampleOptions` object. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Sample(limit_or_options)) + + def union(self, other: "_BasePipeline") -> "_BasePipeline": + """ + Performs a union of all documents from this pipeline and another pipeline, + including duplicates. + + This stage passes through documents from the previous stage of this pipeline, + and also passes through documents from the previous stage of the `other` + pipeline provided. The order of documents emitted from this stage is undefined. + + Example: + >>> books_pipeline = client.pipeline().collection("books") + >>> magazines_pipeline = client.pipeline().collection("magazines") + >>> # Emit documents from both collections + >>> combined_pipeline = books_pipeline.union(magazines_pipeline) + + Args: + other: The other `Pipeline` whose results will be unioned with this one. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Union(other)) + + def unnest( + self, + field: str | Selectable, + alias: str | Field | None = None, + options: stages.UnnestOptions | None = None, + ) -> "_BasePipeline": + """ + Produces a document for each element in an array field from the previous stage document. + + For each previous stage document, this stage will emit zero or more augmented documents. The + input array found in the previous stage document field specified by the `fieldName` parameter, + will emit an augmented document for each input array element. The input array element will + augment the previous stage document by setting the `alias` field with the array element value. + If `alias` is unset, the data in `field` will be overwritten. + + Example: + Input document: + ```json + { "title": "The Hitchhiker's Guide", "tags": [ "comedy", "sci-fi" ], ... } + ``` + + >>> from google.cloud.firestore_v1.pipeline_stages import UnnestOptions + >>> pipeline = client.pipeline().collection("books") + >>> # Emit a document for each tag + >>> pipeline = pipeline.unnest("tags", alias="tag") + + Output documents (without options): + ```json + { "title": "The Hitchhiker's Guide", "tag": "comedy", ... } + { "title": "The Hitchhiker's Guide", "tag": "sci-fi", ... } + ``` + + Optionally, `UnnestOptions` can specify a field to store the original index + of the element within the array + + Example: + Input document: + ```json + { "title": "The Hitchhiker's Guide", "tags": [ "comedy", "sci-fi" ], ... } + ``` + + >>> from google.cloud.firestore_v1.pipeline_stages import UnnestOptions + >>> pipeline = client.pipeline().collection("books") + >>> # Emit a document for each tag, including the index + >>> pipeline = pipeline.unnest("tags", options=UnnestOptions(index_field="tagIndex")) + + Output documents (with index_field="tagIndex"): + ```json + { "title": "The Hitchhiker's Guide", "tags": "comedy", "tagIndex": 0, ... } + { "title": "The Hitchhiker's Guide", "tags": "sci-fi", "tagIndex": 1, ... } + ``` + + Args: + field: The name of the field containing the array to unnest. + alias The alias field is used as the field name for each element within the output array. + If unset, or if `alias` matches the `field`, the output data will overwrite the original field. + options: Optional `UnnestOptions` to configure additional behavior, like adding an index field. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Unnest(field, alias, options)) + def generic_stage(self, name: str, *params: Expr) -> "_BasePipeline": """ Adds a generic, named stage to the pipeline with specified parameters. @@ -149,3 +480,132 @@ def generic_stage(self, name: str, *params: Expr) -> "_BasePipeline": A new Pipeline object with this stage appended to the stage list """ return self._append(stages.GenericStage(name, *params)) + + def offset(self, offset: int) -> "_BasePipeline": + """ + Skips the first `offset` number of documents from the results of previous stages. + + This stage is useful for implementing pagination, allowing you to retrieve + results in chunks. It is typically used in conjunction with `limit()` to + control the size of each page. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = client.pipeline().collection("books") + >>> # Retrieve the second page of 20 results (assuming sorted) + >>> pipeline = pipeline.sort(Field.of("published").descending()) + >>> pipeline = pipeline.offset(20) # Skip the first 20 results + >>> pipeline = pipeline.limit(20) # Take the next 20 results + + Args: + offset: The non-negative number of documents to skip. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Offset(offset)) + + def limit(self, limit: int) -> "_BasePipeline": + """ + Limits the maximum number of documents returned by previous stages to `limit`. + + This stage is useful for controlling the size of the result set, often used for: + - **Pagination:** In combination with `offset()` to retrieve specific pages. + - **Top-N queries:** To get a limited number of results after sorting. + - **Performance:** To prevent excessive data transfer. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = client.pipeline().collection("books") + >>> # Limit the results to the top 10 highest-rated books + >>> pipeline = pipeline.sort(Field.of("rating").descending()) + >>> pipeline = pipeline.limit(10) + + Args: + limit: The non-negative maximum number of documents to return. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Limit(limit)) + + def aggregate( + self, + *accumulators: ExprWithAlias[Accumulator], + groups: Sequence[str | Selectable] = (), + ) -> "_BasePipeline": + """ + Performs aggregation operations on the documents from previous stages, + optionally grouped by specified fields or expressions. + + This stage allows you to calculate aggregate values (like sum, average, count, + min, max) over a set of documents. + + - **Accumulators:** Define the aggregation calculations using `Accumulator` + expressions (e.g., `sum()`, `avg()`, `count()`, `min()`, `max()`) combined + with `as_()` to name the result field. + - **Groups:** Optionally specify fields (by name or `Selectable`) to group + the documents by. Aggregations are then performed within each distinct group. + If no groups are provided, the aggregation is performed over the entire input. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = client.pipeline().collection("books") + >>> # Calculate the average rating and total count for all books + >>> pipeline = pipeline.aggregate( + ... Field.of("rating").avg().as_("averageRating"), + ... Field.of("rating").count().as_("totalBooks") + ... ) + >>> # Calculate the average rating for each genre + >>> pipeline = pipeline.aggregate( + ... Field.of("rating").avg().as_("avg_rating"), + ... groups=["genre"] # Group by the 'genre' field + ... ) + >>> # Calculate the count for each author, grouping by Field object + >>> pipeline = pipeline.aggregate( + ... Count().as_("bookCount"), + ... groups=[Field.of("author")] + ... ) + + + Args: + *accumulators: One or more `ExprWithAlias[Accumulator]` expressions defining + the aggregations to perform and their output names. + groups: An optional sequence of field names (str) or `Selectable` + expressions to group by before aggregating. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Aggregate(*accumulators, groups=groups)) + + def distinct(self, *fields: str | Selectable) -> "_BasePipeline": + """ + Returns documents with distinct combinations of values for the specified + fields or expressions. + + This stage filters the results from previous stages to include only one + document for each unique combination of values in the specified `fields`. + The output documents contain only the fields specified in the `distinct` call. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, to_upper + >>> pipeline = client.pipeline().collection("books") + >>> # Get a list of unique genres (output has only 'genre' field) + >>> pipeline = pipeline.distinct("genre") + >>> # Get unique combinations of author (uppercase) and genre + >>> pipeline = pipeline.distinct( + ... Field.of("author").to_upper().as_("authorUpper"), + ... Field.of("genre") + ... ) + + + Args: + *fields: Field names (str) or `Selectable` expressions to consider when + determining distinct value combinations. The output will only + contain these fields/expressions. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Distinct(*fields)) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 5e0c775a2..70d619d3b 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -15,17 +15,22 @@ from __future__ import annotations from typing import ( Any, + List, Generic, TypeVar, Dict, + Sequence, ) from abc import ABC from abc import abstractmethod +from enum import Enum import datetime from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.types.query import StructuredQuery as Query_pb from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1._helpers import GeoPoint from google.cloud.firestore_v1._helpers import encode_value +from google.cloud.firestore_v1._helpers import decode_value CONSTANT_TYPE = TypeVar( "CONSTANT_TYPE", @@ -43,6 +48,48 @@ ) +class Ordering: + """Represents the direction for sorting results in a pipeline.""" + + class Direction(Enum): + ASCENDING = "ascending" + DESCENDING = "descending" + + def __init__(self, expr, order_dir: Direction | str = Direction.ASCENDING): + """ + Initializes an Ordering instance + + Args: + expr (Expr | str): The expression or field path string to sort by. + If a string is provided, it's treated as a field path. + order_dir (Direction | str): The direction to sort in. + Defaults to ascending + """ + self.expr = expr if isinstance(expr, Expr) else Field.of(expr) + self.order_dir = ( + Ordering.Direction[order_dir.upper()] + if isinstance(order_dir, str) + else order_dir + ) + + def __repr__(self): + if self.order_dir is Ordering.Direction.ASCENDING: + order_str = ".ascending()" + else: + order_str = ".descending()" + return f"{self.expr!r}{order_str}" + + def _to_pb(self) -> Value: + return Value( + map_value={ + "fields": { + "direction": Value(string_value=self.order_dir.value), + "expression": self.expr._to_pb(), + } + } + ) + + class Expr(ABC): """Represents an expression that can be evaluated to a value within the execution of a pipeline. @@ -66,6 +113,794 @@ def __repr__(self): def _to_pb(self) -> Value: raise NotImplementedError + @staticmethod + def _cast_to_expr_or_convert_to_constant(o: Any) -> "Expr": + return o if isinstance(o, Expr) else Constant(o) + + def add(self, other: Expr | float) -> "Add": + """Creates an expression that adds this expression to another expression or constant. + + Example: + >>> # Add the value of the 'quantity' field and the 'reserve' field. + >>> Field.of("quantity").add(Field.of("reserve")) + >>> # Add 5 to the value of the 'age' field + >>> Field.of("age").add(5) + + Args: + other: The expression or constant value to add to this expression. + + Returns: + A new `Expr` representing the addition operation. + """ + return Add(self, self._cast_to_expr_or_convert_to_constant(other)) + + def subtract(self, other: Expr | float) -> "Subtract": + """Creates an expression that subtracts another expression or constant from this expression. + + Example: + >>> # Subtract the 'discount' field from the 'price' field + >>> Field.of("price").subtract(Field.of("discount")) + >>> # Subtract 20 from the value of the 'total' field + >>> Field.of("total").subtract(20) + + Args: + other: The expression or constant value to subtract from this expression. + + Returns: + A new `Expr` representing the subtraction operation. + """ + return Subtract(self, self._cast_to_expr_or_convert_to_constant(other)) + + def multiply(self, other: Expr | float) -> "Multiply": + """Creates an expression that multiplies this expression by another expression or constant. + + Example: + >>> # Multiply the 'quantity' field by the 'price' field + >>> Field.of("quantity").multiply(Field.of("price")) + >>> # Multiply the 'value' field by 2 + >>> Field.of("value").multiply(2) + + Args: + other: The expression or constant value to multiply by. + + Returns: + A new `Expr` representing the multiplication operation. + """ + return Multiply(self, self._cast_to_expr_or_convert_to_constant(other)) + + def divide(self, other: Expr | float) -> "Divide": + """Creates an expression that divides this expression by another expression or constant. + + Example: + >>> # Divide the 'total' field by the 'count' field + >>> Field.of("total").divide(Field.of("count")) + >>> # Divide the 'value' field by 10 + >>> Field.of("value").divide(10) + + Args: + other: The expression or constant value to divide by. + + Returns: + A new `Expr` representing the division operation. + """ + return Divide(self, self._cast_to_expr_or_convert_to_constant(other)) + + def mod(self, other: Expr | float) -> "Mod": + """Creates an expression that calculates the modulo (remainder) to another expression or constant. + + Example: + >>> # Calculate the remainder of dividing the 'value' field by field 'divisor'. + >>> Field.of("value").mod(Field.of("divisor")) + >>> # Calculate the remainder of dividing the 'value' field by 5. + >>> Field.of("value").mod(5) + + Args: + other: The divisor expression or constant. + + Returns: + A new `Expr` representing the modulo operation. + """ + return Mod(self, self._cast_to_expr_or_convert_to_constant(other)) + + def logical_max(self, other: Expr | CONSTANT_TYPE) -> "LogicalMax": + """Creates an expression that returns the larger value between this expression + and another expression or constant, based on Firestore's value type ordering. + + Firestore's value type ordering is described here: + https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering + + Example: + >>> # Returns the larger value between the 'discount' field and the 'cap' field. + >>> Field.of("discount").logical_max(Field.of("cap")) + >>> # Returns the larger value between the 'value' field and 10. + >>> Field.of("value").logical_max(10) + + Args: + other: The other expression or constant value to compare with. + + Returns: + A new `Expr` representing the logical max operation. + """ + return LogicalMax(self, self._cast_to_expr_or_convert_to_constant(other)) + + def logical_min(self, other: Expr | CONSTANT_TYPE) -> "LogicalMin": + """Creates an expression that returns the smaller value between this expression + and another expression or constant, based on Firestore's value type ordering. + + Firestore's value type ordering is described here: + https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering + + Example: + >>> # Returns the smaller value between the 'discount' field and the 'floor' field. + >>> Field.of("discount").logical_min(Field.of("floor")) + >>> # Returns the smaller value between the 'value' field and 10. + >>> Field.of("value").logical_min(10) + + Args: + other: The other expression or constant value to compare with. + + Returns: + A new `Expr` representing the logical min operation. + """ + return LogicalMin(self, self._cast_to_expr_or_convert_to_constant(other)) + + def eq(self, other: Expr | CONSTANT_TYPE) -> "Eq": + """Creates an expression that checks if this expression is equal to another + expression or constant value. + + Example: + >>> # Check if the 'age' field is equal to 21 + >>> Field.of("age").eq(21) + >>> # Check if the 'city' field is equal to "London" + >>> Field.of("city").eq("London") + + Args: + other: The expression or constant value to compare for equality. + + Returns: + A new `Expr` representing the equality comparison. + """ + return Eq(self, self._cast_to_expr_or_convert_to_constant(other)) + + def neq(self, other: Expr | CONSTANT_TYPE) -> "Neq": + """Creates an expression that checks if this expression is not equal to another + expression or constant value. + + Example: + >>> # Check if the 'status' field is not equal to "completed" + >>> Field.of("status").neq("completed") + >>> # Check if the 'country' field is not equal to "USA" + >>> Field.of("country").neq("USA") + + Args: + other: The expression or constant value to compare for inequality. + + Returns: + A new `Expr` representing the inequality comparison. + """ + return Neq(self, self._cast_to_expr_or_convert_to_constant(other)) + + def gt(self, other: Expr | CONSTANT_TYPE) -> "Gt": + """Creates an expression that checks if this expression is greater than another + expression or constant value. + + Example: + >>> # Check if the 'age' field is greater than the 'limit' field + >>> Field.of("age").gt(Field.of("limit")) + >>> # Check if the 'price' field is greater than 100 + >>> Field.of("price").gt(100) + + Args: + other: The expression or constant value to compare for greater than. + + Returns: + A new `Expr` representing the greater than comparison. + """ + return Gt(self, self._cast_to_expr_or_convert_to_constant(other)) + + def gte(self, other: Expr | CONSTANT_TYPE) -> "Gte": + """Creates an expression that checks if this expression is greater than or equal + to another expression or constant value. + + Example: + >>> # Check if the 'quantity' field is greater than or equal to field 'requirement' plus 1 + >>> Field.of("quantity").gte(Field.of('requirement').add(1)) + >>> # Check if the 'score' field is greater than or equal to 80 + >>> Field.of("score").gte(80) + + Args: + other: The expression or constant value to compare for greater than or equal to. + + Returns: + A new `Expr` representing the greater than or equal to comparison. + """ + return Gte(self, self._cast_to_expr_or_convert_to_constant(other)) + + def lt(self, other: Expr | CONSTANT_TYPE) -> "Lt": + """Creates an expression that checks if this expression is less than another + expression or constant value. + + Example: + >>> # Check if the 'age' field is less than 'limit' + >>> Field.of("age").lt(Field.of('limit')) + >>> # Check if the 'price' field is less than 50 + >>> Field.of("price").lt(50) + + Args: + other: The expression or constant value to compare for less than. + + Returns: + A new `Expr` representing the less than comparison. + """ + return Lt(self, self._cast_to_expr_or_convert_to_constant(other)) + + def lte(self, other: Expr | CONSTANT_TYPE) -> "Lte": + """Creates an expression that checks if this expression is less than or equal to + another expression or constant value. + + Example: + >>> # Check if the 'quantity' field is less than or equal to 20 + >>> Field.of("quantity").lte(Constant.of(20)) + >>> # Check if the 'score' field is less than or equal to 70 + >>> Field.of("score").lte(70) + + Args: + other: The expression or constant value to compare for less than or equal to. + + Returns: + A new `Expr` representing the less than or equal to comparison. + """ + return Lte(self, self._cast_to_expr_or_convert_to_constant(other)) + + def in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "In": + """Creates an expression that checks if this expression is equal to any of the + provided values or expressions. + + Example: + >>> # Check if the 'category' field is either "Electronics" or value of field 'primaryType' + >>> Field.of("category").in_any(["Electronics", Field.of("primaryType")]) + + Args: + array: The values or expressions to check against. + + Returns: + A new `Expr` representing the 'IN' comparison. + """ + return In(self, [self._cast_to_expr_or_convert_to_constant(v) for v in array]) + + def not_in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "Not": + """Creates an expression that checks if this expression is not equal to any of the + provided values or expressions. + + Example: + >>> # Check if the 'status' field is neither "pending" nor "cancelled" + >>> Field.of("status").not_in_any(["pending", "cancelled"]) + + Args: + *others: The values or expressions to check against. + + Returns: + A new `Expr` representing the 'NOT IN' comparison. + """ + return Not(self.in_any(array)) + + def array_contains(self, element: Expr | CONSTANT_TYPE) -> "ArrayContains": + """Creates an expression that checks if an array contains a specific element or value. + + Example: + >>> # Check if the 'sizes' array contains the value from the 'selectedSize' field + >>> Field.of("sizes").array_contains(Field.of("selectedSize")) + >>> # Check if the 'colors' array contains "red" + >>> Field.of("colors").array_contains("red") + + Args: + element: The element (expression or constant) to search for in the array. + + Returns: + A new `Expr` representing the 'array_contains' comparison. + """ + return ArrayContains(self, self._cast_to_expr_or_convert_to_constant(element)) + + def array_contains_all( + self, elements: List[Expr | CONSTANT_TYPE] + ) -> "ArrayContainsAll": + """Creates an expression that checks if an array contains all the specified elements. + + Example: + >>> # Check if the 'tags' array contains both "news" and "sports" + >>> Field.of("tags").array_contains_all(["news", "sports"]) + >>> # Check if the 'tags' array contains both of the values from field 'tag1' and "tag2" + >>> Field.of("tags").array_contains_all([Field.of("tag1"), "tag2"]) + + Args: + elements: The list of elements (expressions or constants) to check for in the array. + + Returns: + A new `Expr` representing the 'array_contains_all' comparison. + """ + return ArrayContainsAll( + self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements] + ) + + def array_contains_any( + self, elements: List[Expr | CONSTANT_TYPE] + ) -> "ArrayContainsAny": + """Creates an expression that checks if an array contains any of the specified elements. + + Example: + >>> # Check if the 'categories' array contains either values from field "cate1" or "cate2" + >>> Field.of("categories").array_contains_any([Field.of("cate1"), Field.of("cate2")]) + >>> # Check if the 'groups' array contains either the value from the 'userGroup' field + >>> # or the value "guest" + >>> Field.of("groups").array_contains_any([Field.of("userGroup"), "guest"]) + + Args: + elements: The list of elements (expressions or constants) to check for in the array. + + Returns: + A new `Expr` representing the 'array_contains_any' comparison. + """ + return ArrayContainsAny( + self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements] + ) + + def array_length(self) -> "ArrayLength": + """Creates an expression that calculates the length of an array. + + Example: + >>> # Get the number of items in the 'cart' array + >>> Field.of("cart").array_length() + + Returns: + A new `Expr` representing the length of the array. + """ + return ArrayLength(self) + + def array_reverse(self) -> "ArrayReverse": + """Creates an expression that returns the reversed content of an array. + + Example: + >>> # Get the 'preferences' array in reversed order. + >>> Field.of("preferences").array_reverse() + + Returns: + A new `Expr` representing the reversed array. + """ + return ArrayReverse(self) + + def is_nan(self) -> "IsNaN": + """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). + + Example: + >>> # Check if the result of a calculation is NaN + >>> Field.of("value").divide(0).is_nan() + + Returns: + A new `Expr` representing the 'isNaN' check. + """ + return IsNaN(self) + + def exists(self) -> "Exists": + """Creates an expression that checks if a field exists in the document. + + Example: + >>> # Check if the document has a field named "phoneNumber" + >>> Field.of("phoneNumber").exists() + + Returns: + A new `Expr` representing the 'exists' check. + """ + return Exists(self) + + def sum(self) -> "Sum": + """Creates an aggregation that calculates the sum of a numeric field across multiple stage inputs. + + Example: + >>> # Calculate the total revenue from a set of orders + >>> Field.of("orderAmount").sum().as_("totalRevenue") + + Returns: + A new `Accumulator` representing the 'sum' aggregation. + """ + return Sum(self) + + def avg(self) -> "Avg": + """Creates an aggregation that calculates the average (mean) of a numeric field across multiple + stage inputs. + + Example: + >>> # Calculate the average age of users + >>> Field.of("age").avg().as_("averageAge") + + Returns: + A new `Accumulator` representing the 'avg' aggregation. + """ + return Avg(self) + + def count(self) -> "Count": + """Creates an aggregation that counts the number of stage inputs with valid evaluations of the + expression or field. + + Example: + >>> # Count the total number of products + >>> Field.of("productId").count().as_("totalProducts") + + Returns: + A new `Accumulator` representing the 'count' aggregation. + """ + return Count(self) + + def min(self) -> "Min": + """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. + + Example: + >>> # Find the lowest price of all products + >>> Field.of("price").min().as_("lowestPrice") + + Returns: + A new `Accumulator` representing the 'min' aggregation. + """ + return Min(self) + + def max(self) -> "Max": + """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. + + Example: + >>> # Find the highest score in a leaderboard + >>> Field.of("score").max().as_("highestScore") + + Returns: + A new `Accumulator` representing the 'max' aggregation. + """ + return Max(self) + + def char_length(self) -> "CharLength": + """Creates an expression that calculates the character length of a string. + + Example: + >>> # Get the character length of the 'name' field + >>> Field.of("name").char_length() + + Returns: + A new `Expr` representing the length of the string. + """ + return CharLength(self) + + def byte_length(self) -> "ByteLength": + """Creates an expression that calculates the byte length of a string in its UTF-8 form. + + Example: + >>> # Get the byte length of the 'name' field + >>> Field.of("name").byte_length() + + Returns: + A new `Expr` representing the byte length of the string. + """ + return ByteLength(self) + + def like(self, pattern: Expr | str) -> "Like": + """Creates an expression that performs a case-sensitive string comparison. + + Example: + >>> # Check if the 'title' field contains the word "guide" (case-sensitive) + >>> Field.of("title").like("%guide%") + >>> # Check if the 'title' field matches the pattern specified in field 'pattern'. + >>> Field.of("title").like(Field.of("pattern")) + + Args: + pattern: The pattern (string or expression) to search for. You can use "%" as a wildcard character. + + Returns: + A new `Expr` representing the 'like' comparison. + """ + return Like(self, self._cast_to_expr_or_convert_to_constant(pattern)) + + def regex_contains(self, regex: Expr | str) -> "RegexContains": + """Creates an expression that checks if a string contains a specified regular expression as a + substring. + + Example: + >>> # Check if the 'description' field contains "example" (case-insensitive) + >>> Field.of("description").regex_contains("(?i)example") + >>> # Check if the 'description' field contains the regular expression stored in field 'regex' + >>> Field.of("description").regex_contains(Field.of("regex")) + + Args: + regex: The regular expression (string or expression) to use for the search. + + Returns: + A new `Expr` representing the 'contains' comparison. + """ + return RegexContains(self, self._cast_to_expr_or_convert_to_constant(regex)) + + def regex_matches(self, regex: Expr | str) -> "RegexMatch": + """Creates an expression that checks if a string matches a specified regular expression. + + Example: + >>> # Check if the 'email' field matches a valid email pattern + >>> Field.of("email").regex_matches("[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}") + >>> # Check if the 'email' field matches a regular expression stored in field 'regex' + >>> Field.of("email").regex_matches(Field.of("regex")) + + Args: + regex: The regular expression (string or expression) to use for the match. + + Returns: + A new `Expr` representing the regular expression match. + """ + return RegexMatch(self, self._cast_to_expr_or_convert_to_constant(regex)) + + def str_contains(self, substring: Expr | str) -> "StrContains": + """Creates an expression that checks if this string expression contains a specified substring. + + Example: + >>> # Check if the 'description' field contains "example". + >>> Field.of("description").str_contains("example") + >>> # Check if the 'description' field contains the value of the 'keyword' field. + >>> Field.of("description").str_contains(Field.of("keyword")) + + Args: + substring: The substring (string or expression) to use for the search. + + Returns: + A new `Expr` representing the 'contains' comparison. + """ + return StrContains(self, self._cast_to_expr_or_convert_to_constant(substring)) + + def starts_with(self, prefix: Expr | str) -> "StartsWith": + """Creates an expression that checks if a string starts with a given prefix. + + Example: + >>> # Check if the 'name' field starts with "Mr." + >>> Field.of("name").starts_with("Mr.") + >>> # Check if the 'fullName' field starts with the value of the 'firstName' field + >>> Field.of("fullName").starts_with(Field.of("firstName")) + + Args: + prefix: The prefix (string or expression) to check for. + + Returns: + A new `Expr` representing the 'starts with' comparison. + """ + return StartsWith(self, self._cast_to_expr_or_convert_to_constant(prefix)) + + def ends_with(self, postfix: Expr | str) -> "EndsWith": + """Creates an expression that checks if a string ends with a given postfix. + + Example: + >>> # Check if the 'filename' field ends with ".txt" + >>> Field.of("filename").ends_with(".txt") + >>> # Check if the 'url' field ends with the value of the 'extension' field + >>> Field.of("url").ends_with(Field.of("extension")) + + Args: + postfix: The postfix (string or expression) to check for. + + Returns: + A new `Expr` representing the 'ends with' comparison. + """ + return EndsWith(self, self._cast_to_expr_or_convert_to_constant(postfix)) + + def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": + """Creates an expression that concatenates string expressions, fields or constants together. + + Example: + >>> # Combine the 'firstName', " ", and 'lastName' fields into a single string + >>> Field.of("firstName").str_concat(" ", Field.of("lastName")) + + Args: + *elements: The expressions or constants (typically strings) to concatenate. + + Returns: + A new `Expr` representing the concatenated string. + """ + return StrConcat( + self, *[self._cast_to_expr_or_convert_to_constant(el) for el in elements] + ) + + def map_get(self, key: str) -> "MapGet": + """Accesses a value from a map (object) field using the provided key. + + Example: + >>> # Get the 'city' value from + >>> # the 'address' map field + >>> Field.of("address").map_get("city") + + Args: + key: The key to access in the map. + + Returns: + A new `Expr` representing the value associated with the given key in the map. + """ + return MapGet(self, Constant.of(key)) + + def vector_length(self) -> "VectorLength": + """Creates an expression that calculates the length (dimension) of a Firestore Vector. + + Example: + >>> # Get the vector length (dimension) of the field 'embedding'. + >>> Field.of("embedding").vector_length() + + Returns: + A new `Expr` representing the length of the vector. + """ + return VectorLength(self) + + def timestamp_to_unix_micros(self) -> "TimestampToUnixMicros": + """Creates an expression that converts a timestamp to the number of microseconds since the epoch + (1970-01-01 00:00:00 UTC). + + Truncates higher levels of precision by rounding down to the beginning of the microsecond. + + Example: + >>> # Convert the 'timestamp' field to microseconds since the epoch. + >>> Field.of("timestamp").timestamp_to_unix_micros() + + Returns: + A new `Expr` representing the number of microseconds since the epoch. + """ + return TimestampToUnixMicros(self) + + def unix_micros_to_timestamp(self) -> "UnixMicrosToTimestamp": + """Creates an expression that converts a number of microseconds since the epoch (1970-01-01 + 00:00:00 UTC) to a timestamp. + + Example: + >>> # Convert the 'microseconds' field to a timestamp. + >>> Field.of("microseconds").unix_micros_to_timestamp() + + Returns: + A new `Expr` representing the timestamp. + """ + return UnixMicrosToTimestamp(self) + + def timestamp_to_unix_millis(self) -> "TimestampToUnixMillis": + """Creates an expression that converts a timestamp to the number of milliseconds since the epoch + (1970-01-01 00:00:00 UTC). + + Truncates higher levels of precision by rounding down to the beginning of the millisecond. + + Example: + >>> # Convert the 'timestamp' field to milliseconds since the epoch. + >>> Field.of("timestamp").timestamp_to_unix_millis() + + Returns: + A new `Expr` representing the number of milliseconds since the epoch. + """ + return TimestampToUnixMillis(self) + + def unix_millis_to_timestamp(self) -> "UnixMillisToTimestamp": + """Creates an expression that converts a number of milliseconds since the epoch (1970-01-01 + 00:00:00 UTC) to a timestamp. + + Example: + >>> # Convert the 'milliseconds' field to a timestamp. + >>> Field.of("milliseconds").unix_millis_to_timestamp() + + Returns: + A new `Expr` representing the timestamp. + """ + return UnixMillisToTimestamp(self) + + def timestamp_to_unix_seconds(self) -> "TimestampToUnixSeconds": + """Creates an expression that converts a timestamp to the number of seconds since the epoch + (1970-01-01 00:00:00 UTC). + + Truncates higher levels of precision by rounding down to the beginning of the second. + + Example: + >>> # Convert the 'timestamp' field to seconds since the epoch. + >>> Field.of("timestamp").timestamp_to_unix_seconds() + + Returns: + A new `Expr` representing the number of seconds since the epoch. + """ + return TimestampToUnixSeconds(self) + + def unix_seconds_to_timestamp(self) -> "UnixSecondsToTimestamp": + """Creates an expression that converts a number of seconds since the epoch (1970-01-01 00:00:00 + UTC) to a timestamp. + + Example: + >>> # Convert the 'seconds' field to a timestamp. + >>> Field.of("seconds").unix_seconds_to_timestamp() + + Returns: + A new `Expr` representing the timestamp. + """ + return UnixSecondsToTimestamp(self) + + def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "TimestampAdd": + """Creates an expression that adds a specified amount of time to this timestamp expression. + + Example: + >>> # Add a duration specified by the 'unit' and 'amount' fields to the 'timestamp' field. + >>> Field.of("timestamp").timestamp_add(Field.of("unit"), Field.of("amount")) + >>> # Add 1.5 days to the 'timestamp' field. + >>> Field.of("timestamp").timestamp_add("day", 1.5) + + Args: + unit: The expression or string evaluating to the unit of time to add, must be one of + 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. + amount: The expression or float representing the amount of time to add. + + Returns: + A new `Expr` representing the resulting timestamp. + """ + return TimestampAdd( + self, + self._cast_to_expr_or_convert_to_constant(unit), + self._cast_to_expr_or_convert_to_constant(amount), + ) + + def timestamp_sub(self, unit: Expr | str, amount: Expr | float) -> "TimestampSub": + """Creates an expression that subtracts a specified amount of time from this timestamp expression. + + Example: + >>> # Subtract a duration specified by the 'unit' and 'amount' fields from the 'timestamp' field. + >>> Field.of("timestamp").timestamp_sub(Field.of("unit"), Field.of("amount")) + >>> # Subtract 2.5 hours from the 'timestamp' field. + >>> Field.of("timestamp").timestamp_sub("hour", 2.5) + + Args: + unit: The expression or string evaluating to the unit of time to subtract, must be one of + 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. + amount: The expression or float representing the amount of time to subtract. + + Returns: + A new `Expr` representing the resulting timestamp. + """ + return TimestampSub( + self, + self._cast_to_expr_or_convert_to_constant(unit), + self._cast_to_expr_or_convert_to_constant(amount), + ) + + def ascending(self) -> Ordering: + """Creates an `Ordering` that sorts documents in ascending order based on this expression. + + Example: + >>> # Sort documents by the 'name' field in ascending order + >>> firestore.pipeline().collection("users").sort(Field.of("name").ascending()) + + Returns: + A new `Ordering` for ascending sorting. + """ + return Ordering(self, Ordering.Direction.ASCENDING) + + def descending(self) -> Ordering: + """Creates an `Ordering` that sorts documents in descending order based on this expression. + + Example: + >>> # Sort documents by the 'createdAt' field in descending order + >>> firestore.pipeline().collection("users").sort(Field.of("createdAt").descending()) + + Returns: + A new `Ordering` for descending sorting. + """ + return Ordering(self, Ordering.Direction.DESCENDING) + + def as_(self, alias: str) -> "ExprWithAlias": + """Assigns an alias to this expression. + + Aliases are useful for renaming fields in the output of a stage or for giving meaningful + names to calculated values. + + Example: + >>> # Calculate the total price and assign it the alias "totalPrice" and add it to the output. + >>> firestore.pipeline().collection("items").add_fields( + ... Field.of("price").multiply(Field.of("quantity")).as_("totalPrice") + ... ) + + Args: + alias: The alias to assign to this expression. + + Returns: + A new `Selectable` (typically an `ExprWithAlias`) that wraps this + expression and associates it with the provided alias. + """ + return ExprWithAlias(self, alias) + class Constant(Expr, Generic[CONSTANT_TYPE]): """Represents a constant literal value in an expression.""" @@ -73,6 +908,12 @@ class Constant(Expr, Generic[CONSTANT_TYPE]): def __init__(self, value: CONSTANT_TYPE): self.value: CONSTANT_TYPE = value + def __eq__(self, other): + if not isinstance(other, Constant): + return other == self.value + else: + return other.value == self.value + @staticmethod def of(value: CONSTANT_TYPE) -> Constant[CONSTANT_TYPE]: """Creates a constant expression from a Python value.""" @@ -83,3 +924,1379 @@ def __repr__(self): def _to_pb(self) -> Value: return encode_value(self.value) + + +class ListOfExprs(Expr): + """Represents a list of expressions, typically used as an argument to functions like 'in' or array functions.""" + + def __init__(self, exprs: List[Expr]): + self.exprs: list[Expr] = exprs + + def __eq__(self, other): + if not isinstance(other, ListOfExprs): + return False + else: + return other.exprs == self.exprs + + def __repr__(self): + return f"{self.__class__.__name__}({self.exprs})" + + def _to_pb(self): + return Value(array_value={"values": [e._to_pb() for e in self.exprs]}) + + +class Function(Expr): + """A base class for expressions that represent function calls.""" + + def __init__(self, name: str, params: Sequence[Expr]): + self.name = name + self.params = list(params) + + def __eq__(self, other): + if not isinstance(other, Function): + return False + else: + return other.name == self.name and other.params == self.params + + def __repr__(self): + return f"{self.__class__.__name__}({', '.join([repr(p) for p in self.params])})" + + def _to_pb(self): + return Value( + function_value={ + "name": self.name, + "args": [p._to_pb() for p in self.params], + } + ) + + def add(left: Expr | str, right: Expr | float) -> "Add": + """Creates an expression that adds two expressions together. + + Example: + >>> Function.add("rating", 5) + >>> Function.add(Field.of("quantity"), Field.of("reserve")) + + Args: + left: The first expression or field path to add. + right: The second expression or constant value to add. + + Returns: + A new `Expr` representing the addition operation. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.add(left_expr, right) + + def subtract(left: Expr | str, right: Expr | float) -> "Subtract": + """Creates an expression that subtracts another expression or constant from this expression. + + Example: + >>> Function.subtract("total", 20) + >>> Function.subtract(Field.of("price"), Field.of("discount")) + + Args: + left: The expression or field path to subtract from. + right: The expression or constant value to subtract. + + Returns: + A new `Expr` representing the subtraction operation. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.subtract(left_expr, right) + + def multiply(left: Expr | str, right: Expr | float) -> "Multiply": + """Creates an expression that multiplies this expression by another expression or constant. + + Example: + >>> Function.multiply("value", 2) + >>> Function.multiply(Field.of("quantity"), Field.of("price")) + + Args: + left: The expression or field path to multiply. + right: The expression or constant value to multiply by. + + Returns: + A new `Expr` representing the multiplication operation. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.multiply(left_expr, right) + + def divide(left: Expr | str, right: Expr | float) -> "Divide": + """Creates an expression that divides this expression by another expression or constant. + + Example: + >>> Function.divide("value", 10) + >>> Function.divide(Field.of("total"), Field.of("count")) + + Args: + left: The expression or field path to be divided. + right: The expression or constant value to divide by. + + Returns: + A new `Expr` representing the division operation. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.divide(left_expr, right) + + def mod(left: Expr | str, right: Expr | float) -> "Mod": + """Creates an expression that calculates the modulo (remainder) to another expression or constant. + + Example: + >>> Function.mod("value", 5) + >>> Function.mod(Field.of("value"), Field.of("divisor")) + + Args: + left: The dividend expression or field path. + right: The divisor expression or constant. + + Returns: + A new `Expr` representing the modulo operation. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.mod(left_expr, right) + + def logical_max(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "LogicalMax": + """Creates an expression that returns the larger value between this expression + and another expression or constant, based on Firestore's value type ordering. + + Firestore's value type ordering is described here: + https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering + + Example: + >>> Function.logical_max("value", 10) + >>> Function.logical_max(Field.of("discount"), Field.of("cap")) + + Args: + left: The expression or field path to compare. + right: The other expression or constant value to compare with. + + Returns: + A new `Expr` representing the logical max operation. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.logical_max(left_expr, right) + + def logical_min(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "LogicalMin": + """Creates an expression that returns the smaller value between this expression + and another expression or constant, based on Firestore's value type ordering. + + Firestore's value type ordering is described here: + https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering + + Example: + >>> Function.logical_min("value", 10) + >>> Function.logical_min(Field.of("discount"), Field.of("floor")) + + Args: + left: The expression or field path to compare. + right: The other expression or constant value to compare with. + + Returns: + A new `Expr` representing the logical min operation. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.logical_min(left_expr, right) + + def eq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Eq": + """Creates an expression that checks if this expression is equal to another + expression or constant value. + + Example: + >>> Function.eq("city", "London") + >>> Function.eq(Field.of("age"), 21) + + Args: + left: The expression or field path to compare. + right: The expression or constant value to compare for equality. + + Returns: + A new `Expr` representing the equality comparison. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.eq(left_expr, right) + + def neq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Neq": + """Creates an expression that checks if this expression is not equal to another + expression or constant value. + + Example: + >>> Function.neq("country", "USA") + >>> Function.neq(Field.of("status"), "completed") + + Args: + left: The expression or field path to compare. + right: The expression or constant value to compare for inequality. + + Returns: + A new `Expr` representing the inequality comparison. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.neq(left_expr, right) + + def gt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Gt": + """Creates an expression that checks if this expression is greater than another + expression or constant value. + + Example: + >>> Function.gt("price", 100) + >>> Function.gt(Field.of("age"), Field.of("limit")) + + Args: + left: The expression or field path to compare. + right: The expression or constant value to compare for greater than. + + Returns: + A new `Expr` representing the greater than comparison. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.gt(left_expr, right) + + def gte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Gte": + """Creates an expression that checks if this expression is greater than or equal + to another expression or constant value. + + Example: + >>> Function.gte("score", 80) + >>> Function.gte(Field.of("quantity"), Field.of('requirement').add(1)) + + Args: + left: The expression or field path to compare. + right: The expression or constant value to compare for greater than or equal to. + + Returns: + A new `Expr` representing the greater than or equal to comparison. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.gte(left_expr, right) + + def lt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Lt": + """Creates an expression that checks if this expression is less than another + expression or constant value. + + Example: + >>> Function.lt("price", 50) + >>> Function.lt(Field.of("age"), Field.of('limit')) + + Args: + left: The expression or field path to compare. + right: The expression or constant value to compare for less than. + + Returns: + A new `Expr` representing the less than comparison. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.lt(left_expr, right) + + def lte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Lte": + """Creates an expression that checks if this expression is less than or equal to + another expression or constant value. + + Example: + >>> Function.lte("score", 70) + >>> Function.lte(Field.of("quantity"), Constant.of(20)) + + Args: + left: The expression or field path to compare. + right: The expression or constant value to compare for less than or equal to. + + Returns: + A new `Expr` representing the less than or equal to comparison. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.lte(left_expr, right) + + def in_any(left: Expr | str, array: List[Expr | CONSTANT_TYPE]) -> "In": + """Creates an expression that checks if this expression is equal to any of the + provided values or expressions. + + Example: + >>> Function.in_any("category", ["Electronics", "Apparel"]) + >>> Function.in_any(Field.of("category"), ["Electronics", Field.of("primaryType")]) + + Args: + left: The expression or field path to compare. + array: The values or expressions to check against. + + Returns: + A new `Expr` representing the 'IN' comparison. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.in_any(left_expr, array) + + def not_in_any(left: Expr | str, array: List[Expr | CONSTANT_TYPE]) -> "Not": + """Creates an expression that checks if this expression is not equal to any of the + provided values or expressions. + + Example: + >>> Function.not_in_any("status", ["pending", "cancelled"]) + + Args: + left: The expression or field path to compare. + array: The values or expressions to check against. + + Returns: + A new `Expr` representing the 'NOT IN' comparison. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.not_in_any(left_expr, array) + + def array_contains( + array: Expr | str, element: Expr | CONSTANT_TYPE + ) -> "ArrayContains": + """Creates an expression that checks if an array contains a specific element or value. + + Example: + >>> Function.array_contains("colors", "red") + >>> Function.array_contains(Field.of("sizes"), Field.of("selectedSize")) + + Args: + array: The array expression or field path to check. + element: The element (expression or constant) to search for in the array. + + Returns: + A new `Expr` representing the 'array_contains' comparison. + """ + array_expr = Field.of(array) if isinstance(array, str) else array + return Expr.array_contains(array_expr, element) + + def array_contains_all( + array: Expr | str, elements: List[Expr | CONSTANT_TYPE] + ) -> "ArrayContainsAll": + """Creates an expression that checks if an array contains all the specified elements. + + Example: + >>> Function.array_contains_all("tags", ["news", "sports"]) + >>> Function.array_contains_all(Field.of("tags"), [Field.of("tag1"), "tag2"]) + + Args: + array: The array expression or field path to check. + elements: The list of elements (expressions or constants) to check for in the array. + + Returns: + A new `Expr` representing the 'array_contains_all' comparison. + """ + array_expr = Field.of(array) if isinstance(array, str) else array + return Expr.array_contains_all(array_expr, elements) + + def array_contains_any( + array: Expr | str, elements: List[Expr | CONSTANT_TYPE] + ) -> "ArrayContainsAny": + """Creates an expression that checks if an array contains any of the specified elements. + + Example: + >>> Function.array_contains_any("groups", ["admin", "editor"]) + >>> Function.array_contains_any(Field.of("categories"), [Field.of("cate1"), Field.of("cate2")]) + + Args: + array: The array expression or field path to check. + elements: The list of elements (expressions or constants) to check for in the array. + + Returns: + A new `Expr` representing the 'array_contains_any' comparison. + """ + array_expr = Field.of(array) if isinstance(array, str) else array + return Expr.array_contains_any(array_expr, elements) + + def array_length(array: Expr | str) -> "ArrayLength": + """Creates an expression that calculates the length of an array. + + Example: + >>> Function.array_length("cart") + + Returns: + A new `Expr` representing the length of the array. + """ + array_expr = Field.of(array) if isinstance(array, str) else array + return Expr.array_length(array_expr) + + def array_reverse(array: Expr | str) -> "ArrayReverse": + """Creates an expression that returns the reversed content of an array. + + Example: + >>> Function.array_reverse("preferences") + + Returns: + A new `Expr` representing the reversed array. + """ + array_expr = Field.of(array) if isinstance(array, str) else array + return Expr.array_reverse(array_expr) + + def is_nan(expr: Expr | str) -> "IsNaN": + """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). + + Example: + >>> Function.is_nan("measurement") + + Returns: + A new `Expr` representing the 'isNaN' check. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.is_nan(expr_val) + + def exists(expr: Expr | str) -> "Exists": + """Creates an expression that checks if a field exists in the document. + + Example: + >>> Function.exists("phoneNumber") + + Returns: + A new `Expr` representing the 'exists' check. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.exists(expr_val) + + def sum(expr: Expr | str) -> "Sum": + """Creates an aggregation that calculates the sum of a numeric field across multiple stage inputs. + + Example: + >>> Function.sum("orderAmount") + + Returns: + A new `Accumulator` representing the 'sum' aggregation. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.sum(expr_val) + + def avg(expr: Expr | str) -> "Avg": + """Creates an aggregation that calculates the average (mean) of a numeric field across multiple + stage inputs. + + Example: + >>> Function.avg("age") + + Returns: + A new `Accumulator` representing the 'avg' aggregation. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.avg(expr_val) + + def count(expr: Expr | str | None = None) -> "Count": + """Creates an aggregation that counts the number of stage inputs with valid evaluations of the + expression or field. If no expression is provided, it counts all inputs. + + Example: + >>> Function.count("productId") + >>> Function.count() + + Returns: + A new `Accumulator` representing the 'count' aggregation. + """ + if expr is None: + return Count() + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.count(expr_val) + + def min(expr: Expr | str) -> "Min": + """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. + + Example: + >>> Function.min("price") + + Returns: + A new `Accumulator` representing the 'min' aggregation. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.min(expr_val) + + def max(expr: Expr | str) -> "Max": + """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. + + Example: + >>> Function.max("score") + + Returns: + A new `Accumulator` representing the 'max' aggregation. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.max(expr_val) + + def char_length(expr: Expr | str) -> "CharLength": + """Creates an expression that calculates the character length of a string. + + Example: + >>> Function.char_length("name") + + Returns: + A new `Expr` representing the length of the string. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.char_length(expr_val) + + def byte_length(expr: Expr | str) -> "ByteLength": + """Creates an expression that calculates the byte length of a string in its UTF-8 form. + + Example: + >>> Function.byte_length("name") + + Returns: + A new `Expr` representing the byte length of the string. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.byte_length(expr_val) + + def like(expr: Expr | str, pattern: Expr | str) -> "Like": + """Creates an expression that performs a case-sensitive string comparison. + + Example: + >>> Function.like("title", "%guide%") + >>> Function.like(Field.of("title"), Field.of("pattern")) + + Args: + expr: The expression or field path to perform the comparison on. + pattern: The pattern (string or expression) to search for. You can use "%" as a wildcard character. + + Returns: + A new `Expr` representing the 'like' comparison. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.like(expr_val, pattern) + + def regex_contains(expr: Expr | str, regex: Expr | str) -> "RegexContains": + """Creates an expression that checks if a string contains a specified regular expression as a + substring. + + Example: + >>> Function.regex_contains("description", "(?i)example") + >>> Function.regex_contains(Field.of("description"), Field.of("regex")) + + Args: + expr: The expression or field path to perform the comparison on. + regex: The regular expression (string or expression) to use for the search. + + Returns: + A new `Expr` representing the 'contains' comparison. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.regex_contains(expr_val, regex) + + def regex_matches(expr: Expr | str, regex: Expr | str) -> "RegexMatch": + """Creates an expression that checks if a string matches a specified regular expression. + + Example: + >>> # Check if the 'email' field matches a valid email pattern + >>> Function.regex_matches("email", "[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}") + >>> Function.regex_matches(Field.of("email"), Field.of("regex")) + + Args: + expr: The expression or field path to match against. + regex: The regular expression (string or expression) to use for the match. + + Returns: + A new `Expr` representing the regular expression match. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.regex_matches(expr_val, regex) + + def str_contains(expr: Expr | str, substring: Expr | str) -> "StrContains": + """Creates an expression that checks if this string expression contains a specified substring. + + Example: + >>> Function.str_contains("description", "example") + >>> Function.str_contains(Field.of("description"), Field.of("keyword")) + + Args: + expr: The expression or field path to perform the comparison on. + substring: The substring (string or expression) to use for the search. + + Returns: + A new `Expr` representing the 'contains' comparison. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.str_contains(expr_val, substring) + + def starts_with(expr: Expr | str, prefix: Expr | str) -> "StartsWith": + """Creates an expression that checks if a string starts with a given prefix. + + Example: + >>> Function.starts_with("name", "Mr.") + >>> Function.starts_with(Field.of("fullName"), Field.of("firstName")) + + Args: + expr: The expression or field path to check. + prefix: The prefix (string or expression) to check for. + + Returns: + A new `Expr` representing the 'starts with' comparison. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.starts_with(expr_val, prefix) + + def ends_with(expr: Expr | str, postfix: Expr | str) -> "EndsWith": + """Creates an expression that checks if a string ends with a given postfix. + + Example: + >>> Function.ends_with("filename", ".txt") + >>> Function.ends_with(Field.of("url"), Field.of("extension")) + + Args: + expr: The expression or field path to check. + postfix: The postfix (string or expression) to check for. + + Returns: + A new `Expr` representing the 'ends with' comparison. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.ends_with(expr_val, postfix) + + def str_concat(first: Expr | str, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": + """Creates an expression that concatenates string expressions, fields or constants together. + + Example: + >>> Function.str_concat("firstName", " ", Field.of("lastName")) + + Args: + first: The first expression or field path to concatenate. + *elements: The expressions or constants (typically strings) to concatenate. + + Returns: + A new `Expr` representing the concatenated string. + """ + first_expr = Field.of(first) if isinstance(first, str) else first + return Expr.str_concat(first_expr, *elements) + + def map_get(map_expr: Expr | str, key: str) -> "MapGet": + """Accesses a value from a map (object) field using the provided key. + + Example: + >>> Function.map_get("address", "city") + + Args: + map_expr: The expression or field path of the map. + key: The key to access in the map. + + Returns: + A new `Expr` representing the value associated with the given key in the map. + """ + map_val = Field.of(map_expr) if isinstance(map_expr, str) else map_expr + return Expr.map_get(map_val, key) + + def vector_length(vector_expr: Expr | str) -> "VectorLength": + """Creates an expression that calculates the length (dimension) of a Firestore Vector. + + Example: + >>> Function.vector_length("embedding") + + Returns: + A new `Expr` representing the length of the vector. + """ + vector_val = ( + Field.of(vector_expr) if isinstance(vector_expr, str) else vector_expr + ) + return Expr.vector_length(vector_val) + + def timestamp_to_unix_micros(timestamp_expr: Expr | str) -> "TimestampToUnixMicros": + """Creates an expression that converts a timestamp to the number of microseconds since the epoch + (1970-01-01 00:00:00 UTC). + + Truncates higher levels of precision by rounding down to the beginning of the microsecond. + + Example: + >>> Function.timestamp_to_unix_micros("timestamp") + + Returns: + A new `Expr` representing the number of microseconds since the epoch. + """ + timestamp_val = ( + Field.of(timestamp_expr) + if isinstance(timestamp_expr, str) + else timestamp_expr + ) + return Expr.timestamp_to_unix_micros(timestamp_val) + + def unix_micros_to_timestamp(micros_expr: Expr | str) -> "UnixMicrosToTimestamp": + """Creates an expression that converts a number of microseconds since the epoch (1970-01-01 + 00:00:00 UTC) to a timestamp. + + Example: + >>> Function.unix_micros_to_timestamp("microseconds") + + Returns: + A new `Expr` representing the timestamp. + """ + micros_val = ( + Field.of(micros_expr) if isinstance(micros_expr, str) else micros_expr + ) + return Expr.unix_micros_to_timestamp(micros_val) + + def timestamp_to_unix_millis(timestamp_expr: Expr | str) -> "TimestampToUnixMillis": + """Creates an expression that converts a timestamp to the number of milliseconds since the epoch + (1970-01-01 00:00:00 UTC). + + Truncates higher levels of precision by rounding down to the beginning of the millisecond. + + Example: + >>> Function.timestamp_to_unix_millis("timestamp") + + Returns: + A new `Expr` representing the number of milliseconds since the epoch. + """ + timestamp_val = ( + Field.of(timestamp_expr) + if isinstance(timestamp_expr, str) + else timestamp_expr + ) + return Expr.timestamp_to_unix_millis(timestamp_val) + + def unix_millis_to_timestamp(millis_expr: Expr | str) -> "UnixMillisToTimestamp": + """Creates an expression that converts a number of milliseconds since the epoch (1970-01-01 + 00:00:00 UTC) to a timestamp. + + Example: + >>> Function.unix_millis_to_timestamp("milliseconds") + + Returns: + A new `Expr` representing the timestamp. + """ + millis_val = ( + Field.of(millis_expr) if isinstance(millis_expr, str) else millis_expr + ) + return Expr.unix_millis_to_timestamp(millis_val) + + def timestamp_to_unix_seconds( + timestamp_expr: Expr | str, + ) -> "TimestampToUnixSeconds": + """Creates an expression that converts a timestamp to the number of seconds since the epoch + (1970-01-01 00:00:00 UTC). + + Truncates higher levels of precision by rounding down to the beginning of the second. + + Example: + >>> Function.timestamp_to_unix_seconds("timestamp") + + Returns: + A new `Expr` representing the number of seconds since the epoch. + """ + timestamp_val = ( + Field.of(timestamp_expr) + if isinstance(timestamp_expr, str) + else timestamp_expr + ) + return Expr.timestamp_to_unix_seconds(timestamp_val) + + def unix_seconds_to_timestamp(seconds_expr: Expr | str) -> "UnixSecondsToTimestamp": + """Creates an expression that converts a number of seconds since the epoch (1970-01-01 00:00:00 + UTC) to a timestamp. + + Example: + >>> Function.unix_seconds_to_timestamp("seconds") + + Returns: + A new `Expr` representing the timestamp. + """ + seconds_val = ( + Field.of(seconds_expr) if isinstance(seconds_expr, str) else seconds_expr + ) + return Expr.unix_seconds_to_timestamp(seconds_val) + + def timestamp_add( + timestamp: Expr | str, unit: Expr | str, amount: Expr | float + ) -> "TimestampAdd": + """Creates an expression that adds a specified amount of time to this timestamp expression. + + Example: + >>> Function.timestamp_add("timestamp", "day", 1.5) + >>> Function.timestamp_add(Field.of("timestamp"), Field.of("unit"), Field.of("amount")) + + Args: + timestamp: The expression or field path of the timestamp. + unit: The expression or string evaluating to the unit of time to add, must be one of + 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. + amount: The expression or float representing the amount of time to add. + + Returns: + A new `Expr` representing the resulting timestamp. + """ + timestamp_expr = ( + Field.of(timestamp) if isinstance(timestamp, str) else timestamp + ) + return Expr.timestamp_add(timestamp_expr, unit, amount) + + def timestamp_sub( + timestamp: Expr | str, unit: Expr | str, amount: Expr | float + ) -> "TimestampSub": + """Creates an expression that subtracts a specified amount of time from this timestamp expression. + + Example: + >>> Function.timestamp_sub("timestamp", "hour", 2.5) + >>> Function.timestamp_sub(Field.of("timestamp"), Field.of("unit"), Field.of("amount")) + + Args: + timestamp: The expression or field path of the timestamp. + unit: The expression or string evaluating to the unit of time to subtract, must be one of + 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. + amount: The expression or float representing the amount of time to subtract. + + Returns: + A new `Expr` representing the resulting timestamp. + """ + timestamp_expr = ( + Field.of(timestamp) if isinstance(timestamp, str) else timestamp + ) + return Expr.timestamp_sub(timestamp_expr, unit, amount) + + +class Divide(Function): + """Represents the division function.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("divide", [left, right]) + + +class LogicalMax(Function): + """Represents the logical maximum function based on Firestore type ordering.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("logical_maximum", [left, right]) + + +class LogicalMin(Function): + """Represents the logical minimum function based on Firestore type ordering.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("logical_minimum", [left, right]) + + +class MapGet(Function): + """Represents accessing a value within a map by key.""" + + def __init__(self, map_: Expr, key: Constant[str]): + super().__init__("map_get", [map_, key]) + + +class Mod(Function): + """Represents the modulo function.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("mod", [left, right]) + + +class Multiply(Function): + """Represents the multiplication function.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("multiply", [left, right]) + + +class Parent(Function): + """Represents getting the parent document reference.""" + + def __init__(self, value: Expr): + super().__init__("parent", [value]) + + +class StrConcat(Function): + """Represents concatenating multiple strings.""" + + def __init__(self, *exprs: Expr): + super().__init__("str_concat", exprs) + + +class Subtract(Function): + """Represents the subtraction function.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("subtract", [left, right]) + + +class TimestampAdd(Function): + """Represents adding a duration to a timestamp.""" + + def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): + super().__init__("timestamp_add", [timestamp, unit, amount]) + + +class TimestampSub(Function): + """Represents subtracting a duration from a timestamp.""" + + def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): + super().__init__("timestamp_sub", [timestamp, unit, amount]) + + +class TimestampToUnixMicros(Function): + """Represents converting a timestamp to microseconds since epoch.""" + + def __init__(self, input: Expr): + super().__init__("timestamp_to_unix_micros", [input]) + + +class TimestampToUnixMillis(Function): + """Represents converting a timestamp to milliseconds since epoch.""" + + def __init__(self, input: Expr): + super().__init__("timestamp_to_unix_millis", [input]) + + +class TimestampToUnixSeconds(Function): + """Represents converting a timestamp to seconds since epoch.""" + + def __init__(self, input: Expr): + super().__init__("timestamp_to_unix_seconds", [input]) + + +class UnixMicrosToTimestamp(Function): + """Represents converting microseconds since epoch to a timestamp.""" + + def __init__(self, input: Expr): + super().__init__("unix_micros_to_timestamp", [input]) + + +class UnixMillisToTimestamp(Function): + """Represents converting milliseconds since epoch to a timestamp.""" + + def __init__(self, input: Expr): + super().__init__("unix_millis_to_timestamp", [input]) + + +class UnixSecondsToTimestamp(Function): + """Represents converting seconds since epoch to a timestamp.""" + + def __init__(self, input: Expr): + super().__init__("unix_seconds_to_timestamp", [input]) + + +class VectorLength(Function): + """Represents getting the length (dimension) of a vector.""" + + def __init__(self, array: Expr): + super().__init__("vector_length", [array]) + + +class Add(Function): + """Represents the addition function.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("add", [left, right]) + + +class ArrayElement(Function): + """Represents accessing an element within an array""" + + def __init__(self): + super().__init__("array_element", []) + + +class ArrayFilter(Function): + """Represents filtering elements from an array based on a condition.""" + + def __init__(self, array: Expr, filter: "FilterCondition"): + super().__init__("array_filter", [array, filter]) + + +class ArrayLength(Function): + """Represents getting the length of an array.""" + + def __init__(self, array: Expr): + super().__init__("array_length", [array]) + + +class ArrayReverse(Function): + """Represents reversing the elements of an array.""" + + def __init__(self, array: Expr): + super().__init__("array_reverse", [array]) + + +class ArrayTransform(Function): + """Represents applying a transformation function to each element of an array.""" + + def __init__(self, array: Expr, transform: Function): + super().__init__("array_transform", [array, transform]) + + +class ByteLength(Function): + """Represents getting the byte length of a string (UTF-8).""" + + def __init__(self, expr: Expr): + super().__init__("byte_length", [expr]) + + +class CharLength(Function): + """Represents getting the character length of a string.""" + + def __init__(self, expr: Expr): + super().__init__("char_length", [expr]) + + +class CollectionId(Function): + """Represents getting the collection ID from a document reference.""" + + def __init__(self, value: Expr): + super().__init__("collection_id", [value]) + + +class Accumulator(Function): + """A base class for aggregation functions that operate across multiple inputs.""" + + +class Max(Accumulator): + """Represents the maximum aggregation function.""" + + def __init__(self, value: Expr): + super().__init__("maximum", [value]) + + +class Min(Accumulator): + """Represents the minimum aggregation function.""" + + def __init__(self, value: Expr): + super().__init__("minimum", [value]) + + +class Sum(Accumulator): + """Represents the sum aggregation function.""" + + def __init__(self, value: Expr): + super().__init__("sum", [value]) + + +class Avg(Accumulator): + """Represents the average aggregation function.""" + + def __init__(self, value: Expr): + super().__init__("avg", [value]) + + +class Count(Accumulator): + """Represents an aggregation that counts the total number of inputs.""" + + def __init__(self, value: Expr | None = None): + super().__init__("count", [value] if value else []) + + +class Selectable(Expr): + """Base class for expressions that can be selected or aliased in projection stages.""" + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + else: + return other._to_map() == self._to_map() + + @abstractmethod + def _to_map(self) -> tuple[str, Value]: + """ + Returns a str: Value representation of the Selectable + """ + raise NotImplementedError + + @classmethod + def _value_from_selectables(cls, *selectables: Selectable) -> Value: + """ + Returns a Value representing a map of Selectables + """ + return Value( + map_value={ + "fields": {m[0]: m[1] for m in [s._to_map() for s in selectables]} + } + ) + + +T = TypeVar("T", bound=Expr) + + +class ExprWithAlias(Selectable, Generic[T]): + """Wraps an expression with an alias.""" + + def __init__(self, expr: T, alias: str): + self.expr = expr + self.alias = alias + + def _to_map(self): + return self.alias, self.expr._to_pb() + + def __repr__(self): + return f"{self.expr}.as_('{self.alias}')" + + def _to_pb(self): + return Value(map_value={"fields": {self.alias: self.expr._to_pb()}}) + + +class Field(Selectable): + """Represents a reference to a field within a document.""" + + DOCUMENT_ID = "__name__" + + def __init__(self, path: str): + """Initializes a Field reference. + + Args: + path: The dot-separated path to the field (e.g., "address.city"). + Use Field.DOCUMENT_ID for the document ID. + """ + self.path = path + + @staticmethod + def of(path: str): + """Creates a Field reference. + + Args: + path: The dot-separated path to the field (e.g., "address.city"). + Use Field.DOCUMENT_ID for the document ID. + + Returns: + A new Field instance. + """ + return Field(path) + + def _to_map(self): + return self.path, self._to_pb() + + def __repr__(self): + return f"Field.of({self.path!r})" + + def _to_pb(self): + return Value(field_reference_value=self.path) + + +class FilterCondition(Function): + """Filters the given data in some way.""" + + def __init__( + self, + *args, + use_infix_repr: bool = True, + infix_name_override: str | None = None, + **kwargs, + ): + self._use_infix_repr = use_infix_repr + self._infix_name_override = infix_name_override + super().__init__(*args, **kwargs) + + def __repr__(self): + """ + Most FilterConditions can be triggered infix. Eg: Field.of('age').gte(18). + + Display them this way in the repr string where possible + """ + if self._use_infix_repr: + infix_name = self._infix_name_override or self.name + if len(self.params) == 1: + return f"{self.params[0]!r}.{infix_name}()" + elif len(self.params) == 2: + return f"{self.params[0]!r}.{infix_name}({self.params[1]!r})" + return super().__repr__() + + @staticmethod + def _from_query_filter_pb(filter_pb, client): + if isinstance(filter_pb, Query_pb.CompositeFilter): + sub_filters = [ + FilterCondition._from_query_filter_pb(f, client) + for f in filter_pb.filters + ] + if filter_pb.op == Query_pb.CompositeFilter.Operator.OR: + return Or(*sub_filters) + elif filter_pb.op == Query_pb.CompositeFilter.Operator.AND: + return And(*sub_filters) + else: + raise TypeError( + f"Unexpected CompositeFilter operator type: {filter_pb.op}" + ) + elif isinstance(filter_pb, Query_pb.UnaryFilter): + field = Field.of(filter_pb.field.field_path) + if filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NAN: + return And(field.exists(), field.is_nan()) + elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NAN: + return And(field.exists(), Not(field.is_nan())) + elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NULL: + return And(field.exists(), field.eq(None)) + elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NULL: + return And(field.exists(), Not(field.eq(None))) + else: + raise TypeError(f"Unexpected UnaryFilter operator type: {filter_pb.op}") + elif isinstance(filter_pb, Query_pb.FieldFilter): + field = Field.of(filter_pb.field.field_path) + value = decode_value(filter_pb.value, client) + if filter_pb.op == Query_pb.FieldFilter.Operator.LESS_THAN: + return And(field.exists(), field.lt(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.LESS_THAN_OR_EQUAL: + return And(field.exists(), field.lte(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.GREATER_THAN: + return And(field.exists(), field.gt(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.GREATER_THAN_OR_EQUAL: + return And(field.exists(), field.gte(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.EQUAL: + return And(field.exists(), field.eq(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_EQUAL: + return And(field.exists(), field.neq(value)) + if filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS: + return And(field.exists(), field.array_contains(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS_ANY: + return And(field.exists(), field.array_contains_any(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.IN: + return And(field.exists(), field.in_any(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_IN: + return And(field.exists(), field.not_in_any(value)) + else: + raise TypeError(f"Unexpected FieldFilter operator type: {filter_pb.op}") + elif isinstance(filter_pb, Query_pb.Filter): + # unwrap oneof + f = ( + filter_pb.composite_filter + or filter_pb.field_filter + or filter_pb.unary_filter + ) + return FilterCondition._from_query_filter_pb(f, client) + else: + raise TypeError(f"Unexpected filter type: {type(filter_pb)}") + + +class And(FilterCondition): + def __init__(self, *conditions: "FilterCondition"): + super().__init__("and", conditions, use_infix_repr=False) + + +class ArrayContains(FilterCondition): + def __init__(self, array: Expr, element: Expr): + super().__init__( + "array_contains", [array, element if element else Constant(None)] + ) + + +class ArrayContainsAll(FilterCondition): + """Represents checking if an array contains all specified elements.""" + + def __init__(self, array: Expr, elements: List[Expr]): + super().__init__("array_contains_all", [array, ListOfExprs(elements)]) + + +class ArrayContainsAny(FilterCondition): + """Represents checking if an array contains any of the specified elements.""" + + def __init__(self, array: Expr, elements: List[Expr]): + super().__init__("array_contains_any", [array, ListOfExprs(elements)]) + + +class EndsWith(FilterCondition): + """Represents checking if a string ends with a specific postfix.""" + + def __init__(self, expr: Expr, postfix: Expr): + super().__init__("ends_with", [expr, postfix]) + + +class Eq(FilterCondition): + """Represents the equality comparison.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("eq", [left, right if right else Constant(None)]) + + +class Exists(FilterCondition): + """Represents checking if a field exists.""" + + def __init__(self, expr: Expr): + super().__init__("exists", [expr]) + + +class Gt(FilterCondition): + """Represents the greater than comparison.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("gt", [left, right if right else Constant(None)]) + + +class Gte(FilterCondition): + """Represents the greater than or equal to comparison.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("gte", [left, right if right else Constant(None)]) + + +class If(FilterCondition): + """Represents a conditional expression (if-then-else).""" + + def __init__(self, condition: "FilterCondition", true_expr: Expr, false_expr: Expr): + super().__init__( + "if", [condition, true_expr, false_expr if false_expr else Constant(None)] + ) + + +class In(FilterCondition): + """Represents checking if an expression's value is within a list of values.""" + + def __init__(self, left: Expr, others: List[Expr]): + super().__init__( + "in", [left, ListOfExprs(others)], infix_name_override="in_any" + ) + + +class IsNaN(FilterCondition): + """Represents checking if a numeric value is NaN.""" + + def __init__(self, value: Expr): + super().__init__("is_nan", [value]) + + +class Like(FilterCondition): + """Represents a case-sensitive wildcard string comparison.""" + + def __init__(self, expr: Expr, pattern: Expr): + super().__init__("like", [expr, pattern]) + + +class Lt(FilterCondition): + """Represents the less than comparison.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("lt", [left, right if right else Constant(None)]) + + +class Lte(FilterCondition): + """Represents the less than or equal to comparison.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("lte", [left, right if right else Constant(None)]) + + +class Neq(FilterCondition): + """Represents the inequality comparison.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("neq", [left, right if right else Constant(None)]) + + +class Not(FilterCondition): + """Represents the logical NOT of a filter condition.""" + + def __init__(self, condition: Expr): + super().__init__("not", [condition], use_infix_repr=False) + + +class Or(FilterCondition): + """Represents the logical OR of multiple filter conditions.""" + + def __init__(self, *conditions: "FilterCondition"): + super().__init__("or", conditions) + + +class RegexContains(FilterCondition): + """Represents checking if a string contains a substring matching a regex.""" + + def __init__(self, expr: Expr, regex: Expr): + super().__init__("regex_contains", [expr, regex]) + + +class RegexMatch(FilterCondition): + """Represents checking if a string fully matches a regex.""" + + def __init__(self, expr: Expr, regex: Expr): + super().__init__("regex_match", [expr, regex]) + + +class StartsWith(FilterCondition): + """Represents checking if a string starts with a specific prefix.""" + + def __init__(self, expr: Expr, prefix: Expr): + super().__init__("starts_with", [expr, prefix]) + + +class StrContains(FilterCondition): + """Represents checking if a string contains a specific substring.""" + + def __init__(self, expr: Expr, substring: Expr): + super().__init__("str_contains", [expr, substring]) + + +class Xor(FilterCondition): + """Represents the logical XOR of multiple filter conditions.""" + + def __init__(self, conditions: List["FilterCondition"]): + super().__init__("xor", conditions, use_infix_repr=False) diff --git a/google/cloud/firestore_v1/pipeline_source.py b/google/cloud/firestore_v1/pipeline_source.py index f2f081fee..6d83ae533 100644 --- a/google/cloud/firestore_v1/pipeline_source.py +++ b/google/cloud/firestore_v1/pipeline_source.py @@ -16,10 +16,12 @@ from typing import Generic, TypeVar, TYPE_CHECKING from google.cloud.firestore_v1 import _pipeline_stages as stages from google.cloud.firestore_v1.base_pipeline import _BasePipeline +from google.cloud.firestore_v1._helpers import DOCUMENT_PATH_DELIMITER if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.async_client import AsyncClient + from google.cloud.firestore_v1.base_document import BaseDocumentReference PipelineType = TypeVar("PipelineType", bound=_BasePipeline) @@ -41,13 +43,45 @@ def __init__(self, client: Client | AsyncClient): def _create_pipeline(self, source_stage): return self.client._pipeline_cls._create_with_stages(self.client, source_stage) - def collection(self, path: str) -> PipelineType: + def collection(self, path: str | tuple[str]) -> PipelineType: """ Creates a new Pipeline that operates on a specified Firestore collection. Args: - path: The path to the Firestore collection (e.g., "users") + path: The path to the Firestore collection (e.g., "users"). Can either be: + * A single ``/``-delimited path to a collection + * A tuple of collection path segment Returns: a new pipeline instance targeting the specified collection """ + if isinstance(path, tuple): + path = DOCUMENT_PATH_DELIMITER.join(path) return self._create_pipeline(stages.Collection(path)) + + def collection_group(self, collection_id: str) -> PipelineType: + """ + Creates a new Pipeline that that operates on all documents in a collection group. + Args: + collection_id: The ID of the collection group + Returns: + a new pipeline instance targeting the specified collection group + """ + return self._create_pipeline(stages.CollectionGroup(collection_id)) + + def database(self) -> PipelineType: + """ + Creates a new Pipeline that operates on all documents in the Firestore database. + Returns: + a new pipeline instance targeting the specified collection + """ + return self._create_pipeline(stages.Database()) + + def documents(self, *docs: "BaseDocumentReference") -> PipelineType: + """ + Creates a new Pipeline that operates on a specific set of Firestore documents. + Args: + docs: The DocumentReference instances representing the documents to include in the pipeline. + Returns: + a new pipeline instance targeting the specified documents + """ + return self._create_pipeline(stages.Documents.of(*docs)) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml new file mode 100644 index 000000000..dc262f4a9 --- /dev/null +++ b/tests/system/pipeline_e2e.yaml @@ -0,0 +1,1640 @@ +data: + books: + book1: + title: "The Hitchhiker's Guide to the Galaxy" + author: "Douglas Adams" + genre: "Science Fiction" + published: 1979 + rating: 4.2 + tags: + - comedy + - space + - adventure + awards: + hugo: true + nebula: false + book2: + title: "Pride and Prejudice" + author: "Jane Austen" + genre: "Romance" + published: 1813 + rating: 4.5 + tags: + - classic + - social commentary + - love + awards: + none: true + book3: + title: "One Hundred Years of Solitude" + author: "Gabriel García Márquez" + genre: "Magical Realism" + published: 1967 + rating: 4.3 + tags: + - family + - history + - fantasy + awards: + nobel: true + nebula: false + book4: + title: "The Lord of the Rings" + author: "J.R.R. Tolkien" + genre: "Fantasy" + published: 1954 + rating: 4.7 + tags: + - adventure + - magic + - epic + awards: + hugo: false + nebula: false + book5: + title: "The Handmaid's Tale" + author: "Margaret Atwood" + genre: "Dystopian" + published: 1985 + rating: 4.1 + tags: + - feminism + - totalitarianism + - resistance + awards: + arthur c. clarke: true + booker prize: false + book6: + title: "Crime and Punishment" + author: "Fyodor Dostoevsky" + genre: "Psychological Thriller" + published: 1866 + rating: 4.3 + tags: + - philosophy + - crime + - redemption + awards: + none: true + book7: + title: "To Kill a Mockingbird" + author: "Harper Lee" + genre: "Southern Gothic" + published: 1960 + rating: 4.2 + tags: + - racism + - injustice + - coming-of-age + awards: + pulitzer: true + book8: + title: "1984" + author: "George Orwell" + genre: "Dystopian" + published: 1949 + rating: 4.2 + tags: + - surveillance + - totalitarianism + - propaganda + awards: + prometheus: true + book9: + title: "The Great Gatsby" + author: "F. Scott Fitzgerald" + genre: "Modernist" + published: 1925 + rating: 4.0 + tags: + - wealth + - american dream + - love + awards: + none: true + book10: + title: "Dune" + author: "Frank Herbert" + genre: "Science Fiction" + published: 1965 + rating: 4.6 + tags: + - politics + - desert + - ecology + awards: + hugo: true + nebula: true +tests: + - description: "testAggregates - count" + pipeline: + - Collection: books + - Aggregate: + - ExprWithAlias: + - Count + - "count" + assert_results: + - count: 10 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + count: + functionValue: + name: count + - mapValue: {} + name: aggregate + - description: "testAggregates - avg, count, max" + pipeline: + - Collection: books + - Where: + - Eq: + - Field: genre + - Constant: Science Fiction + - Aggregate: + - ExprWithAlias: + - Count + - "count" + - ExprWithAlias: + - Avg: + - Field: rating + - "avg_rating" + - ExprWithAlias: + - Max: + - Field: rating + - "max_rating" + assert_results: + - count: 2 + avg_rating: 4.4 + max_rating: 4.6 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: eq + name: where + - args: + - mapValue: + fields: + avg_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: avg + count: + functionValue: + name: count + max_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: maximum + - mapValue: {} + name: aggregate + - description: testGroupBysWithoutAccumulators + pipeline: + - Collection: books + - Where: + - Lt: + - Field: published + - Constant: 1900 + - Aggregate: + accumulators: [] + groups: [genre] + assert_error: ".* requires at least one accumulator" + - description: testGroupBysAndAggregate + pipeline: + - Collection: books + - Where: + - Lt: + - Field: published + - Constant: 1984 + - Aggregate: + accumulators: + - ExprWithAlias: + - Avg: + - Field: rating + - "avg_rating" + groups: [genre] + - Where: + - Gt: + - Field: avg_rating + - Constant: 4.3 + - Sort: + - Ordering: + - Field: avg_rating + - ASCENDING + assert_results: + - avg_rating: 4.4 + genre: Science Fiction + - avg_rating: 4.5 + genre: Romance + - avg_rating: 4.7 + genre: Fantasy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1984' + name: lt + name: where + - args: + - mapValue: + fields: + avg_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: avg + - mapValue: + fields: + genre: + fieldReferenceValue: genre + name: aggregate + - args: + - functionValue: + args: + - fieldReferenceValue: avg_rating + - doubleValue: 4.3 + name: gt + name: where + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: avg_rating + name: sort + - description: testMinMax + pipeline: + - Collection: books + - Aggregate: + - ExprWithAlias: + - Count + - "count" + - ExprWithAlias: + - Max: + - Field: rating + - "max_rating" + - ExprWithAlias: + - Min: + - Field: published + - "min_published" + assert_results: + - count: 10 + max_rating: 4.7 + min_published: 1813 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + count: + functionValue: + name: count + max_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: maximum + min_published: + functionValue: + args: + - fieldReferenceValue: published + name: minimum + - mapValue: {} + name: aggregate + - description: selectSpecificFields + pipeline: + - Collection: books + - Select: + - title + - author + - Sort: + - Ordering: + - Field: author + - ASCENDING + assert_results: + - title: "The Hitchhiker's Guide to the Galaxy" + author: "Douglas Adams" + - title: "The Great Gatsby" + author: "F. Scott Fitzgerald" + - title: "Dune" + author: "Frank Herbert" + - title: "Crime and Punishment" + author: "Fyodor Dostoevsky" + - title: "One Hundred Years of Solitude" + author: "Gabriel García Márquez" + - title: "1984" + author: "George Orwell" + - title: "To Kill a Mockingbird" + author: "Harper Lee" + - title: "The Lord of the Rings" + author: "J.R.R. Tolkien" + - title: "Pride and Prejudice" + author: "Jane Austen" + - title: "The Handmaid's Tale" + author: "Margaret Atwood" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + author: + fieldReferenceValue: author + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author + name: sort + - description: addAndRemoveFields + pipeline: + - Collection: books + - AddFields: + - ExprWithAlias: + - StrConcat: + - Field: author + - Constant: _ + - Field: title + - "author_title" + - ExprWithAlias: + - StrConcat: + - Field: title + - Constant: _ + - Field: author + - "title_author" + - RemoveFields: + - title_author + - tags + - awards + - rating + - title + - Field: published + - Field: genre + - Field: nestedField # Field does not exist, should be ignored + - Sort: + - Ordering: + - Field: author_title + - ASCENDING + assert_results: + - author: Douglas Adams + author_title: Douglas Adams_The Hitchhiker's Guide to the Galaxy + - author: F. Scott Fitzgerald + author_title: F. Scott Fitzgerald_The Great Gatsby + - author: Frank Herbert + author_title: Frank Herbert_Dune + - author: Fyodor Dostoevsky + author_title: Fyodor Dostoevsky_Crime and Punishment + - author: Gabriel García Márquez + author_title: Gabriel García Márquez_One Hundred Years of Solitude + - author: George Orwell + author_title: George Orwell_1984 + - author: Harper Lee + author_title: Harper Lee_To Kill a Mockingbird + - author: J.R.R. Tolkien + author_title: J.R.R. Tolkien_The Lord of the Rings + - author: Jane Austen + author_title: Jane Austen_Pride and Prejudice + - author: Margaret Atwood + author_title: Margaret Atwood_The Handmaid's Tale + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + author_title: + functionValue: + args: + - fieldReferenceValue: author + - stringValue: _ + - fieldReferenceValue: title + name: str_concat + title_author: + functionValue: + args: + - fieldReferenceValue: title + - stringValue: _ + - fieldReferenceValue: author + name: str_concat + name: add_fields + - args: + - fieldReferenceValue: title_author + - fieldReferenceValue: tags + - fieldReferenceValue: awards + - fieldReferenceValue: rating + - fieldReferenceValue: title + - fieldReferenceValue: published + - fieldReferenceValue: genre + - fieldReferenceValue: nestedField + name: remove_fields + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author_title + name: sort + - description: whereByMultipleConditions + pipeline: + - Collection: books + - Where: + - And: + - Gt: + - Field: rating + - Constant: 4.5 + - Eq: + - Field: genre + - Constant: Science Fiction + assert_results: + - title: Dune + author: Frank Herbert + genre: Science Fiction + published: 1965 + rating: 4.6 + tags: + - politics + - desert + - ecology + awards: + hugo: true + nebula: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: gt + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: eq + name: and + name: where + - description: whereByOrCondition + pipeline: + - Collection: books + - Where: + - Or: + - Eq: + - Field: genre + - Constant: Romance + - Eq: + - Field: genre + - Constant: Dystopian + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "1984" + - title: Pride and Prejudice + - title: The Handmaid's Tale + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Romance + name: eq + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Dystopian + name: eq + name: or + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testPipelineWithOffsetAndLimit + pipeline: + - Collection: books + - Sort: + - Ordering: + - Field: author + - ASCENDING + - Offset: 5 + - Limit: 3 + - Select: + - title + - author + assert_results: + - title: "1984" + author: George Orwell + - title: To Kill a Mockingbird + author: Harper Lee + - title: The Lord of the Rings + author: J.R.R. Tolkien + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author + name: sort + - args: + - integerValue: '5' + name: offset + - args: + - integerValue: '3' + name: limit + - args: + - mapValue: + fields: + author: + fieldReferenceValue: author + title: + fieldReferenceValue: title + name: select + - description: testArrayContains + pipeline: + - Collection: books + - Where: + - ArrayContains: + - Field: tags + - Constant: comedy + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + author: Douglas Adams + awards: + hugo: true + nebula: false + genre: Science Fiction + published: 1979 + rating: 4.2 + tags: ["comedy", "space", "adventure"] + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: tags + - stringValue: comedy + name: array_contains + name: where + - description: testArrayContainsAny + pipeline: + - Collection: books + - Where: + - ArrayContainsAny: + - Field: tags + - - Constant: comedy + - Constant: classic + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: Pride and Prejudice + - title: The Hitchhiker's Guide to the Galaxy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: tags + - arrayValue: + values: + - stringValue: comedy + - stringValue: classic + name: array_contains_any + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testArrayContainsAll + pipeline: + - Collection: books + - Where: + - ArrayContainsAll: + - Field: tags + - - Constant: adventure + - Constant: magic + - Select: + - title + assert_results: + - title: The Lord of the Rings + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: tags + - arrayValue: + values: + - stringValue: adventure + - stringValue: magic + name: array_contains_all + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - description: testArrayLength + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + - ArrayLength: + - Field: tags + - "tagsCount" + - Where: + - Eq: + - Field: tagsCount + - Constant: 3 + assert_results: # All documents have 3 tags + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + tagsCount: + functionValue: + args: + - fieldReferenceValue: tags + name: array_length + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: tagsCount + - integerValue: '3' + name: eq + name: where + - description: testStrConcat + pipeline: + - Collection: books + - Sort: + - Ordering: + - Field: author + - ASCENDING + - Select: + - ExprWithAlias: + - StrConcat: + - Field: author + - Constant: " - " + - Field: title + - "bookInfo" + - Limit: 1 + assert_results: + - bookInfo: Douglas Adams - The Hitchhiker's Guide to the Galaxy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author + name: sort + - args: + - mapValue: + fields: + bookInfo: + functionValue: + args: + - fieldReferenceValue: author + - stringValue: ' - ' + - fieldReferenceValue: title + name: str_concat + name: select + - args: + - integerValue: '1' + name: limit + - description: testStartsWith + pipeline: + - Collection: books + - Where: + - StartsWith: + - Field: title + - Constant: The + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: The Great Gatsby + - title: The Handmaid's Tale + - title: The Hitchhiker's Guide to the Galaxy + - title: The Lord of the Rings + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: The + name: starts_with + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testEndsWith + pipeline: + - Collection: books + - Where: + - EndsWith: + - Field: title + - Constant: y + - Select: + - title + - Sort: + - Ordering: + - Field: title + - DESCENDING + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + - title: The Great Gatsby + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: y + name: ends_with + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: descending + expression: + fieldReferenceValue: title + name: sort + - description: testLength + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + - CharLength: + - Field: title + - "titleLength" + - title + - Where: + - Gt: + - Field: titleLength + - Constant: 20 + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - titleLength: 29 + title: One Hundred Years of Solitude + - titleLength: 36 + title: The Hitchhiker's Guide to the Galaxy + - titleLength: 21 + title: The Lord of the Rings + - titleLength: 21 + title: To Kill a Mockingbird + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + titleLength: + functionValue: + args: + - fieldReferenceValue: title + name: char_length + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: titleLength + - integerValue: '20' + name: gt + name: where + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testStringFunctions - CharLength + pipeline: + - Collection: books + - Where: + - Eq: + - Field: author + - Constant: "Douglas Adams" + - Select: + - ExprWithAlias: + - CharLength: + - Field: title + - "title_length" + assert_results: + - title_length: 36 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where + - args: + - mapValue: + fields: + title_length: + functionValue: + args: + - fieldReferenceValue: title + name: char_length + name: select + - description: testStringFunctions - ByteLength + pipeline: + - Collection: books + - Where: + - Eq: + - Field: author + - Constant: Douglas Adams + - Select: + - ExprWithAlias: + - ByteLength: + - StrConcat: + - Field: title + - Constant: _银河系漫游指南 + - "title_byte_length" + assert_results: + - title_byte_length: 58 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where + - args: + - mapValue: + fields: + title_byte_length: + functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "_\u94F6\u6CB3\u7CFB\u6F2B\u6E38\u6307\u5357" + name: str_concat + name: byte_length + name: select + - description: testLike + pipeline: + - Collection: books + - Where: + - Like: + - Field: title + - Constant: "%Guide%" + - Select: + - title + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + - description: testRegexContains + # Find titles that contain either "the" or "of" (case-insensitive) + pipeline: + - Collection: books + - Where: + - RegexContains: + - Field: title + - Constant: "(?i)(the|of)" + assert_count: 5 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "(?i)(the|of)" + name: regex_contains + name: where + - description: testRegexMatches + # Find titles that contain either "the" or "of" (case-insensitive) + pipeline: + - Collection: books + - Where: + - RegexMatch: + - Field: title + - Constant: ".*(?i)(the|of).*" + assert_count: 5 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: ".*(?i)(the|of).*" + name: regex_match + name: where + - description: testArithmeticOperations + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: To Kill a Mockingbird + - Select: + - ExprWithAlias: + - Add: + - Field: rating + - Constant: 1 + - "ratingPlusOne" + - ExprWithAlias: + - Subtract: + - Field: published + - Constant: 1900 + - "yearsSince1900" + - ExprWithAlias: + - Multiply: + - Field: rating + - Constant: 10 + - "ratingTimesTen" + - ExprWithAlias: + - Divide: + - Field: rating + - Constant: 2 + - "ratingDividedByTwo" + - ExprWithAlias: + - Multiply: + - Field: rating + - Constant: 20 + - "ratingTimes20" + - ExprWithAlias: + - Add: + - Field: rating + - Constant: 3 + - "ratingPlus3" + - ExprWithAlias: + - Mod: + - Field: rating + - Constant: 2 + - "ratingMod2" + assert_results: + - ratingPlusOne: 5.2 + yearsSince1900: 60 + ratingTimesTen: 42.0 + ratingDividedByTwo: 2.1 + ratingTimes20: 84 + ratingPlus3: 7.2 + ratingMod2: 0.20000000000000018 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: To Kill a Mockingbird + name: eq + name: where + - args: + - mapValue: + fields: + ratingDividedByTwo: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: divide + ratingPlusOne: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '1' + name: add + ratingTimesTen: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '10' + name: multiply + yearsSince1900: + functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: subtract + ratingTimes20: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '20' + name: multiply + ratingPlus3: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '3' + name: add + ratingMod2: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: mod + name: select + - description: testComparisonOperators + pipeline: + - Collection: books + - Where: + - And: + - Gt: + - Field: rating + - Constant: 4.2 + - Lte: + - Field: rating + - Constant: 4.5 + - Neq: + - Field: genre + - Constant: Science Fiction + - Select: + - rating + - title + - Sort: + - Ordering: + - title + - ASCENDING + assert_results: + - rating: 4.3 + title: Crime and Punishment + - rating: 4.3 + title: One Hundred Years of Solitude + - rating: 4.5 + title: Pride and Prejudice + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.2 + name: gt + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: lte + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: neq + name: and + name: where + - args: + - mapValue: + fields: + rating: + fieldReferenceValue: rating + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testLogicalOperators + pipeline: + - Collection: books + - Where: + - Or: + - And: + - Gt: + - Field: rating + - Constant: 4.5 + - Eq: + - Field: genre + - Constant: Science Fiction + - Lt: + - Field: published + - Constant: 1900 + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: Crime and Punishment + - title: Dune + - title: Pride and Prejudice + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: gt + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: eq + name: and + - functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: lt + name: or + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testChecks + pipeline: + - Collection: books + - Where: + - Not: + - IsNaN: + - Field: rating + - Select: + - ExprWithAlias: + - Not: + - IsNaN: + - Field: rating + - "ratingIsNotNaN" + - Limit: 1 + assert_results: + - ratingIsNotNaN: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + name: is_nan + name: not + name: where + - args: + - mapValue: + fields: + ratingIsNotNaN: + functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + name: is_nan + name: not + name: select + - args: + - integerValue: '1' + name: limit + - description: testLogicalMinMax + pipeline: + - Collection: books + - Where: + - Eq: + - Field: author + - Constant: Douglas Adams + - Select: + - ExprWithAlias: + - LogicalMax: + - Field: rating + - Constant: 4.5 + - "max_rating" + - ExprWithAlias: + - LogicalMax: + - Field: published + - Constant: 1900 + - "max_published" + assert_results: + - max_rating: 4.5 + max_published: 1979 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where + - args: + - mapValue: + fields: + max_published: + functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: logical_maximum + max_rating: + functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: logical_maximum + name: select + - description: testMapGet + pipeline: + - Collection: books + - Sort: + - Ordering: + - Field: published + - DESCENDING + - Select: + - ExprWithAlias: + - MapGet: + - Field: awards + - Constant: hugo + - "hugoAward" + - Field: title + - Where: + - Eq: + - Field: hugoAward + - Constant: true + assert_results: + - hugoAward: true + title: The Hitchhiker's Guide to the Galaxy + - hugoAward: true + title: Dune + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + direction: + stringValue: descending + expression: + fieldReferenceValue: published + name: sort + - args: + - mapValue: + fields: + hugoAward: + functionValue: + args: + - fieldReferenceValue: awards + - stringValue: hugo + name: map_get + title: + fieldReferenceValue: title + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: hugoAward + - booleanValue: true + name: eq + name: where + - description: testNestedFields + pipeline: + - Collection: books + - Where: + - Eq: + - Field: awards.hugo + - Constant: true + - Sort: + - Ordering: + - Field: title + - DESCENDING + - Select: + - title + - Field: awards.hugo + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + awards.hugo: true + - title: Dune + awards.hugo: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: awards.hugo + - booleanValue: true + name: eq + name: where + - args: + - mapValue: + fields: + direction: + stringValue: descending + expression: + fieldReferenceValue: title + name: sort + - args: + - mapValue: + fields: + awards.hugo: + fieldReferenceValue: awards.hugo + title: + fieldReferenceValue: title + name: select + - description: testSampleLimit + pipeline: + - Collection: books + - Sample: 3 + assert_count: 3 # Results will vary due to randomness + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - integerValue: '3' + - stringValue: documents + name: sample + - description: testSamplePercentage + pipeline: + - Collection: books + - Sample: + - SampleOptions: + - 0.6 + - percent + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - doubleValue: 0.6 + - stringValue: percent + name: sample + - description: testUnion + pipeline: + - Collection: books + - Union: + - Pipeline: + - Collection: books + assert_count: 20 # Results will be duplicated + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - pipelineValue: + stages: + - args: + - referenceValue: /books + name: collection + name: union + - description: testUnnest + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: The Hitchhiker's Guide to the Galaxy + - Unnest: + - tags + - tags_alias + - Select: tags_alias + assert_results: + - tags_alias: comedy + - tags_alias: space + - tags_alias: adventure + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: The Hitchhiker's Guide to the Galaxy + name: eq + name: where + - args: + - fieldReferenceValue: tags + - fieldReferenceValue: tags_alias + name: unnest + - args: + - mapValue: + fields: + tags_alias: + fieldReferenceValue: tags_alias + name: select diff --git a/tests/system/test__helpers.py b/tests/system/test__helpers.py index d6ee9b944..c146a5763 100644 --- a/tests/system/test__helpers.py +++ b/tests/system/test__helpers.py @@ -10,7 +10,13 @@ RANDOM_ID_REGEX = re.compile("^[a-zA-Z0-9]{20}$") MISSING_DOCUMENT = "No document to update: " DOCUMENT_EXISTS = "Document already exists: " +ENTERPRISE_MODE_ERROR = "only allowed on ENTERPRISE mode" UNIQUE_RESOURCE_ID = unique_resource_id("-") EMULATOR_CREDS = EmulatorCreds() FIRESTORE_EMULATOR = os.environ.get(_FIRESTORE_EMULATOR_HOST) is not None FIRESTORE_OTHER_DB = os.environ.get("SYSTEM_TESTS_DATABASE", "system-tests-named-db") +FIRESTORE_ENTERPRISE_DB = os.environ.get("ENTERPRISE_DATABASE", "enterprise-db") + +# run all tests against default database, and a named database +# TODO: add enterprise mode when GA (RunQuery not currently supported) +TEST_DATABASES = [None, FIRESTORE_OTHER_DB] diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py new file mode 100644 index 000000000..9d44bbc57 --- /dev/null +++ b/tests/system/test_pipeline_acceptance.py @@ -0,0 +1,285 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file loads and executes yaml-encoded test cases from pipeline_e2e.yaml +""" + +from __future__ import annotations +import os +import pytest +import yaml +import re +from typing import Any + +from google.protobuf.json_format import MessageToDict + +from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1 import pipeline_expressions +from google.api_core.exceptions import GoogleAPIError + +from google.cloud.firestore import Client, AsyncClient + +from test__helpers import FIRESTORE_ENTERPRISE_DB + +FIRESTORE_PROJECT = os.environ.get("GCLOUD_PROJECT") + +test_dir_name = os.path.dirname(__file__) + + +def yaml_loader(field="tests", file_name="pipeline_e2e.yaml"): + """ + Helper to load test cases or data from yaml file + """ + with open(f"{test_dir_name}/{file_name}") as f: + test_cases = yaml.safe_load(f) + return test_cases[field] + + +@pytest.mark.parametrize( + "test_dict", + [t for t in yaml_loader() if "assert_proto" in t], + ids=lambda x: f"{x.get('description', '')}", +) +def test_pipeline_parse_proto(test_dict, client): + """ + Finds assert_proto statements in yaml, and compares generated proto against expected value + """ + expected_proto = test_dict.get("assert_proto", None) + pipeline = parse_pipeline(client, test_dict["pipeline"]) + # check if proto matches as expected + if expected_proto: + got_proto = MessageToDict(pipeline._to_pb()._pb) + assert yaml.dump(expected_proto) == yaml.dump(got_proto) + + +@pytest.mark.parametrize( + "test_dict", + [t for t in yaml_loader() if "assert_error" in t], + ids=lambda x: f"{x.get('description', '')}", +) +def test_pipeline_expected_errors(test_dict, client): + """ + Finds assert_error statements in yaml, and ensures the pipeline raises the expected error + """ + error_regex = test_dict["assert_error"] + pipeline = parse_pipeline(client, test_dict["pipeline"]) + # check if server responds as expected + with pytest.raises(GoogleAPIError) as err: + pipeline.execute() + found_error = str(err.value) + match = re.search(error_regex, found_error) + assert match, f"error '{found_error}' does not match '{error_regex}'" + + +@pytest.mark.parametrize( + "test_dict", + [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t], + ids=lambda x: f"{x.get('description', '')}", +) +def test_pipeline_results(test_dict, client): + """ + Ensure pipeline returns expected results + """ + expected_results = test_dict.get("assert_results", None) + expected_count = test_dict.get("assert_count", None) + pipeline = parse_pipeline(client, test_dict["pipeline"]) + # check if server responds as expected + got_results = [snapshot.data() for snapshot in pipeline.stream()] + if expected_results: + assert got_results == expected_results + if expected_count is not None: + assert len(got_results) == expected_count + + +@pytest.mark.parametrize( + "test_dict", + [t for t in yaml_loader() if "assert_error" in t], + ids=lambda x: f"{x.get('description', '')}", +) +@pytest.mark.asyncio +async def test_pipeline_expected_errors_async(test_dict, async_client): + """ + Finds assert_error statements in yaml, and ensures the pipeline raises the expected error + """ + error_regex = test_dict["assert_error"] + pipeline = parse_pipeline(async_client, test_dict["pipeline"]) + # check if server responds as expected + with pytest.raises(GoogleAPIError) as err: + await pipeline.execute() + found_error = str(err.value) + match = re.search(error_regex, found_error) + assert match, f"error '{found_error}' does not match '{error_regex}'" + + +@pytest.mark.parametrize( + "test_dict", + [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t], + ids=lambda x: f"{x.get('description', '')}", +) +@pytest.mark.asyncio +async def test_pipeline_results_async(test_dict, async_client): + """ + Ensure pipeline returns expected results + """ + expected_results = test_dict.get("assert_results", None) + expected_count = test_dict.get("assert_count", None) + pipeline = parse_pipeline(async_client, test_dict["pipeline"]) + # check if server responds as expected + got_results = [snapshot.data() async for snapshot in pipeline.stream()] + if expected_results: + assert got_results == expected_results + if expected_count is not None: + assert len(got_results) == expected_count + + +################################################################################# +# Helpers & Fixtures +################################################################################# + + +def parse_pipeline(client, pipeline: list[dict[str, Any], str]): + """ + parse a yaml list of pipeline stages into firestore._pipeline_stages.Stage classes + """ + result_list = [] + for stage in pipeline: + # stage will be either a map of the stage_name and its args, or just the stage_name itself + stage_name: str = stage if isinstance(stage, str) else list(stage.keys())[0] + stage_cls: type[stages.Stage] = getattr(stages, stage_name) + # find arguments if given + if isinstance(stage, dict): + stage_yaml_args = stage[stage_name] + stage_obj = _apply_yaml_args(stage_cls, client, stage_yaml_args) + else: + # yaml has no arguments + stage_obj = stage_cls() + result_list.append(stage_obj) + return client._pipeline_cls._create_with_stages(client, *result_list) + + +def _parse_expressions(client, yaml_element: Any): + """ + Turn yaml objects into pipeline expressions or native python object arguments + """ + if isinstance(yaml_element, list): + return [_parse_expressions(client, v) for v in yaml_element] + elif isinstance(yaml_element, dict): + if len(yaml_element) == 1 and _is_expr_string(next(iter(yaml_element))): + # build pipeline expressions if possible + cls_str = next(iter(yaml_element)) + cls = getattr(pipeline_expressions, cls_str) + yaml_args = yaml_element[cls_str] + return _apply_yaml_args(cls, client, yaml_args) + elif len(yaml_element) == 1 and _is_stage_string(next(iter(yaml_element))): + # build pipeline stage if possible (eg, for SampleOptions) + cls_str = next(iter(yaml_element)) + cls = getattr(stages, cls_str) + yaml_args = yaml_element[cls_str] + return _apply_yaml_args(cls, client, yaml_args) + elif len(yaml_element) == 1 and list(yaml_element)[0] == "Pipeline": + # find Pipeline objects for Union expressions + other_ppl = yaml_element["Pipeline"] + return parse_pipeline(client, other_ppl) + else: + # otherwise, return dict + return { + _parse_expressions(client, k): _parse_expressions(client, v) + for k, v in yaml_element.items() + } + elif _is_expr_string(yaml_element): + return getattr(pipeline_expressions, yaml_element)() + else: + return yaml_element + + +def _apply_yaml_args(cls, client, yaml_args): + """ + Helper to instantiate a class with yaml arguments. The arguments will be applied + as positional or keyword arguments, based on type + """ + if isinstance(yaml_args, dict): + return cls(**_parse_expressions(client, yaml_args)) + elif isinstance(yaml_args, list): + # yaml has an array of arguments. Treat as args + return cls(*_parse_expressions(client, yaml_args)) + else: + # yaml has a single argument + return cls(_parse_expressions(client, yaml_args)) + + +def _is_expr_string(yaml_str): + """ + Returns true if a string represents a class in pipeline_expressions + """ + return ( + isinstance(yaml_str, str) + and yaml_str[0].isupper() + and hasattr(pipeline_expressions, yaml_str) + ) + + +def _is_stage_string(yaml_str): + """ + Returns true if a string represents a class in pipeline_stages + """ + return ( + isinstance(yaml_str, str) + and yaml_str[0].isupper() + and hasattr(stages, yaml_str) + ) + + +@pytest.fixture(scope="module") +def event_loop(): + """Change event_loop fixture to module level.""" + import asyncio + + policy = asyncio.get_event_loop_policy() + loop = policy.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="module") +def client(): + """ + Build a client to use for requests + """ + client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_ENTERPRISE_DB) + data = yaml_loader("data") + try: + # setup data + batch = client.batch() + for collection_name, documents in data.items(): + collection_ref = client.collection(collection_name) + for document_id, document_data in documents.items(): + document_ref = collection_ref.document(document_id) + batch.set(document_ref, document_data) + batch.commit() + yield client + finally: + # clear data + for collection_name, documents in data.items(): + collection_ref = client.collection(collection_name) + for document_id in documents: + document_ref = collection_ref.document(document_id) + document_ref.delete() + + +@pytest.fixture(scope="module") +def async_client(client): + """ + Build an async client to use for AsyncPipeline requests + """ + yield AsyncClient(project=client.project, database=client._database) diff --git a/tests/system/test_system.py b/tests/system/test_system.py index c66340de1..9909fb05e 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -38,11 +38,11 @@ EMULATOR_CREDS, FIRESTORE_CREDS, FIRESTORE_EMULATOR, - FIRESTORE_OTHER_DB, FIRESTORE_PROJECT, MISSING_DOCUMENT, RANDOM_ID_REGEX, UNIQUE_RESOURCE_ID, + TEST_DATABASES, ) @@ -80,13 +80,13 @@ def cleanup(): operation() -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collections(client, database): collections = list(client.collections()) assert isinstance(collections, list) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB]) +@pytest.mark.parametrize("database", TEST_DATABASES) def test_collections_w_import(database): from google.cloud import firestore @@ -103,7 +103,7 @@ def test_collections_w_import(database): FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collection_stream_or_get_w_no_explain_options(database, query_docs, method): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -125,7 +125,7 @@ def test_collection_stream_or_get_w_no_explain_options(database, query_docs, met FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["get", "stream"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collection_stream_or_get_w_explain_options_analyze_false( database, method, query_docs ): @@ -163,7 +163,7 @@ def test_collection_stream_or_get_w_explain_options_analyze_false( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["get", "stream"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collection_stream_or_get_w_explain_options_analyze_true( database, method, query_docs ): @@ -217,7 +217,7 @@ def test_collection_stream_or_get_w_explain_options_analyze_true( assert len(execution_stats.debug_stats) > 0 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collections_w_read_time(client, cleanup, database): first_collection_id = "doc-create" + UNIQUE_RESOURCE_ID first_document_id = "doc" + UNIQUE_RESOURCE_ID @@ -228,7 +228,6 @@ def test_collections_w_read_time(client, cleanup, database): data = {"status": "new"} write_result = first_document.create(data) read_time = write_result.update_time - num_collections = len(list(client.collections())) second_collection_id = "doc-create" + UNIQUE_RESOURCE_ID + "-2" second_document_id = "doc" + UNIQUE_RESOURCE_ID + "-2" @@ -238,7 +237,6 @@ def test_collections_w_read_time(client, cleanup, database): # Test that listing current collections does have the second id. curr_collections = list(client.collections()) - assert len(curr_collections) > num_collections ids = [collection.id for collection in curr_collections] assert second_collection_id in ids assert first_collection_id in ids @@ -250,7 +248,7 @@ def test_collections_w_read_time(client, cleanup, database): assert first_collection_id in ids -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_create_document(client, cleanup, database): now = datetime.datetime.now(tz=datetime.timezone.utc) collection_id = "doc-create" + UNIQUE_RESOURCE_ID @@ -295,7 +293,7 @@ def test_create_document(client, cleanup, database): assert stored_data == expected_data -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_create_document_w_vector(client, cleanup, database): collection_id = "doc-create" + UNIQUE_RESOURCE_ID document1 = client.document(collection_id, "doc1") @@ -326,7 +324,7 @@ def on_snapshot(docs, changes, read_time): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -355,7 +353,7 @@ def test_vector_search_collection(client, database, distance_measure): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -384,7 +382,7 @@ def test_vector_search_collection_with_filter(client, database, distance_measure @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_search_collection_with_distance_parameters_euclid(client, database): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" @@ -414,7 +412,7 @@ def test_vector_search_collection_with_distance_parameters_euclid(client, databa @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_search_collection_with_distance_parameters_cosine(client, database): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" @@ -444,7 +442,7 @@ def test_vector_search_collection_with_distance_parameters_cosine(client, databa @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -480,7 +478,7 @@ def test_vector_search_collection_group(client, database, distance_measure): DistanceMeasure.COSINE, ], ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_search_collection_group_with_filter(client, database, distance_measure): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" @@ -502,7 +500,7 @@ def test_vector_search_collection_group_with_filter(client, database, distance_m @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_search_collection_group_with_distance_parameters_euclid( client, database ): @@ -534,7 +532,7 @@ def test_vector_search_collection_group_with_distance_parameters_euclid( @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_search_collection_group_with_distance_parameters_cosine( client, database ): @@ -569,7 +567,7 @@ def test_vector_search_collection_group_with_distance_parameters_cosine( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_query_stream_or_get_w_no_explain_options(client, database, method): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -599,7 +597,7 @@ def test_vector_query_stream_or_get_w_no_explain_options(client, database, metho FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_query_stream_or_get_w_explain_options_analyze_true( client, database, method ): @@ -668,7 +666,7 @@ def test_vector_query_stream_or_get_w_explain_options_analyze_true( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_query_stream_or_get_w_explain_options_analyze_false( client, database, method ): @@ -715,7 +713,7 @@ def test_vector_query_stream_or_get_w_explain_options_analyze_false( explain_metrics.execution_stats -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_create_document_w_subcollection(client, cleanup, database): collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID document_id = "doc" + UNIQUE_RESOURCE_ID @@ -741,7 +739,7 @@ def assert_timestamp_less(timestamp_pb1, timestamp_pb2): assert timestamp_pb1 < timestamp_pb2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_collections_w_read_time(client, cleanup, database): collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID document_id = "doc" + UNIQUE_RESOURCE_ID @@ -777,7 +775,7 @@ def test_document_collections_w_read_time(client, cleanup, database): ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_no_document(client, database): document_id = "no_document" + UNIQUE_RESOURCE_ID document = client.document("abcde", document_id) @@ -785,7 +783,7 @@ def test_no_document(client, database): assert snapshot.to_dict() is None -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_set(client, cleanup, database): document_id = "for-set" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -815,7 +813,7 @@ def test_document_set(client, cleanup, database): assert snapshot2.update_time == write_result2.update_time -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_integer_field(client, cleanup, database): document_id = "for-set" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -832,7 +830,7 @@ def test_document_integer_field(client, cleanup, database): assert snapshot.to_dict() == expected -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_set_merge(client, cleanup, database): document_id = "for-set" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -865,7 +863,7 @@ def test_document_set_merge(client, cleanup, database): assert snapshot2.update_time == write_result2.update_time -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_set_w_int_field(client, cleanup, database): document_id = "set-int-key" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -889,7 +887,7 @@ def test_document_set_w_int_field(client, cleanup, database): assert snapshot1.to_dict() == data -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_update_w_int_field(client, cleanup, database): # Attempt to reproduce #5489. document_id = "update-int-key" + UNIQUE_RESOURCE_ID @@ -917,7 +915,7 @@ def test_document_update_w_int_field(client, cleanup, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137867104") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_update_document(client, cleanup, database): document_id = "for-update" + UNIQUE_RESOURCE_ID document = client.document("made", document_id) @@ -989,7 +987,7 @@ def check_snapshot(snapshot, document, data, write_result): assert snapshot.update_time == write_result.update_time -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_get(client, cleanup, database): now = datetime.datetime.now(tz=datetime.timezone.utc) document_id = "for-get" + UNIQUE_RESOURCE_ID @@ -1015,7 +1013,7 @@ def test_document_get(client, cleanup, database): check_snapshot(snapshot, document, data, write_result) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_delete(client, cleanup, database): document_id = "deleted" + UNIQUE_RESOURCE_ID document = client.document("here-to-be", document_id) @@ -1052,7 +1050,7 @@ def test_document_delete(client, cleanup, database): assert_timestamp_less(delete_time3, delete_time4) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collection_add(client, cleanup, database): # TODO(microgen): list_documents is returning a generator, not a list. # Consider if this is desired. Also, Document isn't hashable. @@ -1141,7 +1139,7 @@ def test_collection_add(client, cleanup, database): assert set(collection3.list_documents()) == {document_ref5} -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_list_collections_with_read_time(client, cleanup, database): # TODO(microgen): list_documents is returning a generator, not a list. # Consider if this is desired. Also, Document isn't hashable. @@ -1166,7 +1164,7 @@ def test_list_collections_with_read_time(client, cleanup, database): } -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_unicode_doc(client, cleanup, database): collection_id = "coll-unicode" + UNIQUE_RESOURCE_ID collection = client.collection(collection_id) @@ -1233,7 +1231,7 @@ def query(collection): return collection.where(filter=FieldFilter("a", "==", 1)) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_legacy_where(query_docs, database): """Assert the legacy code still works and returns value""" collection, stored, allowed_vals = query_docs @@ -1249,7 +1247,7 @@ def test_query_stream_legacy_where(query_docs, database): assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_simple_field_eq_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("a", "==", 1)) @@ -1260,7 +1258,7 @@ def test_query_stream_w_simple_field_eq_op(query_docs, database): assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_simple_field_array_contains_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("c", "array_contains", 1)) @@ -1271,7 +1269,7 @@ def test_query_stream_w_simple_field_array_contains_op(query_docs, database): assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_simple_field_in_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1283,7 +1281,7 @@ def test_query_stream_w_simple_field_in_op(query_docs, database): assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_not_eq_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", "!=", 4)) @@ -1305,7 +1303,7 @@ def test_query_stream_w_not_eq_op(query_docs, database): assert expected_ab_pairs == ab_pairs2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_simple_not_in_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1317,7 +1315,7 @@ def test_query_stream_w_simple_not_in_op(query_docs, database): assert len(values) == 22 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1331,7 +1329,7 @@ def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database) assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_order_by(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.order_by("b", direction=firestore.Query.DESCENDING) @@ -1345,7 +1343,7 @@ def test_query_stream_w_order_by(query_docs, database): assert sorted(b_vals, reverse=True) == b_vals -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_field_path(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", ">", 4)) @@ -1367,7 +1365,7 @@ def test_query_stream_w_field_path(query_docs, database): assert expected_ab_pairs == ab_pairs2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_start_end_cursor(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1383,7 +1381,7 @@ def test_query_stream_w_start_end_cursor(query_docs, database): assert value["a"] == num_vals - 2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_wo_results(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1392,7 +1390,7 @@ def test_query_stream_wo_results(query_docs, database): assert len(values) == 0 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_projection(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1409,7 +1407,7 @@ def test_query_stream_w_projection(query_docs, database): assert expected == value -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_multiple_filters(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.product", ">", 5)).where( @@ -1429,7 +1427,7 @@ def test_query_stream_w_multiple_filters(query_docs, database): assert pair in matching_pairs -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_offset(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1449,7 +1447,7 @@ def test_query_stream_w_offset(query_docs, database): FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -1471,7 +1469,7 @@ def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_or_get_w_explain_options_analyze_true( query_docs, database, method ): @@ -1531,7 +1529,7 @@ def test_query_stream_or_get_w_explain_options_analyze_true( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_or_get_w_explain_options_analyze_false( query_docs, database, method ): @@ -1571,7 +1569,7 @@ def test_query_stream_or_get_w_explain_options_analyze_false( explain_metrics.execution_stats -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_read_time(query_docs, cleanup, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1609,7 +1607,7 @@ def test_query_stream_w_read_time(query_docs, cleanup, database): assert new_values[new_ref.id] == new_data -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_with_order_dot_key(client, cleanup, database): db = client collection_id = "collek" + UNIQUE_RESOURCE_ID @@ -1622,15 +1620,16 @@ def test_query_with_order_dot_key(client, cleanup, database): query = collection.order_by("wordcount.page1").limit(3) data = [doc.to_dict()["wordcount"]["page1"] for doc in query.stream()] assert [100, 110, 120] == data - for snapshot in collection.order_by("wordcount.page1").limit(3).stream(): + query2 = collection.order_by("wordcount.page1").limit(3) + for snapshot in query2.stream(): last_value = snapshot.get("wordcount.page1") cursor_with_nested_keys = {"wordcount": {"page1": last_value}} - found = list( + query3 = ( collection.order_by("wordcount.page1") .start_after(cursor_with_nested_keys) .limit(3) - .stream() ) + found = list(query3.stream()) found_data = [ {"count": 30, "wordcount": {"page1": 130}}, {"count": 40, "wordcount": {"page1": 140}}, @@ -1638,16 +1637,16 @@ def test_query_with_order_dot_key(client, cleanup, database): ] assert found_data == [snap.to_dict() for snap in found] cursor_with_dotted_paths = {"wordcount.page1": last_value} - cursor_with_key_data = list( + query4 = ( collection.order_by("wordcount.page1") .start_after(cursor_with_dotted_paths) .limit(3) - .stream() ) + cursor_with_key_data = list(query4.stream()) assert found_data == [snap.to_dict() for snap in cursor_with_key_data] -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_unary(client, cleanup, database): collection_name = "unary" + UNIQUE_RESOURCE_ID collection = client.collection(collection_name) @@ -1702,7 +1701,7 @@ def test_query_unary(client, cleanup, database): assert snapshot3.to_dict() == {field_name: 123} -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collection_group_queries(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1735,7 +1734,7 @@ def test_collection_group_queries(client, cleanup, database): assert found == expected -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collection_group_queries_startat_endat(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1778,7 +1777,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database): assert found == set(["cg-doc2"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collection_group_queries_filters(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1847,7 +1846,7 @@ def test_collection_group_queries_filters(client, cleanup, database): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="PartitionQuery not implemented in emulator" ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_partition_query_no_partitions(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1882,7 +1881,7 @@ def test_partition_query_no_partitions(client, cleanup, database): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="PartitionQuery not implemented in emulator" ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_partition_query(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID n_docs = 128 * 2 + 127 # Minimum partition size is 128 @@ -1910,7 +1909,7 @@ def test_partition_query(client, cleanup, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137865992") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_get_all(client, cleanup, database): collection_name = "get-all" + UNIQUE_RESOURCE_ID @@ -1986,7 +1985,7 @@ def test_get_all(client, cleanup, database): check_snapshot(snapshot3, document3, data3, write_result3) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_batch(client, cleanup, database): collection_name = "batch" + UNIQUE_RESOURCE_ID @@ -2032,7 +2031,7 @@ def test_batch(client, cleanup, database): assert not document3.get().exists -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_live_bulk_writer(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriter from google.cloud.firestore_v1.client import Client @@ -2056,7 +2055,7 @@ def test_live_bulk_writer(client, cleanup, database): assert len(col.get()) == 50 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_watch_document(client, cleanup, database): db = client collection_ref = db.collection("wd-users" + UNIQUE_RESOURCE_ID) @@ -2093,7 +2092,7 @@ def on_snapshot(docs, changes, read_time): ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_watch_collection(client, cleanup, database): db = client collection_ref = db.collection("wc-users" + UNIQUE_RESOURCE_ID) @@ -2130,7 +2129,7 @@ def on_snapshot(docs, changes, read_time): ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_watch_query(client, cleanup, database): db = client collection_ref = db.collection("wq-users" + UNIQUE_RESOURCE_ID) @@ -2148,9 +2147,8 @@ def on_snapshot(docs, changes, read_time): on_snapshot.called_count += 1 # A snapshot should return the same thing as if a query ran now. - query_ran = collection_ref.where( - filter=FieldFilter("first", "==", "Ada") - ).stream() + query_ran_query = collection_ref.where(filter=FieldFilter("first", "==", "Ada")) + query_ran = query_ran_query.stream() assert len(docs) == len([i for i in query_ran]) on_snapshot.called_count = 0 @@ -2172,7 +2170,7 @@ def on_snapshot(docs, changes, read_time): ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_array_union(client, cleanup, database): doc_ref = client.document("gcp-7523", "test-document") cleanup(doc_ref.delete) @@ -2319,7 +2317,7 @@ def _do_recursive_delete(client, bulk_writer, empty_philosophers=False): ), f"Snapshot at Socrates{path} should have been deleted" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_recursive_delete_parallelized(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2327,7 +2325,7 @@ def test_recursive_delete_parallelized(client, cleanup, database): _do_recursive_delete(client, bw) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_recursive_delete_serialized(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2335,7 +2333,7 @@ def test_recursive_delete_serialized(client, cleanup, database): _do_recursive_delete(client, bw) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_recursive_delete_parallelized_empty(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2343,7 +2341,7 @@ def test_recursive_delete_parallelized_empty(client, cleanup, database): _do_recursive_delete(client, bw, empty_philosophers=True) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_recursive_delete_serialized_empty(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2351,12 +2349,13 @@ def test_recursive_delete_serialized_empty(client, cleanup, database): _do_recursive_delete(client, bw, empty_philosophers=True) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_recursive_query(client, cleanup, database): col_id: str = f"philosophers-recursive-query{UNIQUE_RESOURCE_ID}" _persist_documents(client, col_id, philosophers_data_set, cleanup) - ids = [doc.id for doc in client.collection_group(col_id).recursive().get()] + query = client.collection_group(col_id).recursive() + ids = [doc.id for doc in query.get()] expected_ids = [ # Aristotle doc and subdocs @@ -2390,14 +2389,15 @@ def test_recursive_query(client, cleanup, database): assert ids[index] == expected_ids[index], error_msg -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_nested_recursive_query(client, cleanup, database): col_id: str = f"philosophers-nested-recursive-query{UNIQUE_RESOURCE_ID}" _persist_documents(client, col_id, philosophers_data_set, cleanup) collection_ref = client.collection(col_id) aristotle = collection_ref.document("Aristotle") - ids = [doc.id for doc in aristotle.collection("pets").recursive().get()] + query = aristotle.collection("pets").recursive() + ids = [doc.id for doc in query.get()] expected_ids = [ # Aristotle pets @@ -2414,7 +2414,7 @@ def test_nested_recursive_query(client, cleanup, database): assert ids[index] == expected_ids[index], error_msg -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_chunked_query(client, cleanup, database): col = client.collection(f"chunked-test{UNIQUE_RESOURCE_ID}") for index in range(10): @@ -2429,7 +2429,7 @@ def test_chunked_query(client, cleanup, database): assert len(next(iter)) == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_chunked_query_smaller_limit(client, cleanup, database): col = client.collection(f"chunked-test-smaller-limit{UNIQUE_RESOURCE_ID}") for index in range(10): @@ -2441,7 +2441,7 @@ def test_chunked_query_smaller_limit(client, cleanup, database): assert len(next(iter)) == 5 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_chunked_and_recursive(client, cleanup, database): col_id = f"chunked-recursive-test{UNIQUE_RESOURCE_ID}" documents = [ @@ -2490,7 +2490,7 @@ def test_chunked_and_recursive(client, cleanup, database): assert [doc.id for doc in next(iter)] == page_3_ids -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_watch_query_order(client, cleanup, database): db = client collection_ref = db.collection("users") @@ -2566,7 +2566,7 @@ def on_snapshot(docs, changes, read_time): ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_repro_429(client, cleanup, database): # See: https://github.com/googleapis/python-firestore/issues/429 now = datetime.datetime.now(tz=datetime.timezone.utc) @@ -2594,7 +2594,7 @@ def test_repro_429(client, cleanup, database): print(f"id: {snapshot.id}") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_repro_391(client, cleanup, database): # See: https://github.com/googleapis/python-firestore/issues/391 now = datetime.datetime.now(tz=datetime.timezone.utc) @@ -2609,7 +2609,7 @@ def test_repro_391(client, cleanup, database): assert len(set(collection.stream())) == len(document_ids) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_get_default_alias(query, database): count_query = query.count() result = count_query.get() @@ -2618,7 +2618,7 @@ def test_count_query_get_default_alias(query, database): assert r.alias == "field_1" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_get_with_alias(query, database): count_query = query.count(alias="total") result = count_query.get() @@ -2627,7 +2627,7 @@ def test_count_query_get_with_alias(query, database): assert r.alias == "total" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_get_with_limit(query, database): # count without limit count_query = query.count(alias="total") @@ -2647,7 +2647,7 @@ def test_count_query_get_with_limit(query, database): assert r.value == 2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_get_multiple_aggregations(query, database): count_query = query.count(alias="total").count(alias="all") @@ -2662,7 +2662,7 @@ def test_count_query_get_multiple_aggregations(query, database): assert found_alias == set(expected_aliases) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_get_multiple_aggregations_duplicated_alias(query, database): count_query = query.count(alias="total").count(alias="total") @@ -2672,7 +2672,7 @@ def test_count_query_get_multiple_aggregations_duplicated_alias(query, database) assert "Aggregation aliases contain duplicate alias" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_get_empty_aggregation(query, database): from google.cloud.firestore_v1.aggregation import AggregationQuery @@ -2684,7 +2684,7 @@ def test_count_query_get_empty_aggregation(query, database): assert "Aggregations can not be empty" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_stream_default_alias(query, database): count_query = query.count() for result in count_query.stream(): @@ -2692,7 +2692,7 @@ def test_count_query_stream_default_alias(query, database): assert aggregation_result.alias == "field_1" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_stream_with_alias(query, database): count_query = query.count(alias="total") for result in count_query.stream(): @@ -2700,7 +2700,7 @@ def test_count_query_stream_with_alias(query, database): assert aggregation_result.alias == "total" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_stream_with_limit(query, database): # count without limit count_query = query.count(alias="total") @@ -2718,7 +2718,7 @@ def test_count_query_stream_with_limit(query, database): assert aggregation_result.value == 2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_stream_multiple_aggregations(query, database): count_query = query.count(alias="total").count(alias="all") @@ -2727,7 +2727,7 @@ def test_count_query_stream_multiple_aggregations(query, database): assert aggregation_result.alias in ["total", "all"] -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_stream_multiple_aggregations_duplicated_alias(query, database): count_query = query.count(alias="total").count(alias="total") @@ -2738,7 +2738,7 @@ def test_count_query_stream_multiple_aggregations_duplicated_alias(query, databa assert "Aggregation aliases contain duplicate alias" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_stream_empty_aggregation(query, database): from google.cloud.firestore_v1.aggregation import AggregationQuery @@ -2751,7 +2751,7 @@ def test_count_query_stream_empty_aggregation(query, database): assert "Aggregations can not be empty" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_with_start_at(query, database): """ Ensure that count aggregation queries work when chained with a start_at @@ -2770,7 +2770,7 @@ def test_count_query_with_start_at(query, database): assert aggregation_result.value == expected_count -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_get_default_alias(collection, database): sum_query = collection.sum("stats.product") result = sum_query.get() @@ -2780,7 +2780,7 @@ def test_sum_query_get_default_alias(collection, database): assert r.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_get_with_alias(collection, database): sum_query = collection.sum("stats.product", alias="total") result = sum_query.get() @@ -2790,7 +2790,7 @@ def test_sum_query_get_with_alias(collection, database): assert r.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_get_with_limit(collection, database): # sum without limit sum_query = collection.sum("stats.product", alias="total") @@ -2811,7 +2811,7 @@ def test_sum_query_get_with_limit(collection, database): assert r.value == 5 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_get_multiple_aggregations(collection, database): sum_query = collection.sum("stats.product", alias="total").sum( "stats.product", alias="all" @@ -2828,7 +2828,7 @@ def test_sum_query_get_multiple_aggregations(collection, database): assert found_alias == set(expected_aliases) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_stream_default_alias(collection, database): sum_query = collection.sum("stats.product") for result in sum_query.stream(): @@ -2837,7 +2837,7 @@ def test_sum_query_stream_default_alias(collection, database): assert aggregation_result.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_stream_with_alias(collection, database): sum_query = collection.sum("stats.product", alias="total") for result in sum_query.stream(): @@ -2846,7 +2846,7 @@ def test_sum_query_stream_with_alias(collection, database): assert aggregation_result.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_stream_with_limit(collection, database): # sum without limit sum_query = collection.sum("stats.product", alias="total") @@ -2864,7 +2864,7 @@ def test_sum_query_stream_with_limit(collection, database): assert aggregation_result.value == 5 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_stream_multiple_aggregations(collection, database): sum_query = collection.sum("stats.product", alias="total").sum( "stats.product", alias="all" @@ -2878,7 +2878,7 @@ def test_sum_query_stream_multiple_aggregations(collection, database): # tests for issue reported in b/306241058 # we will skip test in client for now, until backend fix is implemented @pytest.mark.skip(reason="backend fix required") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_with_start_at(query, database): """ Ensure that sum aggregation queries work when chained with a start_at @@ -2896,7 +2896,7 @@ def test_sum_query_with_start_at(query, database): assert sum_result[0].value == expected_sum -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_get_default_alias(collection, database): avg_query = collection.avg("stats.product") result = avg_query.get() @@ -2907,7 +2907,7 @@ def test_avg_query_get_default_alias(collection, database): assert isinstance(r.value, float) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_get_with_alias(collection, database): avg_query = collection.avg("stats.product", alias="total") result = avg_query.get() @@ -2917,7 +2917,7 @@ def test_avg_query_get_with_alias(collection, database): assert r.value == 4 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_get_with_limit(collection, database): # avg without limit avg_query = collection.avg("stats.product", alias="total") @@ -2939,7 +2939,7 @@ def test_avg_query_get_with_limit(collection, database): assert isinstance(r.value, float) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_get_multiple_aggregations(collection, database): avg_query = collection.avg("stats.product", alias="total").avg( "stats.product", alias="all" @@ -2956,7 +2956,7 @@ def test_avg_query_get_multiple_aggregations(collection, database): assert found_alias == set(expected_aliases) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_stream_default_alias(collection, database): avg_query = collection.avg("stats.product") for result in avg_query.stream(): @@ -2965,7 +2965,7 @@ def test_avg_query_stream_default_alias(collection, database): assert aggregation_result.value == 4 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_stream_with_alias(collection, database): avg_query = collection.avg("stats.product", alias="total") for result in avg_query.stream(): @@ -2974,7 +2974,7 @@ def test_avg_query_stream_with_alias(collection, database): assert aggregation_result.value == 4 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_stream_with_limit(collection, database): # avg without limit avg_query = collection.avg("stats.product", alias="total") @@ -2992,7 +2992,7 @@ def test_avg_query_stream_with_limit(collection, database): assert aggregation_result.value == 5 / 12 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_stream_multiple_aggregations(collection, database): avg_query = collection.avg("stats.product", alias="total").avg( "stats.product", alias="all" @@ -3006,7 +3006,7 @@ def test_avg_query_stream_multiple_aggregations(collection, database): # tests for issue reported in b/306241058 # we will skip test in client for now, until backend fix is implemented @pytest.mark.skip(reason="backend fix required") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_with_start_at(query, database): """ Ensure that avg aggregation queries work when chained with a start_at @@ -3030,7 +3030,7 @@ def test_avg_query_with_start_at(query, database): FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_aggregation_query_stream_or_get_w_no_explain_options(query, database, method): # Because all aggregation methods end up calling AggregationQuery.get() or # AggregationQuery.stream(), only use count() for testing here. @@ -3056,7 +3056,7 @@ def test_aggregation_query_stream_or_get_w_no_explain_options(query, database, m FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_aggregation_query_stream_or_get_w_explain_options_analyze_true( query, database, method ): @@ -3120,7 +3120,7 @@ def test_aggregation_query_stream_or_get_w_explain_options_analyze_true( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_aggregation_query_stream_or_get_w_explain_options_analyze_false( query, database, method ): @@ -3160,7 +3160,7 @@ def test_aggregation_query_stream_or_get_w_explain_options_analyze_false( explain_metrics.execution_stats -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_with_and_composite_filter(collection, database): and_filter = And( filters=[ @@ -3175,7 +3175,7 @@ def test_query_with_and_composite_filter(collection, database): assert result.get("stats.product") < 10 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_with_or_composite_filter(collection, database): or_filter = Or( filters=[ @@ -3198,12 +3198,17 @@ def test_query_with_or_composite_filter(collection, database): assert lt_10 > 0 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "aggregation_type,expected_value", [("count", 5), ("sum", 100), ("avg", 4.0)] ) def test_aggregation_queries_with_read_time( - collection, query, cleanup, database, aggregation_type, expected_value + collection, + query, + cleanup, + database, + aggregation_type, + expected_value, ): """ Ensure that all aggregation queries work when read_time is passed into @@ -3240,7 +3245,7 @@ def test_aggregation_queries_with_read_time( assert r.value == expected_value -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_with_complex_composite_filter(collection, database): field_filter = FieldFilter("b", "==", 0) or_filter = Or( @@ -3289,9 +3294,14 @@ def test_query_with_complex_composite_filter(collection, database): "aggregation_type,aggregation_args,expected", [("count", (), 3), ("sum", ("b"), 12), ("avg", ("b"), 4)], ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_aggregation_query_in_transaction( - client, cleanup, database, aggregation_type, aggregation_args, expected + client, + cleanup, + database, + aggregation_type, + aggregation_args, + expected, ): """ Test creating an aggregation query inside a transaction @@ -3331,7 +3341,7 @@ def in_transaction(transaction): assert inner_fn_ran is True -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_or_query_in_transaction(client, cleanup, database): """ Test running or query inside a transaction. Should pass transaction id along with request @@ -3376,7 +3386,7 @@ def in_transaction(transaction): assert inner_fn_ran is True -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_transaction_w_uuid(client, cleanup, database): """ https://github.com/googleapis/python-firestore/issues/1012 @@ -3401,7 +3411,7 @@ def update_doc(tx, doc_ref, key, value): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_in_transaction_with_explain_options(client, cleanup, database): """ Test query profiling in transactions. @@ -3453,7 +3463,7 @@ def in_transaction(transaction): assert inner_fn_ran is True -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_in_transaction_with_read_time(client, cleanup, database): """ Test query profiling in transactions. @@ -3499,7 +3509,7 @@ def in_transaction(transaction): assert inner_fn_ran is True -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_update_w_uuid(client, cleanup, database): """ https://github.com/googleapis/python-firestore/issues/1012 @@ -3518,7 +3528,7 @@ def test_update_w_uuid(client, cleanup, database): @pytest.mark.parametrize("with_rollback,expected", [(True, 2), (False, 3)]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_transaction_rollback(client, cleanup, database, with_rollback, expected): """ Create a document in a transaction that is rolled back diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 945e7cb12..bc79ee2df 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -49,11 +49,11 @@ EMULATOR_CREDS, FIRESTORE_CREDS, FIRESTORE_EMULATOR, - FIRESTORE_OTHER_DB, FIRESTORE_PROJECT, MISSING_DOCUMENT, RANDOM_ID_REGEX, UNIQUE_RESOURCE_ID, + TEST_DATABASES, ) RETRIES = retries.AsyncRetry( @@ -169,13 +169,13 @@ def event_loop(): loop.close() -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collections(client, database): collections = [x async for x in client.collections(retry=RETRIES)] assert isinstance(collections, list) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB]) +@pytest.mark.parametrize("database", TEST_DATABASES) async def test_collections_w_import(database): from google.cloud import firestore @@ -188,7 +188,7 @@ async def test_collections_w_import(database): assert isinstance(collections, list) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_create_document(client, cleanup, database): now = datetime.datetime.now(tz=datetime.timezone.utc) collection_id = "doc-create" + UNIQUE_RESOURCE_ID @@ -234,7 +234,7 @@ async def test_create_document(client, cleanup, database): assert stored_data == expected_data -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collections_w_read_time(client, cleanup, database): first_collection_id = "doc-create" + UNIQUE_RESOURCE_ID first_document_id = "doc" + UNIQUE_RESOURCE_ID @@ -245,7 +245,6 @@ async def test_collections_w_read_time(client, cleanup, database): data = {"status": "new"} write_result = await first_document.create(data) read_time = write_result.update_time - num_collections = len([x async for x in client.collections(retry=RETRIES)]) second_collection_id = "doc-create" + UNIQUE_RESOURCE_ID + "-2" second_document_id = "doc" + UNIQUE_RESOURCE_ID + "-2" @@ -255,7 +254,6 @@ async def test_collections_w_read_time(client, cleanup, database): # Test that listing current collections does have the second id. curr_collections = [x async for x in client.collections(retry=RETRIES)] - assert len(curr_collections) > num_collections ids = [collection.id for collection in curr_collections] assert second_collection_id in ids assert first_collection_id in ids @@ -269,7 +267,7 @@ async def test_collections_w_read_time(client, cleanup, database): assert first_collection_id in ids -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_create_document_w_subcollection(client, cleanup, database): collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID document_id = "doc" + UNIQUE_RESOURCE_ID @@ -295,7 +293,7 @@ def assert_timestamp_less(timestamp_pb1, timestamp_pb2): assert timestamp_pb1 < timestamp_pb2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_collections_w_read_time(client, cleanup, database): collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID document_id = "doc" + UNIQUE_RESOURCE_ID @@ -331,7 +329,7 @@ async def test_document_collections_w_read_time(client, cleanup, database): ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_no_document(client, database): document_id = "no_document" + UNIQUE_RESOURCE_ID document = client.document("abcde", document_id) @@ -339,7 +337,7 @@ async def test_no_document(client, database): assert snapshot.to_dict() is None -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_set(client, cleanup, database): document_id = "for-set" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -369,7 +367,7 @@ async def test_document_set(client, cleanup, database): assert snapshot2.update_time == write_result2.update_time -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_integer_field(client, cleanup, database): document_id = "for-set" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -386,7 +384,7 @@ async def test_document_integer_field(client, cleanup, database): assert snapshot.to_dict() == expected -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_set_merge(client, cleanup, database): document_id = "for-set" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -419,7 +417,7 @@ async def test_document_set_merge(client, cleanup, database): assert snapshot2.update_time == write_result2.update_time -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_set_w_int_field(client, cleanup, database): document_id = "set-int-key" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -443,7 +441,7 @@ async def test_document_set_w_int_field(client, cleanup, database): assert snapshot1.to_dict() == data -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_update_w_int_field(client, cleanup, database): # Attempt to reproduce #5489. document_id = "update-int-key" + UNIQUE_RESOURCE_ID @@ -471,7 +469,7 @@ async def test_document_update_w_int_field(client, cleanup, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -499,7 +497,7 @@ async def test_vector_search_collection(client, database, distance_measure): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -527,7 +525,7 @@ async def test_vector_search_collection_with_filter(client, database, distance_m @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_search_collection_with_distance_parameters_euclid( client, database ): @@ -559,7 +557,7 @@ async def test_vector_search_collection_with_distance_parameters_euclid( @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_search_collection_with_distance_parameters_cosine( client, database ): @@ -591,7 +589,7 @@ async def test_vector_search_collection_with_distance_parameters_cosine( @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -620,7 +618,7 @@ async def test_vector_search_collection_group(client, database, distance_measure @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -651,7 +649,7 @@ async def test_vector_search_collection_group_with_filter( @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_search_collection_group_with_distance_parameters_euclid( client, database ): @@ -683,7 +681,7 @@ async def test_vector_search_collection_group_with_distance_parameters_euclid( @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_search_collection_group_with_distance_parameters_cosine( client, database ): @@ -718,7 +716,7 @@ async def test_vector_search_collection_group_with_distance_parameters_cosine( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_query_stream_or_get_w_no_explain_options( client, database, method ): @@ -753,7 +751,7 @@ async def test_vector_query_stream_or_get_w_no_explain_options( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_query_stream_or_get_w_explain_options_analyze_true( client, query_docs, database, method ): @@ -832,7 +830,7 @@ async def test_vector_query_stream_or_get_w_explain_options_analyze_true( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_query_stream_or_get_w_explain_options_analyze_false( client, query_docs, database, method ): @@ -895,7 +893,7 @@ async def test_vector_query_stream_or_get_w_explain_options_analyze_false( @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137867104") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_update_document(client, cleanup, database): document_id = "for-update" + UNIQUE_RESOURCE_ID document = client.document("made", document_id) @@ -968,7 +966,7 @@ def check_snapshot(snapshot, document, data, write_result): assert snapshot.update_time == write_result.update_time -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_get(client, cleanup, database): now = datetime.datetime.now(tz=datetime.timezone.utc) document_id = "for-get" + UNIQUE_RESOURCE_ID @@ -994,7 +992,7 @@ async def test_document_get(client, cleanup, database): check_snapshot(snapshot, document, data, write_result) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_delete(client, cleanup, database): document_id = "deleted" + UNIQUE_RESOURCE_ID document = client.document("here-to-be", document_id) @@ -1031,7 +1029,7 @@ async def test_document_delete(client, cleanup, database): assert_timestamp_less(delete_time3, delete_time4) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collection_add(client, cleanup, database): # TODO(microgen): list_documents is returning a generator, not a list. # Consider if this is desired. Also, Document isn't hashable. @@ -1133,7 +1131,7 @@ async def test_collection_add(client, cleanup, database): assert set([i async for i in collection3.list_documents()]) == {document_ref5} -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_list_collections_with_read_time(client, cleanup, database): # TODO(microgen): list_documents is returning a generator, not a list. # Consider if this is desired. Also, Document isn't hashable. @@ -1205,7 +1203,7 @@ async def async_query(collection): return collection.where(filter=FieldFilter("a", "==", 1)) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_legacy_where(query_docs, database): """Assert the legacy code still works and returns value, and shows UserWarning""" collection, stored, allowed_vals = query_docs @@ -1221,7 +1219,7 @@ async def test_query_stream_legacy_where(query_docs, database): assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_simple_field_eq_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("a", "==", 1)) @@ -1232,7 +1230,7 @@ async def test_query_stream_w_simple_field_eq_op(query_docs, database): assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_simple_field_array_contains_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("c", "array_contains", 1)) @@ -1243,7 +1241,7 @@ async def test_query_stream_w_simple_field_array_contains_op(query_docs, databas assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_simple_field_in_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1255,7 +1253,7 @@ async def test_query_stream_w_simple_field_in_op(query_docs, database): assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1269,7 +1267,7 @@ async def test_query_stream_w_simple_field_array_contains_any_op(query_docs, dat assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_order_by(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.order_by("b", direction=firestore.Query.DESCENDING) @@ -1283,7 +1281,7 @@ async def test_query_stream_w_order_by(query_docs, database): assert sorted(b_vals, reverse=True) == b_vals -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_field_path(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", ">", 4)) @@ -1305,7 +1303,7 @@ async def test_query_stream_w_field_path(query_docs, database): assert expected_ab_pairs == ab_pairs2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_start_end_cursor(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1321,7 +1319,7 @@ async def test_query_stream_w_start_end_cursor(query_docs, database): assert value["a"] == num_vals - 2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_wo_results(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1330,7 +1328,7 @@ async def test_query_stream_wo_results(query_docs, database): assert len(values) == 0 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_projection(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1347,7 +1345,7 @@ async def test_query_stream_w_projection(query_docs, database): assert expected == value -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_multiple_filters(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.product", ">", 5)).where( @@ -1367,7 +1365,7 @@ async def test_query_stream_w_multiple_filters(query_docs, database): assert pair in matching_pairs -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_offset(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1387,7 +1385,7 @@ async def test_query_stream_w_offset(query_docs, database): FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -1412,7 +1410,7 @@ async def test_query_stream_or_get_w_no_explain_options(query_docs, database, me FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_or_get_w_explain_options_analyze_true( query_docs, database, method ): @@ -1457,7 +1455,7 @@ async def test_query_stream_or_get_w_explain_options_analyze_true( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_or_get_w_explain_options_analyze_false( query_docs, database, method ): @@ -1492,7 +1490,7 @@ async def test_query_stream_or_get_w_explain_options_analyze_false( _verify_explain_metrics_analyze_false(explain_metrics) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_read_time(query_docs, cleanup, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1532,7 +1530,7 @@ async def test_query_stream_w_read_time(query_docs, cleanup, database): assert new_values[new_ref.id] == new_data -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_with_order_dot_key(client, cleanup, database): db = client collection_id = "collek" + UNIQUE_RESOURCE_ID @@ -1572,7 +1570,7 @@ async def test_query_with_order_dot_key(client, cleanup, database): assert found_data == [snap.to_dict() for snap in cursor_with_key_data] -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_unary(client, cleanup, database): collection_name = "unary" + UNIQUE_RESOURCE_ID collection = client.collection(collection_name) @@ -1627,7 +1625,7 @@ async def test_query_unary(client, cleanup, database): assert snapshot3.to_dict() == {field_name: 123} -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collection_group_queries(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1660,7 +1658,7 @@ async def test_collection_group_queries(client, cleanup, database): assert found == expected -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collection_group_queries_startat_endat(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1703,7 +1701,7 @@ async def test_collection_group_queries_startat_endat(client, cleanup, database) assert found == set(["cg-doc2"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collection_group_queries_filters(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1772,7 +1770,7 @@ async def test_collection_group_queries_filters(client, cleanup, database): FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collection_stream_or_get_w_no_explain_options( query_docs, database, method ): @@ -1797,7 +1795,7 @@ async def test_collection_stream_or_get_w_no_explain_options( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collection_stream_or_get_w_explain_options_analyze_true( query_docs, database, method ): @@ -1865,7 +1863,7 @@ async def test_collection_stream_or_get_w_explain_options_analyze_true( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collection_stream_or_get_w_explain_options_analyze_false( query_docs, database, method ): @@ -1919,7 +1917,7 @@ async def test_collection_stream_or_get_w_explain_options_analyze_false( @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="PartitionQuery not implemented in emulator" ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_partition_query_no_partitions(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1953,7 +1951,7 @@ async def test_partition_query_no_partitions(client, cleanup, database): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="PartitionQuery not implemented in emulator" ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_partition_query(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID n_docs = 128 * 2 + 127 # Minimum partition size is 128 @@ -1980,7 +1978,7 @@ async def test_partition_query(client, cleanup, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137865992") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_get_all(client, cleanup, database): collection_name = "get-all" + UNIQUE_RESOURCE_ID @@ -2053,7 +2051,7 @@ async def test_get_all(client, cleanup, database): assert not snapshots[2].exists -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_live_bulk_writer(client, cleanup, database): from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.bulk_writer import BulkWriter @@ -2077,7 +2075,7 @@ async def test_live_bulk_writer(client, cleanup, database): assert len(await col.get()) == 50 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_batch(client, cleanup, database): collection_name = "batch" + UNIQUE_RESOURCE_ID @@ -2244,7 +2242,7 @@ async def _do_recursive_delete(client, bulk_writer, empty_philosophers=False): ), f"Snapshot at Socrates{path} should have been deleted" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_recursive_delete_parallelized(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2252,7 +2250,7 @@ async def test_async_recursive_delete_parallelized(client, cleanup, database): await _do_recursive_delete(client, bw) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_recursive_delete_serialized(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2260,7 +2258,7 @@ async def test_async_recursive_delete_serialized(client, cleanup, database): await _do_recursive_delete(client, bw) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_recursive_delete_parallelized_empty(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2268,7 +2266,7 @@ async def test_async_recursive_delete_parallelized_empty(client, cleanup, databa await _do_recursive_delete(client, bw, empty_philosophers=True) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_recursive_delete_serialized_empty(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2276,7 +2274,7 @@ async def test_async_recursive_delete_serialized_empty(client, cleanup, database await _do_recursive_delete(client, bw, empty_philosophers=True) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_recursive_query(client, cleanup, database): col_id: str = f"philosophers-recursive-async-query{UNIQUE_RESOURCE_ID}" await _persist_documents(client, col_id, philosophers_data_set, cleanup) @@ -2315,7 +2313,7 @@ async def test_recursive_query(client, cleanup, database): assert ids[index] == expected_ids[index], error_msg -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_nested_recursive_query(client, cleanup, database): col_id: str = f"philosophers-nested-recursive-async-query{UNIQUE_RESOURCE_ID}" await _persist_documents(client, col_id, philosophers_data_set, cleanup) @@ -2339,7 +2337,7 @@ async def test_nested_recursive_query(client, cleanup, database): assert ids[index] == expected_ids[index], error_msg -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_chunked_query(client, cleanup, database): col = client.collection(f"async-chunked-test{UNIQUE_RESOURCE_ID}") for index in range(10): @@ -2355,7 +2353,7 @@ async def test_chunked_query(client, cleanup, database): assert lengths[3] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_chunked_query_smaller_limit(client, cleanup, database): col = client.collection(f"chunked-test-smaller-limit{UNIQUE_RESOURCE_ID}") for index in range(10): @@ -2368,7 +2366,7 @@ async def test_chunked_query_smaller_limit(client, cleanup, database): assert lengths[0] == 5 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_chunked_and_recursive(client, cleanup, database): col_id = f"chunked-async-recursive-test{UNIQUE_RESOURCE_ID}" documents = [ @@ -2427,7 +2425,7 @@ async def _chain(*iterators): yield value -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_count_async_query_get_default_alias(async_query, database): count_query = async_query.count() result = await count_query.get() @@ -2435,7 +2433,7 @@ async def test_count_async_query_get_default_alias(async_query, database): assert r.alias == "field_1" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_get_with_alias(async_query, database): count_query = async_query.count(alias="total") result = await count_query.get() @@ -2443,7 +2441,7 @@ async def test_async_count_query_get_with_alias(async_query, database): assert r.alias == "total" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_get_with_limit(async_query, database): count_query = async_query.count(alias="total") result = await count_query.get() @@ -2459,7 +2457,7 @@ async def test_async_count_query_get_with_limit(async_query, database): assert r.value == 2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_get_multiple_aggregations(async_query, database): count_query = async_query.count(alias="total").count(alias="all") @@ -2474,7 +2472,7 @@ async def test_async_count_query_get_multiple_aggregations(async_query, database assert found_alias == set(expected_aliases) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_get_multiple_aggregations_duplicated_alias( async_query, database ): @@ -2486,7 +2484,7 @@ async def test_async_count_query_get_multiple_aggregations_duplicated_alias( assert "Aggregation aliases contain duplicate alias" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_get_empty_aggregation(async_query, database): from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery @@ -2498,7 +2496,7 @@ async def test_async_count_query_get_empty_aggregation(async_query, database): assert "Aggregations can not be empty" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_stream_default_alias(async_query, database): count_query = async_query.count() @@ -2507,7 +2505,7 @@ async def test_async_count_query_stream_default_alias(async_query, database): assert aggregation_result.alias == "field_1" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_stream_with_alias(async_query, database): count_query = async_query.count(alias="total") async for result in count_query.stream(): @@ -2515,7 +2513,7 @@ async def test_async_count_query_stream_with_alias(async_query, database): assert aggregation_result.alias == "total" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_stream_with_limit(async_query, database): # count without limit count_query = async_query.count(alias="total") @@ -2530,7 +2528,7 @@ async def test_async_count_query_stream_with_limit(async_query, database): assert aggregation_result.value == 2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_stream_multiple_aggregations(async_query, database): count_query = async_query.count(alias="total").count(alias="all") @@ -2540,7 +2538,7 @@ async def test_async_count_query_stream_multiple_aggregations(async_query, datab assert aggregation_result.alias in ["total", "all"] -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_stream_multiple_aggregations_duplicated_alias( async_query, database ): @@ -2553,7 +2551,7 @@ async def test_async_count_query_stream_multiple_aggregations_duplicated_alias( assert "Aggregation aliases contain duplicate alias" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_stream_empty_aggregation(async_query, database): from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery @@ -2566,7 +2564,7 @@ async def test_async_count_query_stream_empty_aggregation(async_query, database) assert "Aggregations can not be empty" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_get_default_alias(collection, database): sum_query = collection.sum("stats.product") result = await sum_query.get() @@ -2575,7 +2573,7 @@ async def test_async_sum_query_get_default_alias(collection, database): assert r.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_get_with_alias(collection, database): sum_query = collection.sum("stats.product", alias="total") result = await sum_query.get() @@ -2584,7 +2582,7 @@ async def test_async_sum_query_get_with_alias(collection, database): assert r.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_get_with_limit(collection, database): sum_query = collection.sum("stats.product", alias="total") result = await sum_query.get() @@ -2600,7 +2598,7 @@ async def test_async_sum_query_get_with_limit(collection, database): assert r.value == 5 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_get_multiple_aggregations(collection, database): sum_query = collection.sum("stats.product", alias="total").sum( "stats.product", alias="all" @@ -2617,7 +2615,7 @@ async def test_async_sum_query_get_multiple_aggregations(collection, database): assert found_alias == set(expected_aliases) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_stream_default_alias(collection, database): sum_query = collection.sum("stats.product") @@ -2627,7 +2625,7 @@ async def test_async_sum_query_stream_default_alias(collection, database): assert aggregation_result.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_stream_with_alias(collection, database): sum_query = collection.sum("stats.product", alias="total") async for result in sum_query.stream(): @@ -2635,7 +2633,7 @@ async def test_async_sum_query_stream_with_alias(collection, database): assert aggregation_result.alias == "total" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_stream_with_limit(collection, database): # sum without limit sum_query = collection.sum("stats.product", alias="total") @@ -2650,7 +2648,7 @@ async def test_async_sum_query_stream_with_limit(collection, database): assert aggregation_result.value == 5 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_stream_multiple_aggregations(collection, database): sum_query = collection.sum("stats.product", alias="total").sum( "stats.product", alias="all" @@ -2662,7 +2660,7 @@ async def test_async_sum_query_stream_multiple_aggregations(collection, database assert aggregation_result.alias in ["total", "all"] -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_default_alias(collection, database): avg_query = collection.avg("stats.product") result = await avg_query.get() @@ -2672,7 +2670,7 @@ async def test_async_avg_query_get_default_alias(collection, database): assert isinstance(r.value, float) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_with_alias(collection, database): avg_query = collection.avg("stats.product", alias="total") result = await avg_query.get() @@ -2681,7 +2679,7 @@ async def test_async_avg_query_get_with_alias(collection, database): assert r.value == 4 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_with_limit(collection, database): avg_query = collection.avg("stats.product", alias="total") result = await avg_query.get() @@ -2697,7 +2695,7 @@ async def test_async_avg_query_get_with_limit(collection, database): assert r.value == 5 / 12 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_multiple_aggregations(collection, database): avg_query = collection.avg("stats.product", alias="total").avg( "stats.product", alias="all" @@ -2714,7 +2712,7 @@ async def test_async_avg_query_get_multiple_aggregations(collection, database): assert found_alias == set(expected_aliases) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_w_no_explain_options(collection, database): avg_query = collection.avg("stats.product", alias="total") results = await avg_query.get() @@ -2725,7 +2723,7 @@ async def test_async_avg_query_get_w_no_explain_options(collection, database): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_w_explain_options_analyze_true(collection, database): avg_query = collection.avg("stats.product", alias="total") results = await avg_query.get(explain_options=ExplainOptions(analyze=True)) @@ -2760,7 +2758,7 @@ async def test_async_avg_query_get_w_explain_options_analyze_true(collection, da @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_w_explain_options_analyze_false( collection, database ): @@ -2791,7 +2789,7 @@ async def test_async_avg_query_get_w_explain_options_analyze_false( explain_metrics.execution_stats -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_default_alias(collection, database): avg_query = collection.avg("stats.product") @@ -2802,7 +2800,7 @@ async def test_async_avg_query_stream_default_alias(collection, database): assert isinstance(aggregation_result.value, float) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_with_alias(collection, database): avg_query = collection.avg("stats.product", alias="total") async for result in avg_query.stream(): @@ -2810,7 +2808,7 @@ async def test_async_avg_query_stream_with_alias(collection, database): assert aggregation_result.alias == "total" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_with_limit(collection, database): # avg without limit avg_query = collection.avg("stats.product", alias="total") @@ -2826,7 +2824,7 @@ async def test_async_avg_query_stream_with_limit(collection, database): assert isinstance(aggregation_result.value, float) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_multiple_aggregations(collection, database): avg_query = collection.avg("stats.product", alias="total").avg( "stats.product", alias="all" @@ -2838,7 +2836,7 @@ async def test_async_avg_query_stream_multiple_aggregations(collection, database assert aggregation_result.alias in ["total", "all"] -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_w_no_explain_options(collection, database): avg_query = collection.avg("stats.product", alias="total") results = avg_query.stream() @@ -2849,7 +2847,7 @@ async def test_async_avg_query_stream_w_no_explain_options(collection, database) @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_w_explain_options_analyze_true( collection, database ): @@ -2894,7 +2892,7 @@ async def test_async_avg_query_stream_w_explain_options_analyze_true( @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_w_explain_options_analyze_false( collection, database ): @@ -2926,7 +2924,7 @@ async def test_async_avg_query_stream_w_explain_options_analyze_false( explain_metrics.execution_stats -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "aggregation_type,expected_value", [("count", 5), ("sum", 100), ("avg", 4.0)] ) @@ -2989,7 +2987,7 @@ async def create_in_transaction_helper( raise ValueError("Collection can't have more than 2 docs") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_count_query_in_transaction(client, cleanup, database): collection_id = "doc-create" + UNIQUE_RESOURCE_ID document_id_1 = "doc1" + UNIQUE_RESOURCE_ID @@ -3021,7 +3019,7 @@ async def test_count_query_in_transaction(client, cleanup, database): assert r.value == 2 # there are still only 2 docs -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_with_and_composite_filter(query_docs, database): collection, stored, allowed_vals = query_docs and_filter = And( @@ -3037,7 +3035,7 @@ async def test_query_with_and_composite_filter(query_docs, database): assert result.get("stats.product") < 10 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_with_or_composite_filter(query_docs, database): collection, stored, allowed_vals = query_docs or_filter = Or( @@ -3061,7 +3059,7 @@ async def test_query_with_or_composite_filter(query_docs, database): assert lt_10 > 0 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_with_complex_composite_filter(query_docs, database): collection, stored, allowed_vals = query_docs field_filter = FieldFilter("b", "==", 0) @@ -3107,7 +3105,7 @@ async def test_query_with_complex_composite_filter(query_docs, database): assert b_not_3 is True -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_or_query_in_transaction(client, cleanup, database): collection_id = "doc-create" + UNIQUE_RESOURCE_ID document_id_1 = "doc1" + UNIQUE_RESOURCE_ID @@ -3171,7 +3169,7 @@ async def _make_transaction_query(client, cleanup): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_transaction_w_query_w_no_explain_options(client, cleanup, database): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -3204,7 +3202,7 @@ async def in_transaction(transaction): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_transaction_w_query_w_explain_options_analyze_true( client, cleanup, database ): @@ -3246,7 +3244,7 @@ async def in_transaction(transaction): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_transaction_w_query_w_explain_options_analyze_false( client, cleanup, database ): @@ -3283,7 +3281,7 @@ async def in_transaction(transaction): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_in_transaction_w_no_explain_options(client, cleanup, database): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -3316,7 +3314,7 @@ async def in_transaction(transaction): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_in_transaction_w_explain_options_analyze_true( client, cleanup, database ): @@ -3350,7 +3348,7 @@ async def in_transaction(transaction): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_in_transaction_w_explain_options_analyze_false( client, cleanup, database ): @@ -3383,7 +3381,7 @@ async def in_transaction(transaction): assert inner_fn_ran is True -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_in_transaction_with_read_time(client, cleanup, database): """ Test query profiling in transactions. diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py index 3abc3619b..47eedc983 100644 --- a/tests/unit/v1/test_async_pipeline.py +++ b/tests/unit/v1/test_async_pipeline.py @@ -16,6 +16,8 @@ import pytest from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1.pipeline_expressions import Field +from google.cloud.firestore_v1.pipeline_expressions import Exists def _make_async_pipeline(*args, client=mock.Mock()): @@ -379,8 +381,34 @@ async def test_async_pipeline_stream_stream_equivalence_mocked(): @pytest.mark.parametrize( "method,args,result_cls", [ - ("generic_stage", ("name",), stages.GenericStage), - ("generic_stage", ("name", mock.Mock()), stages.GenericStage), + ("add_fields", (Field.of("n"),), stages.AddFields), + ("remove_fields", ("name",), stages.RemoveFields), + ("remove_fields", (Field.of("n"),), stages.RemoveFields), + ("select", ("name",), stages.Select), + ("select", (Field.of("n"),), stages.Select), + ("where", (Exists(Field.of("n")),), stages.Where), + ("find_nearest", ("name", [0.1], 0), stages.FindNearest), + ( + "find_nearest", + ("name", [0.1], 0, stages.FindNearestOptions(10)), + stages.FindNearest, + ), + ("sort", (Field.of("n").descending(),), stages.Sort), + ("sort", (Field.of("n").descending(), Field.of("m").ascending()), stages.Sort), + ("sample", (10,), stages.Sample), + ("sample", (stages.SampleOptions.doc_limit(10),), stages.Sample), + ("union", (_make_async_pipeline(),), stages.Union), + ("unnest", ("field_name",), stages.Unnest), + ("unnest", ("field_name", "alias"), stages.Unnest), + ("unnest", (Field.of("n"), Field.of("alias")), stages.Unnest), + ("unnest", ("n", "a", stages.UnnestOptions("idx")), stages.Unnest), + ("generic_stage", ("stage_name",), stages.GenericStage), + ("generic_stage", ("stage_name", Field.of("n")), stages.GenericStage), + ("offset", (1,), stages.Offset), + ("limit", (1,), stages.Limit), + ("aggregate", (Field.of("n").as_("alias"),), stages.Aggregate), + ("distinct", ("field_name",), stages.Distinct), + ("distinct", (Field.of("n"), "second"), stages.Distinct), ], ) def test_async_pipeline_methods(method, args, result_cls): @@ -391,3 +419,13 @@ def test_async_pipeline_methods(method, args, result_cls): assert len(start_ppl.stages) == 0 assert len(result_ppl.stages) == 1 assert isinstance(result_ppl.stages[0], result_cls) + + +def test_async_pipeline_aggregate_with_groups(): + start_ppl = _make_async_pipeline() + result_ppl = start_ppl.aggregate(Field.of("title"), groups=[Field.of("author")]) + assert len(start_ppl.stages) == 0 + assert len(result_ppl.stages) == 1 + assert isinstance(result_ppl.stages[0], stages.Aggregate) + assert list(result_ppl.stages[0].groups) == [Field.of("author")] + assert list(result_ppl.stages[0].accumulators) == [Field.of("title")] diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index 6a3fef3ac..b237ad5ac 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -16,6 +16,8 @@ import pytest from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1.pipeline_expressions import Field +from google.cloud.firestore_v1.pipeline_expressions import Exists def _make_pipeline(*args, client=mock.Mock()): @@ -356,8 +358,34 @@ def test_pipeline_execute_stream_equivalence_mocked(): @pytest.mark.parametrize( "method,args,result_cls", [ - ("generic_stage", ("name",), stages.GenericStage), - ("generic_stage", ("name", mock.Mock()), stages.GenericStage), + ("add_fields", (Field.of("n"),), stages.AddFields), + ("remove_fields", ("name",), stages.RemoveFields), + ("remove_fields", (Field.of("n"),), stages.RemoveFields), + ("select", ("name",), stages.Select), + ("select", (Field.of("n"),), stages.Select), + ("where", (Exists(Field.of("n")),), stages.Where), + ("find_nearest", ("name", [0.1], 0), stages.FindNearest), + ( + "find_nearest", + ("name", [0.1], 0, stages.FindNearestOptions(10)), + stages.FindNearest, + ), + ("sort", (Field.of("n").descending(),), stages.Sort), + ("sort", (Field.of("n").descending(), Field.of("m").ascending()), stages.Sort), + ("sample", (10,), stages.Sample), + ("sample", (stages.SampleOptions.doc_limit(10),), stages.Sample), + ("union", (_make_pipeline(),), stages.Union), + ("unnest", ("field_name",), stages.Unnest), + ("unnest", ("field_name", "alias"), stages.Unnest), + ("unnest", (Field.of("n"), Field.of("alias")), stages.Unnest), + ("unnest", ("n", "a", stages.UnnestOptions("idx")), stages.Unnest), + ("generic_stage", ("stage_name",), stages.GenericStage), + ("generic_stage", ("stage_name", Field.of("n")), stages.GenericStage), + ("offset", (1,), stages.Offset), + ("limit", (1,), stages.Limit), + ("aggregate", (Field.of("n").as_("alias"),), stages.Aggregate), + ("distinct", ("field_name",), stages.Distinct), + ("distinct", (Field.of("n"), "second"), stages.Distinct), ], ) def test_pipeline_methods(method, args, result_cls): @@ -368,3 +396,13 @@ def test_pipeline_methods(method, args, result_cls): assert len(start_ppl.stages) == 0 assert len(result_ppl.stages) == 1 assert isinstance(result_ppl.stages[0], result_cls) + + +def test_pipeline_aggregate_with_groups(): + start_ppl = _make_pipeline() + result_ppl = start_ppl.aggregate(Field.of("title"), groups=[Field.of("author")]) + assert len(start_ppl.stages) == 0 + assert len(result_ppl.stages) == 1 + assert isinstance(result_ppl.stages[0], stages.Aggregate) + assert list(result_ppl.stages[0].groups) == [Field.of("author")] + assert list(result_ppl.stages[0].accumulators) == [Field.of("title")] diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 19ebed3b5..936c0a0a9 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -10,15 +10,68 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License +# limitations under the License. import pytest +import mock import datetime -import google.cloud.firestore_v1.pipeline_expressions as expressions +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.types import document as document_pb +from google.cloud.firestore_v1.types import query as query_pb from google.cloud.firestore_v1.types.document import Value from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1._helpers import GeoPoint +from google.cloud.firestore_v1.pipeline_expressions import FilterCondition, ListOfExprs +import google.cloud.firestore_v1.pipeline_expressions as expr + + +@pytest.fixture +def mock_client(): + client = mock.Mock(spec=["_database_string", "collection"]) + client._database_string = "projects/p/databases/d" + return client + + +class TestOrdering: + @pytest.mark.parametrize( + "direction_arg,expected_direction", + [ + ("ASCENDING", expr.Ordering.Direction.ASCENDING), + ("DESCENDING", expr.Ordering.Direction.DESCENDING), + ("ascending", expr.Ordering.Direction.ASCENDING), + ("descending", expr.Ordering.Direction.DESCENDING), + (expr.Ordering.Direction.ASCENDING, expr.Ordering.Direction.ASCENDING), + (expr.Ordering.Direction.DESCENDING, expr.Ordering.Direction.DESCENDING), + ], + ) + def test_ctor(self, direction_arg, expected_direction): + instance = expr.Ordering("field1", direction_arg) + assert isinstance(instance.expr, expr.Field) + assert instance.expr.path == "field1" + assert instance.order_dir == expected_direction + + def test_repr(self): + field_expr = expr.Field.of("field1") + instance = expr.Ordering(field_expr, "ASCENDING") + repr_str = repr(instance) + assert repr_str == "Field.of('field1').ascending()" + + instance = expr.Ordering(field_expr, "DESCENDING") + repr_str = repr(instance) + assert repr_str == "Field.of('field1').descending()" + + def test_to_pb(self): + field_expr = expr.Field.of("field1") + instance = expr.Ordering(field_expr, "ASCENDING") + result = instance._to_pb() + assert result.map_value.fields["expression"].field_reference_value == "field1" + assert result.map_value.fields["direction"].string_value == "ascending" + + instance = expr.Ordering(field_expr, "DESCENDING") + result = instance._to_pb() + assert result.map_value.fields["expression"].field_reference_value == "field1" + assert result.map_value.fields["direction"].string_value == "descending" class TestExpr: @@ -27,7 +80,81 @@ def test_ctor(self): Base class should be abstract """ with pytest.raises(TypeError): - expressions.Expr() + expr.Expr() + + @pytest.mark.parametrize( + "method,args,result_cls", + [ + ("add", (2,), expr.Add), + ("subtract", (2,), expr.Subtract), + ("multiply", (2,), expr.Multiply), + ("divide", (2,), expr.Divide), + ("mod", (2,), expr.Mod), + ("logical_max", (2,), expr.LogicalMax), + ("logical_min", (2,), expr.LogicalMin), + ("eq", (2,), expr.Eq), + ("neq", (2,), expr.Neq), + ("lt", (2,), expr.Lt), + ("lte", (2,), expr.Lte), + ("gt", (2,), expr.Gt), + ("gte", (2,), expr.Gte), + ("in_any", ([None],), expr.In), + ("not_in_any", ([None],), expr.Not), + ("array_contains", (None,), expr.ArrayContains), + ("array_contains_all", ([None],), expr.ArrayContainsAll), + ("array_contains_any", ([None],), expr.ArrayContainsAny), + ("array_length", (), expr.ArrayLength), + ("array_reverse", (), expr.ArrayReverse), + ("is_nan", (), expr.IsNaN), + ("exists", (), expr.Exists), + ("sum", (), expr.Sum), + ("avg", (), expr.Avg), + ("count", (), expr.Count), + ("min", (), expr.Min), + ("max", (), expr.Max), + ("char_length", (), expr.CharLength), + ("byte_length", (), expr.ByteLength), + ("like", ("pattern",), expr.Like), + ("regex_contains", ("regex",), expr.RegexContains), + ("regex_matches", ("regex",), expr.RegexMatch), + ("str_contains", ("substring",), expr.StrContains), + ("starts_with", ("prefix",), expr.StartsWith), + ("ends_with", ("postfix",), expr.EndsWith), + ("str_concat", ("elem1", expr.Constant("elem2")), expr.StrConcat), + ("map_get", ("key",), expr.MapGet), + ("vector_length", (), expr.VectorLength), + ("timestamp_to_unix_micros", (), expr.TimestampToUnixMicros), + ("unix_micros_to_timestamp", (), expr.UnixMicrosToTimestamp), + ("timestamp_to_unix_millis", (), expr.TimestampToUnixMillis), + ("unix_millis_to_timestamp", (), expr.UnixMillisToTimestamp), + ("timestamp_to_unix_seconds", (), expr.TimestampToUnixSeconds), + ("unix_seconds_to_timestamp", (), expr.UnixSecondsToTimestamp), + ("timestamp_add", ("day", 1), expr.TimestampAdd), + ("timestamp_sub", ("hour", 2.5), expr.TimestampSub), + ("ascending", (), expr.Ordering), + ("descending", (), expr.Ordering), + ("as_", ("alias",), expr.ExprWithAlias), + ], + ) + @pytest.mark.parametrize( + "base_instance", + [ + expr.Constant(1), + expr.Function.add("1", 1), + expr.Field.of("test"), + expr.Constant(1).as_("one"), + ], + ) + def test_infix_call(self, method, args, result_cls, base_instance): + """ + many FilterCondition expressions support infix execution, and are exposed as methods on Expr. Test calling them + """ + method_ptr = getattr(base_instance, method) + + result = method_ptr(*args) + assert isinstance(result, result_cls) + if isinstance(result, expr.Function) and not method == "not_in_any": + assert result.params[0] == base_instance class TestConstant: @@ -73,7 +200,7 @@ class TestConstant: ], ) def test_to_pb(self, input_val, to_pb_val): - instance = expressions.Constant.of(input_val) + instance = expr.Constant.of(input_val) assert instance._to_pb() == to_pb_val @pytest.mark.parametrize( @@ -99,6 +226,1006 @@ def test_to_pb(self, input_val, to_pb_val): ], ) def test_repr(self, input_val, expected): - instance = expressions.Constant.of(input_val) + instance = expr.Constant.of(input_val) repr_string = repr(instance) assert repr_string == expected + + @pytest.mark.parametrize( + "first,second,expected", + [ + (expr.Constant.of(1), expr.Constant.of(2), False), + (expr.Constant.of(1), expr.Constant.of(1), True), + (expr.Constant.of(1), 1, True), + (expr.Constant.of(1), 2, False), + (expr.Constant.of("1"), 1, False), + (expr.Constant.of("1"), "1", True), + (expr.Constant.of(None), expr.Constant.of(0), False), + (expr.Constant.of(None), expr.Constant.of(None), True), + (expr.Constant.of([1, 2, 3]), expr.Constant.of([1, 2, 3]), True), + (expr.Constant.of([1, 2, 3]), expr.Constant.of([1, 2]), False), + (expr.Constant.of([1, 2, 3]), [1, 2, 3], True), + (expr.Constant.of([1, 2, 3]), object(), False), + ], + ) + def test_equality(self, first, second, expected): + assert (first == second) is expected + + +class TestListOfExprs: + def test_to_pb(self): + instance = expr.ListOfExprs([expr.Constant(1), expr.Constant(2)]) + result = instance._to_pb() + assert len(result.array_value.values) == 2 + assert result.array_value.values[0].integer_value == 1 + assert result.array_value.values[1].integer_value == 2 + + def test_empty_to_pb(self): + instance = expr.ListOfExprs([]) + result = instance._to_pb() + assert len(result.array_value.values) == 0 + + def test_repr(self): + instance = expr.ListOfExprs([expr.Constant(1), expr.Constant(2)]) + repr_string = repr(instance) + assert repr_string == "ListOfExprs([Constant.of(1), Constant.of(2)])" + empty_instance = expr.ListOfExprs([]) + empty_repr_string = repr(empty_instance) + assert empty_repr_string == "ListOfExprs([])" + + @pytest.mark.parametrize( + "first,second,expected", + [ + (expr.ListOfExprs([]), expr.ListOfExprs([]), True), + (expr.ListOfExprs([]), expr.ListOfExprs([expr.Constant(1)]), False), + (expr.ListOfExprs([expr.Constant(1)]), expr.ListOfExprs([]), False), + ( + expr.ListOfExprs([expr.Constant(1)]), + expr.ListOfExprs([expr.Constant(1)]), + True, + ), + ( + expr.ListOfExprs([expr.Constant(1)]), + expr.ListOfExprs([expr.Constant(2)]), + False, + ), + ( + expr.ListOfExprs([expr.Constant(1), expr.Constant(2)]), + expr.ListOfExprs([expr.Constant(1), expr.Constant(2)]), + True, + ), + (expr.ListOfExprs([expr.Constant(1)]), [expr.Constant(1)], False), + (expr.ListOfExprs([expr.Constant(1)]), [1], False), + (expr.ListOfExprs([expr.Constant(1)]), object(), False), + ], + ) + def test_equality(self, first, second, expected): + assert (first == second) is expected + + +class TestSelectable: + """ + contains tests for each Expr class that derives from Selectable + """ + + def test_ctor(self): + """ + Base class should be abstract + """ + with pytest.raises(TypeError): + expr.Selectable() + + def test_value_from_selectables(self): + selectable_list = [ + expr.Field.of("field1"), + expr.Field.of("field2").as_("alias2"), + ] + result = expr.Selectable._value_from_selectables(*selectable_list) + assert len(result.map_value.fields) == 2 + assert result.map_value.fields["field1"].field_reference_value == "field1" + assert result.map_value.fields["alias2"].field_reference_value == "field2" + + @pytest.mark.parametrize( + "first,second,expected", + [ + (expr.Field.of("field1"), expr.Field.of("field1"), True), + (expr.Field.of("field1"), expr.Field.of("field2"), False), + (expr.Field.of(None), object(), False), + (expr.Field.of("f").as_("a"), expr.Field.of("f").as_("a"), True), + (expr.Field.of("one").as_("a"), expr.Field.of("two").as_("a"), False), + (expr.Field.of("f").as_("one"), expr.Field.of("f").as_("two"), False), + (expr.Field.of("field"), expr.Field.of("field").as_("alias"), False), + (expr.Field.of("field").as_("alias"), expr.Field.of("field"), False), + ], + ) + def test_equality(self, first, second, expected): + assert (first == second) is expected + + class TestField: + def test_repr(self): + instance = expr.Field.of("field1") + repr_string = repr(instance) + assert repr_string == "Field.of('field1')" + + def test_of(self): + instance = expr.Field.of("field1") + assert instance.path == "field1" + + def test_to_pb(self): + instance = expr.Field.of("field1") + result = instance._to_pb() + assert result.field_reference_value == "field1" + + def test_to_map(self): + instance = expr.Field.of("field1") + result = instance._to_map() + assert result[0] == "field1" + assert result[1] == Value(field_reference_value="field1") + + class TestExprWithAlias: + def test_repr(self): + instance = expr.Field.of("field1").as_("alias1") + assert repr(instance) == "Field.of('field1').as_('alias1')" + + def test_ctor(self): + arg = expr.Field.of("field1") + alias = "alias1" + instance = expr.ExprWithAlias(arg, alias) + assert instance.expr == arg + assert instance.alias == alias + + def test_to_pb(self): + arg = expr.Field.of("field1") + alias = "alias1" + instance = expr.ExprWithAlias(arg, alias) + result = instance._to_pb() + assert result.map_value.fields.get("alias1") == arg._to_pb() + + def test_to_map(self): + instance = expr.Field.of("field1").as_("alias1") + result = instance._to_map() + assert result[0] == "alias1" + assert result[1] == Value(field_reference_value="field1") + + +class TestFilterCondition: + def test__from_query_filter_pb_composite_filter_or(self, mock_client): + """ + test composite OR filters + + should create an or statement, made up of ands checking of existance of relevant fields + """ + filter1_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field1"), + op=query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, + value=_helpers.encode_value("val1"), + ) + filter2_pb = query_pb.StructuredQuery.UnaryFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field2"), + op=query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, + ) + + composite_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.OR, + filters=[ + query_pb.StructuredQuery.Filter(field_filter=filter1_pb), + query_pb.StructuredQuery.Filter(unary_filter=filter2_pb), + ], + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter( + composite_filter=composite_pb + ) + + result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + # should include existance checks + expected_cond1 = expr.And( + expr.Exists(expr.Field.of("field1")), + expr.Eq(expr.Field.of("field1"), expr.Constant("val1")), + ) + expected_cond2 = expr.And( + expr.Exists(expr.Field.of("field2")), + expr.Eq(expr.Field.of("field2"), expr.Constant(None)), + ) + expected = expr.Or(expected_cond1, expected_cond2) + + assert repr(result) == repr(expected) + + def test__from_query_filter_pb_composite_filter_and(self, mock_client): + """ + test composite AND filters + + should create an and statement, made up of ands checking of existance of relevant fields + """ + filter1_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field1"), + op=query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + value=_helpers.encode_value(100), + ) + filter2_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field2"), + op=query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, + value=_helpers.encode_value(200), + ) + + composite_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.AND, + filters=[ + query_pb.StructuredQuery.Filter(field_filter=filter1_pb), + query_pb.StructuredQuery.Filter(field_filter=filter2_pb), + ], + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter( + composite_filter=composite_pb + ) + + result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + # should include existance checks + expected_cond1 = expr.And( + expr.Exists(expr.Field.of("field1")), + expr.Gt(expr.Field.of("field1"), expr.Constant(100)), + ) + expected_cond2 = expr.And( + expr.Exists(expr.Field.of("field2")), + expr.Lt(expr.Field.of("field2"), expr.Constant(200)), + ) + expected = expr.And(expected_cond1, expected_cond2) + assert repr(result) == repr(expected) + + def test__from_query_filter_pb_composite_filter_nested(self, mock_client): + """ + test composite filter with complex nested checks + """ + # OR (field1 == "val1", AND(field2 > 10, field3 IS NOT NULL)) + filter1_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field1"), + op=query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, + value=_helpers.encode_value("val1"), + ) + filter2_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field2"), + op=query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + value=_helpers.encode_value(10), + ) + filter3_pb = query_pb.StructuredQuery.UnaryFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field3"), + op=query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, + ) + inner_and_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.AND, + filters=[ + query_pb.StructuredQuery.Filter(field_filter=filter2_pb), + query_pb.StructuredQuery.Filter(unary_filter=filter3_pb), + ], + ) + outer_or_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.OR, + filters=[ + query_pb.StructuredQuery.Filter(field_filter=filter1_pb), + query_pb.StructuredQuery.Filter(composite_filter=inner_and_pb), + ], + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter( + composite_filter=outer_or_pb + ) + + result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + expected_cond1 = expr.And( + expr.Exists(expr.Field.of("field1")), + expr.Eq(expr.Field.of("field1"), expr.Constant("val1")), + ) + expected_cond2 = expr.And( + expr.Exists(expr.Field.of("field2")), + expr.Gt(expr.Field.of("field2"), expr.Constant(10)), + ) + expected_cond3 = expr.And( + expr.Exists(expr.Field.of("field3")), + expr.Not(expr.Eq(expr.Field.of("field3"), expr.Constant(None))), + ) + expected_inner_and = expr.And(expected_cond2, expected_cond3) + expected_outer_or = expr.Or(expected_cond1, expected_inner_and) + + assert repr(result) == repr(expected_outer_or) + + def test__from_query_filter_pb_composite_filter_unknown_op(self, mock_client): + """ + check composite filter with unsupported operator type + """ + filter1_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field1"), + op=query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, + value=_helpers.encode_value("val1"), + ) + composite_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.OPERATOR_UNSPECIFIED, + filters=[query_pb.StructuredQuery.Filter(field_filter=filter1_pb)], + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter( + composite_filter=composite_pb + ) + + with pytest.raises(TypeError, match="Unexpected CompositeFilter operator type"): + FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + @pytest.mark.parametrize( + "op_enum, expected_expr_func", + [ + (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NAN, expr.IsNaN), + ( + query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NAN, + lambda f: expr.Not(f.is_nan()), + ), + ( + query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, + lambda f: f.eq(None), + ), + ( + query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, + lambda f: expr.Not(f.eq(None)), + ), + ], + ) + def test__from_query_filter_pb_unary_filter( + self, mock_client, op_enum, expected_expr_func + ): + """ + test supported unary filters + """ + field_path = "unary_field" + filter_pb = query_pb.StructuredQuery.UnaryFilter( + field=query_pb.StructuredQuery.FieldReference(field_path=field_path), + op=op_enum, + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) + + result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + field_expr_inst = expr.Field.of(field_path) + expected_condition = expected_expr_func(field_expr_inst) + # should include existance checks + expected = expr.And(expr.Exists(field_expr_inst), expected_condition) + + assert repr(result) == repr(expected) + + def test__from_query_filter_pb_unary_filter_unknown_op(self, mock_client): + """ + check unary filter with unsupported operator type + """ + field_path = "unary_field" + filter_pb = query_pb.StructuredQuery.UnaryFilter( + field=query_pb.StructuredQuery.FieldReference(field_path=field_path), + op=query_pb.StructuredQuery.UnaryFilter.Operator.OPERATOR_UNSPECIFIED, # Unknown op + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) + + with pytest.raises(TypeError, match="Unexpected UnaryFilter operator type"): + FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + @pytest.mark.parametrize( + "op_enum, value, expected_expr_func", + [ + (query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, 10, expr.Lt), + ( + query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN_OR_EQUAL, + 10, + expr.Lte, + ), + (query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, 10, expr.Gt), + ( + query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN_OR_EQUAL, + 10, + expr.Gte, + ), + (query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, 10, expr.Eq), + (query_pb.StructuredQuery.FieldFilter.Operator.NOT_EQUAL, 10, expr.Neq), + ( + query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS, + 10, + expr.ArrayContains, + ), + ( + query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS_ANY, + [10, 20], + expr.ArrayContainsAny, + ), + (query_pb.StructuredQuery.FieldFilter.Operator.IN, [10, 20], expr.In), + ( + query_pb.StructuredQuery.FieldFilter.Operator.NOT_IN, + [10, 20], + lambda f, v: expr.Not(f.in_any(v)), + ), + ], + ) + def test__from_query_filter_pb_field_filter( + self, mock_client, op_enum, value, expected_expr_func + ): + """ + test supported field filters + """ + field_path = "test_field" + value_pb = _helpers.encode_value(value) + filter_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path=field_path), + op=op_enum, + value=value_pb, + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) + + result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + field_expr = expr.Field.of(field_path) + # convert values into constants + value = ( + [expr.Constant(e) for e in value] + if isinstance(value, list) + else expr.Constant(value) + ) + expected_condition = expected_expr_func(field_expr, value) + # should include existance checks + expected = expr.And(expr.Exists(field_expr), expected_condition) + + assert repr(result) == repr(expected) + + def test__from_query_filter_pb_field_filter_unknown_op(self, mock_client): + """ + check field filter with unsupported operator type + """ + field_path = "test_field" + value_pb = _helpers.encode_value(10) + filter_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path=field_path), + op=query_pb.StructuredQuery.FieldFilter.Operator.OPERATOR_UNSPECIFIED, # Unknown op + value=value_pb, + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) + + with pytest.raises(TypeError, match="Unexpected FieldFilter operator type"): + FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + def test__from_query_filter_pb_unknown_filter_type(self, mock_client): + """ + test with unsupported filter type + """ + # Test with an unexpected protobuf type + with pytest.raises(TypeError, match="Unexpected filter type"): + FilterCondition._from_query_filter_pb(document_pb.Value(), mock_client) + + +class TestFilterConditionClasses: + """ + contains test methods for each Expr class that derives from FilterCondition + """ + + def _make_arg(self, name="Mock"): + arg = mock.Mock() + arg.__repr__ = lambda x: name + return arg + + def test_and(self): + arg1 = self._make_arg() + arg2 = self._make_arg() + instance = expr.And(arg1, arg2) + assert instance.name == "and" + assert instance.params == [arg1, arg2] + assert repr(instance) == "And(Mock, Mock)" + + def test_or(self): + arg1 = self._make_arg("Arg1") + arg2 = self._make_arg("Arg2") + instance = expr.Or(arg1, arg2) + assert instance.name == "or" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Arg1.or(Arg2)" + + def test_array_contains(self): + arg1 = self._make_arg("ArrayField") + arg2 = self._make_arg("Element") + instance = expr.ArrayContains(arg1, arg2) + assert instance.name == "array_contains" + assert instance.params == [arg1, arg2] + assert repr(instance) == "ArrayField.array_contains(Element)" + + def test_array_contains_any(self): + arg1 = self._make_arg("ArrayField") + arg2 = self._make_arg("Element1") + arg3 = self._make_arg("Element2") + instance = expr.ArrayContainsAny(arg1, [arg2, arg3]) + assert instance.name == "array_contains_any" + assert isinstance(instance.params[1], ListOfExprs) + assert instance.params[0] == arg1 + assert instance.params[1].exprs == [arg2, arg3] + assert ( + repr(instance) + == "ArrayField.array_contains_any(ListOfExprs([Element1, Element2]))" + ) + + def test_exists(self): + arg1 = self._make_arg("Field") + instance = expr.Exists(arg1) + assert instance.name == "exists" + assert instance.params == [arg1] + assert repr(instance) == "Field.exists()" + + def test_eq(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Eq(arg1, arg2) + assert instance.name == "eq" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.eq(Right)" + + def test_gte(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Gte(arg1, arg2) + assert instance.name == "gte" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.gte(Right)" + + def test_gt(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Gt(arg1, arg2) + assert instance.name == "gt" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.gt(Right)" + + def test_lte(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Lte(arg1, arg2) + assert instance.name == "lte" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.lte(Right)" + + def test_lt(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Lt(arg1, arg2) + assert instance.name == "lt" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.lt(Right)" + + def test_neq(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Neq(arg1, arg2) + assert instance.name == "neq" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.neq(Right)" + + def test_in(self): + arg1 = self._make_arg("Field") + arg2 = self._make_arg("Value1") + arg3 = self._make_arg("Value2") + instance = expr.In(arg1, [arg2, arg3]) + assert instance.name == "in" + assert isinstance(instance.params[1], ListOfExprs) + assert instance.params[0] == arg1 + assert instance.params[1].exprs == [arg2, arg3] + assert repr(instance) == "Field.in_any(ListOfExprs([Value1, Value2]))" + + def test_is_nan(self): + arg1 = self._make_arg("Value") + instance = expr.IsNaN(arg1) + assert instance.name == "is_nan" + assert instance.params == [arg1] + assert repr(instance) == "Value.is_nan()" + + def test_not(self): + arg1 = self._make_arg("Condition") + instance = expr.Not(arg1) + assert instance.name == "not" + assert instance.params == [arg1] + assert repr(instance) == "Not(Condition)" + + def test_array_contains_all(self): + arg1 = self._make_arg("ArrayField") + arg2 = self._make_arg("Element1") + arg3 = self._make_arg("Element2") + instance = expr.ArrayContainsAll(arg1, [arg2, arg3]) + assert instance.name == "array_contains_all" + assert isinstance(instance.params[1], ListOfExprs) + assert instance.params[0] == arg1 + assert instance.params[1].exprs == [arg2, arg3] + assert ( + repr(instance) + == "ArrayField.array_contains_all(ListOfExprs([Element1, Element2]))" + ) + + def test_ends_with(self): + arg1 = self._make_arg("Expr") + arg2 = self._make_arg("Postfix") + instance = expr.EndsWith(arg1, arg2) + assert instance.name == "ends_with" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expr.ends_with(Postfix)" + + def test_if(self): + arg1 = self._make_arg("Condition") + arg2 = self._make_arg("TrueExpr") + arg3 = self._make_arg("FalseExpr") + instance = expr.If(arg1, arg2, arg3) + assert instance.name == "if" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "If(Condition, TrueExpr, FalseExpr)" + + def test_like(self): + arg1 = self._make_arg("Expr") + arg2 = self._make_arg("Pattern") + instance = expr.Like(arg1, arg2) + assert instance.name == "like" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expr.like(Pattern)" + + def test_regex_contains(self): + arg1 = self._make_arg("Expr") + arg2 = self._make_arg("Regex") + instance = expr.RegexContains(arg1, arg2) + assert instance.name == "regex_contains" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expr.regex_contains(Regex)" + + def test_regex_match(self): + arg1 = self._make_arg("Expr") + arg2 = self._make_arg("Regex") + instance = expr.RegexMatch(arg1, arg2) + assert instance.name == "regex_match" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expr.regex_match(Regex)" + + def test_starts_with(self): + arg1 = self._make_arg("Expr") + arg2 = self._make_arg("Prefix") + instance = expr.StartsWith(arg1, arg2) + assert instance.name == "starts_with" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expr.starts_with(Prefix)" + + def test_str_contains(self): + arg1 = self._make_arg("Expr") + arg2 = self._make_arg("Substring") + instance = expr.StrContains(arg1, arg2) + assert instance.name == "str_contains" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expr.str_contains(Substring)" + + def test_xor(self): + arg1 = self._make_arg("Condition1") + arg2 = self._make_arg("Condition2") + instance = expr.Xor([arg1, arg2]) + assert instance.name == "xor" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Xor(Condition1, Condition2)" + + +class TestFunctionClasses: + """ + contains test methods for each Expr class that derives from Function + """ + + @pytest.mark.parametrize( + "method,args,result_cls", + [ + ("add", ("field", 2), expr.Add), + ("subtract", ("field", 2), expr.Subtract), + ("multiply", ("field", 2), expr.Multiply), + ("divide", ("field", 2), expr.Divide), + ("mod", ("field", 2), expr.Mod), + ("logical_max", ("field", 2), expr.LogicalMax), + ("logical_min", ("field", 2), expr.LogicalMin), + ("eq", ("field", 2), expr.Eq), + ("neq", ("field", 2), expr.Neq), + ("lt", ("field", 2), expr.Lt), + ("lte", ("field", 2), expr.Lte), + ("gt", ("field", 2), expr.Gt), + ("gte", ("field", 2), expr.Gte), + ("in_any", ("field", [None]), expr.In), + ("not_in_any", ("field", [None]), expr.Not), + ("array_contains", ("field", None), expr.ArrayContains), + ("array_contains_all", ("field", [None]), expr.ArrayContainsAll), + ("array_contains_any", ("field", [None]), expr.ArrayContainsAny), + ("array_length", ("field",), expr.ArrayLength), + ("array_reverse", ("field",), expr.ArrayReverse), + ("is_nan", ("field",), expr.IsNaN), + ("exists", ("field",), expr.Exists), + ("sum", ("field",), expr.Sum), + ("avg", ("field",), expr.Avg), + ("count", ("field",), expr.Count), + ("count", (), expr.Count), + ("min", ("field",), expr.Min), + ("max", ("field",), expr.Max), + ("char_length", ("field",), expr.CharLength), + ("byte_length", ("field",), expr.ByteLength), + ("like", ("field", "pattern"), expr.Like), + ("regex_contains", ("field", "regex"), expr.RegexContains), + ("regex_matches", ("field", "regex"), expr.RegexMatch), + ("str_contains", ("field", "substring"), expr.StrContains), + ("starts_with", ("field", "prefix"), expr.StartsWith), + ("ends_with", ("field", "postfix"), expr.EndsWith), + ("str_concat", ("field", "elem1", "elem2"), expr.StrConcat), + ("map_get", ("field", "key"), expr.MapGet), + ("vector_length", ("field",), expr.VectorLength), + ("timestamp_to_unix_micros", ("field",), expr.TimestampToUnixMicros), + ("unix_micros_to_timestamp", ("field",), expr.UnixMicrosToTimestamp), + ("timestamp_to_unix_millis", ("field",), expr.TimestampToUnixMillis), + ("unix_millis_to_timestamp", ("field",), expr.UnixMillisToTimestamp), + ("timestamp_to_unix_seconds", ("field",), expr.TimestampToUnixSeconds), + ("unix_seconds_to_timestamp", ("field",), expr.UnixSecondsToTimestamp), + ("timestamp_add", ("field", "day", 1), expr.TimestampAdd), + ("timestamp_sub", ("field", "hour", 2.5), expr.TimestampSub), + ], + ) + def test_function_builder(self, method, args, result_cls): + """ + Test building functions using methods exposed on base Function class. + """ + method_ptr = getattr(expr.Function, method) + + result = method_ptr(*args) + assert isinstance(result, result_cls) + + @pytest.mark.parametrize( + "first,second,expected", + [ + (expr.ArrayElement(), expr.ArrayElement(), True), + (expr.ArrayElement(), expr.CharLength(1), False), + (expr.ArrayElement(), object(), False), + (expr.ArrayElement(), None, False), + (expr.CharLength(1), expr.ArrayElement(), False), + (expr.CharLength(1), expr.CharLength(2), False), + (expr.CharLength(1), expr.CharLength(1), True), + (expr.CharLength(1), expr.ByteLength(1), False), + ], + ) + def test_equality(self, first, second, expected): + assert (first == second) is expected + + def _make_arg(self, name="Mock"): + arg = mock.Mock() + arg.__repr__ = lambda x: name + return arg + + def test_divide(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Divide(arg1, arg2) + assert instance.name == "divide" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Divide(Left, Right)" + + def test_logical_max(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.LogicalMax(arg1, arg2) + assert instance.name == "logical_maximum" + assert instance.params == [arg1, arg2] + assert repr(instance) == "LogicalMax(Left, Right)" + + def test_logical_min(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.LogicalMin(arg1, arg2) + assert instance.name == "logical_minimum" + assert instance.params == [arg1, arg2] + assert repr(instance) == "LogicalMin(Left, Right)" + + def test_map_get(self): + arg1 = self._make_arg("Map") + arg2 = expr.Constant("Key") + instance = expr.MapGet(arg1, arg2) + assert instance.name == "map_get" + assert instance.params == [arg1, arg2] + assert repr(instance) == "MapGet(Map, Constant.of('Key'))" + + def test_mod(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Mod(arg1, arg2) + assert instance.name == "mod" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Mod(Left, Right)" + + def test_multiply(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Multiply(arg1, arg2) + assert instance.name == "multiply" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Multiply(Left, Right)" + + def test_parent(self): + arg1 = self._make_arg("Value") + instance = expr.Parent(arg1) + assert instance.name == "parent" + assert instance.params == [arg1] + assert repr(instance) == "Parent(Value)" + + def test_str_concat(self): + arg1 = self._make_arg("Str1") + arg2 = self._make_arg("Str2") + instance = expr.StrConcat(arg1, arg2) + assert instance.name == "str_concat" + assert instance.params == [arg1, arg2] + assert repr(instance) == "StrConcat(Str1, Str2)" + + def test_subtract(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Subtract(arg1, arg2) + assert instance.name == "subtract" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Subtract(Left, Right)" + + def test_timestamp_add(self): + arg1 = self._make_arg("Timestamp") + arg2 = self._make_arg("Unit") + arg3 = self._make_arg("Amount") + instance = expr.TimestampAdd(arg1, arg2, arg3) + assert instance.name == "timestamp_add" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "TimestampAdd(Timestamp, Unit, Amount)" + + def test_timestamp_sub(self): + arg1 = self._make_arg("Timestamp") + arg2 = self._make_arg("Unit") + arg3 = self._make_arg("Amount") + instance = expr.TimestampSub(arg1, arg2, arg3) + assert instance.name == "timestamp_sub" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "TimestampSub(Timestamp, Unit, Amount)" + + def test_timestamp_to_unix_micros(self): + arg1 = self._make_arg("Input") + instance = expr.TimestampToUnixMicros(arg1) + assert instance.name == "timestamp_to_unix_micros" + assert instance.params == [arg1] + assert repr(instance) == "TimestampToUnixMicros(Input)" + + def test_timestamp_to_unix_millis(self): + arg1 = self._make_arg("Input") + instance = expr.TimestampToUnixMillis(arg1) + assert instance.name == "timestamp_to_unix_millis" + assert instance.params == [arg1] + assert repr(instance) == "TimestampToUnixMillis(Input)" + + def test_timestamp_to_unix_seconds(self): + arg1 = self._make_arg("Input") + instance = expr.TimestampToUnixSeconds(arg1) + assert instance.name == "timestamp_to_unix_seconds" + assert instance.params == [arg1] + assert repr(instance) == "TimestampToUnixSeconds(Input)" + + def test_unix_micros_to_timestamp(self): + arg1 = self._make_arg("Input") + instance = expr.UnixMicrosToTimestamp(arg1) + assert instance.name == "unix_micros_to_timestamp" + assert instance.params == [arg1] + assert repr(instance) == "UnixMicrosToTimestamp(Input)" + + def test_unix_millis_to_timestamp(self): + arg1 = self._make_arg("Input") + instance = expr.UnixMillisToTimestamp(arg1) + assert instance.name == "unix_millis_to_timestamp" + assert instance.params == [arg1] + assert repr(instance) == "UnixMillisToTimestamp(Input)" + + def test_unix_seconds_to_timestamp(self): + arg1 = self._make_arg("Input") + instance = expr.UnixSecondsToTimestamp(arg1) + assert instance.name == "unix_seconds_to_timestamp" + assert instance.params == [arg1] + assert repr(instance) == "UnixSecondsToTimestamp(Input)" + + def test_vector_length(self): + arg1 = self._make_arg("Array") + instance = expr.VectorLength(arg1) + assert instance.name == "vector_length" + assert instance.params == [arg1] + assert repr(instance) == "VectorLength(Array)" + + def test_add(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Add(arg1, arg2) + assert instance.name == "add" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Add(Left, Right)" + + def test_array_element(self): + instance = expr.ArrayElement() + assert instance.name == "array_element" + assert instance.params == [] + assert repr(instance) == "ArrayElement()" + + def test_array_filter(self): + arg1 = self._make_arg("Array") + arg2 = self._make_arg("FilterCond") + instance = expr.ArrayFilter(arg1, arg2) + assert instance.name == "array_filter" + assert instance.params == [arg1, arg2] + assert repr(instance) == "ArrayFilter(Array, FilterCond)" + + def test_array_length(self): + arg1 = self._make_arg("Array") + instance = expr.ArrayLength(arg1) + assert instance.name == "array_length" + assert instance.params == [arg1] + assert repr(instance) == "ArrayLength(Array)" + + def test_array_reverse(self): + arg1 = self._make_arg("Array") + instance = expr.ArrayReverse(arg1) + assert instance.name == "array_reverse" + assert instance.params == [arg1] + assert repr(instance) == "ArrayReverse(Array)" + + def test_array_transform(self): + arg1 = self._make_arg("Array") + arg2 = self._make_arg("TransformFunc") + instance = expr.ArrayTransform(arg1, arg2) + assert instance.name == "array_transform" + assert instance.params == [arg1, arg2] + assert repr(instance) == "ArrayTransform(Array, TransformFunc)" + + def test_byte_length(self): + arg1 = self._make_arg("Expr") + instance = expr.ByteLength(arg1) + assert instance.name == "byte_length" + assert instance.params == [arg1] + assert repr(instance) == "ByteLength(Expr)" + + def test_char_length(self): + arg1 = self._make_arg("Expr") + instance = expr.CharLength(arg1) + assert instance.name == "char_length" + assert instance.params == [arg1] + assert repr(instance) == "CharLength(Expr)" + + def test_collection_id(self): + arg1 = self._make_arg("Value") + instance = expr.CollectionId(arg1) + assert instance.name == "collection_id" + assert instance.params == [arg1] + assert repr(instance) == "CollectionId(Value)" + + def test_sum(self): + arg1 = self._make_arg("Value") + instance = expr.Sum(arg1) + assert instance.name == "sum" + assert instance.params == [arg1] + assert repr(instance) == "Sum(Value)" + + def test_avg(self): + arg1 = self._make_arg("Value") + instance = expr.Avg(arg1) + assert instance.name == "avg" + assert instance.params == [arg1] + assert repr(instance) == "Avg(Value)" + + def test_count(self): + arg1 = self._make_arg("Value") + instance = expr.Count(arg1) + assert instance.name == "count" + assert instance.params == [arg1] + assert repr(instance) == "Count(Value)" + + def test_count_empty(self): + instance = expr.Count() + assert instance.params == [] + assert repr(instance) == "Count()" + + def test_min(self): + arg1 = self._make_arg("Value") + instance = expr.Min(arg1) + assert instance.name == "minimum" + assert instance.params == [arg1] + assert repr(instance) == "Min(Value)" + + def test_max(self): + arg1 = self._make_arg("Value") + instance = expr.Max(arg1) + assert instance.name == "maximum" + assert instance.params == [arg1] + assert repr(instance) == "Max(Value)" diff --git a/tests/unit/v1/test_pipeline_source.py b/tests/unit/v1/test_pipeline_source.py index cd8b56b68..bed1bd05a 100644 --- a/tests/unit/v1/test_pipeline_source.py +++ b/tests/unit/v1/test_pipeline_source.py @@ -18,6 +18,7 @@ from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1.base_document import BaseDocumentReference class TestPipelineSource: @@ -44,6 +45,49 @@ def test_collection(self): assert isinstance(first_stage, stages.Collection) assert first_stage.path == "/path" + def test_collection_w_tuple(self): + instance = self._make_client().pipeline() + ppl = instance.collection(("a", "b", "c")) + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + first_stage = ppl.stages[0] + assert isinstance(first_stage, stages.Collection) + assert first_stage.path == "/a/b/c" + + def test_collection_group(self): + instance = self._make_client().pipeline() + ppl = instance.collection_group("id") + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + first_stage = ppl.stages[0] + assert isinstance(first_stage, stages.CollectionGroup) + assert first_stage.collection_id == "id" + + def test_database(self): + instance = self._make_client().pipeline() + ppl = instance.database() + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + first_stage = ppl.stages[0] + assert isinstance(first_stage, stages.Database) + + def test_documents(self): + instance = self._make_client().pipeline() + test_documents = [ + BaseDocumentReference("a", "1"), + BaseDocumentReference("a", "2"), + BaseDocumentReference("a", "3"), + ] + ppl = instance.documents(*test_documents) + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + first_stage = ppl.stages[0] + assert isinstance(first_stage, stages.Documents) + assert len(first_stage.paths) == 3 + assert first_stage.paths[0] == "/a/1" + assert first_stage.paths[1] == "/a/2" + assert first_stage.paths[2] == "/a/3" + class TestPipelineSourceWithAsyncClient(TestPipelineSource): """ diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py index 59d808d63..e67a4ca3a 100644 --- a/tests/unit/v1/test_pipeline_stages.py +++ b/tests/unit/v1/test_pipeline_stages.py @@ -13,11 +13,21 @@ # limitations under the License import pytest +from unittest import mock +from google.cloud.firestore_v1.base_pipeline import _BasePipeline import google.cloud.firestore_v1._pipeline_stages as stages -from google.cloud.firestore_v1.pipeline_expressions import Constant +from google.cloud.firestore_v1.pipeline_expressions import ( + Constant, + Field, + Ordering, + Sum, + Count, +) from google.cloud.firestore_v1.types.document import Value from google.cloud.firestore_v1._helpers import GeoPoint +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure class TestStage: @@ -29,6 +39,113 @@ def test_ctor(self): stages.Stage() +class TestAddFields: + def _make_one(self, *args, **kwargs): + return stages.AddFields(*args, **kwargs) + + def test_ctor(self): + field1 = Field.of("field1") + field2_aliased = Field.of("field2").as_("alias2") + instance = self._make_one(field1, field2_aliased) + assert instance.fields == [field1, field2_aliased] + assert instance.name == "add_fields" + + def test_repr(self): + field1 = Field.of("field1").as_("f1") + instance = self._make_one(field1) + repr_str = repr(instance) + assert repr_str == "AddFields(fields=[Field.of('field1').as_('f1')])" + + def test_to_pb(self): + field1 = Field.of("field1") + field2_aliased = Field.of("field2").as_("alias2") + instance = self._make_one(field1, field2_aliased) + result = instance._to_pb() + assert result.name == "add_fields" + assert len(result.args) == 1 + expected_map_value = { + "fields": { + "field1": Value(field_reference_value="field1"), + "alias2": Value(field_reference_value="field2"), + } + } + assert result.args[0].map_value.fields == expected_map_value["fields"] + assert len(result.options) == 0 + + +class TestAggregate: + def _make_one(self, *args, **kwargs): + return stages.Aggregate(*args, **kwargs) + + def test_ctor_positional(self): + """test with only positional arguments""" + sum_total = Sum(Field.of("total")).as_("sum_total") + avg_price = Field.of("price").avg().as_("avg_price") + instance = self._make_one(sum_total, avg_price) + assert list(instance.accumulators) == [sum_total, avg_price] + assert len(instance.groups) == 0 + assert instance.name == "aggregate" + + def test_ctor_keyword(self): + """test with only keyword arguments""" + sum_total = Sum(Field.of("total")).as_("sum_total") + avg_price = Field.of("price").avg().as_("avg_price") + group_category = Field.of("category") + instance = self._make_one( + accumulators=[avg_price, sum_total], groups=[group_category, "city"] + ) + assert instance.accumulators == [avg_price, sum_total] + assert len(instance.groups) == 2 + assert instance.groups[0] == group_category + assert isinstance(instance.groups[1], Field) + assert instance.groups[1].path == "city" + assert instance.name == "aggregate" + + def test_ctor_combined(self): + """test with a mix of arguments""" + sum_total = Sum(Field.of("total")).as_("sum_total") + avg_price = Field.of("price").avg().as_("avg_price") + count = Count(Field.of("total")).as_("count") + with pytest.raises(ValueError): + self._make_one(sum_total, accumulators=[avg_price, count]) + + def test_repr(self): + sum_total = Sum(Field.of("total")).as_("sum_total") + group_category = Field.of("category") + instance = self._make_one(sum_total, groups=[group_category]) + repr_str = repr(instance) + assert ( + repr_str + == "Aggregate(Sum(Field.of('total')).as_('sum_total'), groups=[Field.of('category')])" + ) + + def test_to_pb(self): + sum_total = Sum(Field.of("total")).as_("sum_total") + group_category = Field.of("category") + instance = self._make_one(sum_total, groups=[group_category]) + result = instance._to_pb() + assert result.name == "aggregate" + assert len(result.args) == 2 + + expected_accumulators_map = { + "fields": { + "sum_total": Value( + function_value={ + "name": "sum", + "args": [Value(field_reference_value="total")], + } + ) + } + } + assert result.args[0].map_value.fields == expected_accumulators_map["fields"] + + expected_groups_map = { + "fields": {"category": Value(field_reference_value="category")} + } + assert result.args[1].map_value.fields == expected_groups_map["fields"] + assert len(result.options) == 0 + + class TestCollection: def _make_one(self, *args, **kwargs): return stages.Collection(*args, **kwargs) @@ -55,6 +172,256 @@ def test_to_pb(self): assert len(result.options) == 0 +class TestCollectionGroup: + def _make_one(self, *args, **kwargs): + return stages.CollectionGroup(*args, **kwargs) + + def test_repr(self): + input_arg = "test" + instance = self._make_one(input_arg) + repr_str = repr(instance) + assert repr_str == "CollectionGroup(collection_id='test')" + + def test_to_pb(self): + input_arg = "test" + instance = self._make_one(input_arg) + result = instance._to_pb() + assert result.name == "collection_group" + assert len(result.args) == 1 + assert result.args[0].string_value == "test" + assert len(result.options) == 0 + + +class TestDatabase: + def _make_one(self, *args, **kwargs): + return stages.Database(*args, **kwargs) + + def test_ctor(self): + instance = self._make_one() + assert instance.name == "database" + + def test_repr(self): + instance = self._make_one() + repr_str = repr(instance) + assert repr_str == "Database()" + + def test_to_pb(self): + instance = self._make_one() + result = instance._to_pb() + assert result.name == "database" + assert len(result.args) == 0 + assert len(result.options) == 0 + + +class TestDistinct: + def _make_one(self, *args, **kwargs): + return stages.Distinct(*args, **kwargs) + + def test_ctor(self): + field1 = Field.of("field1") + instance = self._make_one("field2", field1) + assert len(instance.fields) == 2 + assert isinstance(instance.fields[0], Field) + assert instance.fields[0].path == "field2" + assert instance.fields[1] == field1 + assert instance.name == "distinct" + + def test_repr(self): + instance = self._make_one("field1", Field.of("field2")) + repr_str = repr(instance) + assert repr_str == "Distinct(fields=[Field.of('field1'), Field.of('field2')])" + + def test_to_pb(self): + instance = self._make_one("field1", Field.of("field2")) + result = instance._to_pb() + assert result.name == "distinct" + assert len(result.args) == 1 + expected_map_value = { + "fields": { + "field1": Value(field_reference_value="field1"), + "field2": Value(field_reference_value="field2"), + } + } + assert result.args[0].map_value.fields == expected_map_value["fields"] + assert len(result.options) == 0 + + +class TestDocuments: + def _make_one(self, *args, **kwargs): + return stages.Documents(*args, **kwargs) + + def test_ctor(self): + instance = self._make_one("/projects/p/databases/d/documents/c/doc1", "/c/doc2") + assert instance.paths == ("/projects/p/databases/d/documents/c/doc1", "/c/doc2") + assert instance.name == "documents" + + def test_of(self): + mock_doc_ref1 = mock.Mock() + mock_doc_ref1.path = "projects/p/databases/d/documents/c/doc1" + mock_doc_ref2 = mock.Mock() + mock_doc_ref2.path = "c/doc2" # Test relative path as well + instance = stages.Documents.of(mock_doc_ref1, mock_doc_ref2) + assert instance.paths == ( + "/projects/p/databases/d/documents/c/doc1", + "/c/doc2", + ) + + def test_repr(self): + instance = self._make_one("/a/b", "/c/d") + repr_str = repr(instance) + assert repr_str == "Documents('/a/b', '/c/d')" + + def test_to_pb(self): + instance = self._make_one("/projects/p/databases/d/documents/c/doc1", "/c/doc2") + result = instance._to_pb() + assert result.name == "documents" + assert len(result.args) == 1 + assert ( + result.args[0].array_value.values[0].string_value + == "/projects/p/databases/d/documents/c/doc1" + ) + assert result.args[0].array_value.values[1].string_value == "/c/doc2" + assert len(result.options) == 0 + + +class TestFindNearest: + class TestFindNearestOptions: + def _make_one_options(self, *args, **kwargs): + return stages.FindNearestOptions(*args, **kwargs) + + def test_ctor_options(self): + limit_val = 10 + distance_field_val = Field.of("dist") + instance = self._make_one_options( + limit=limit_val, distance_field=distance_field_val + ) + assert instance.limit == limit_val + assert instance.distance_field == distance_field_val + + def test_ctor_defaults(self): + instance_default = self._make_one_options() + assert instance_default.limit is None + assert instance_default.distance_field is None + + def test_repr(self): + instance_empty = self._make_one_options() + assert repr(instance_empty) == "FindNearestOptions()" + instance_limit = self._make_one_options(limit=5) + assert repr(instance_limit) == "FindNearestOptions(limit=5)" + instance_distance = self._make_one_options(distance_field=Field.of("dist")) + assert ( + repr(instance_distance) + == "FindNearestOptions(distance_field=Field.of('dist'))" + ) + instance_full = self._make_one_options( + limit=5, distance_field=Field.of("dist") + ) + assert ( + repr(instance_full) + == "FindNearestOptions(limit=5, distance_field=Field.of('dist'))" + ) + + def _make_one(self, *args, **kwargs): + return stages.FindNearest(*args, **kwargs) + + def test_ctor_w_str_field(self): + field_path = "embedding_field" + vector_val = Vector([1.0, 2.0, 3.0]) + distance_measure_val = DistanceMeasure.EUCLIDEAN + options_val = stages.FindNearestOptions( + limit=5, distance_field=Field.of("distance") + ) + + instance_str_field = self._make_one( + field_path, vector_val, distance_measure_val, options=options_val + ) + assert isinstance(instance_str_field.field, Field) + assert instance_str_field.field.path == field_path + assert instance_str_field.vector == vector_val + assert instance_str_field.distance_measure == distance_measure_val + assert instance_str_field.options == options_val + assert instance_str_field.name == "find_nearest" + + def test_ctor_w_field_obj(self): + field_path = "embedding_field" + field_obj = Field.of(field_path) + vector_val = Vector([1.0, 2.0, 3.0]) + distance_measure_val = DistanceMeasure.EUCLIDEAN + instance_field_obj = self._make_one(field_obj, vector_val, distance_measure_val) + assert instance_field_obj.field == field_obj + assert instance_field_obj.options.limit is None # Default options + assert instance_field_obj.options.distance_field is None + + def test_ctor_w_vector_list(self): + field_path = "embedding_field" + distance_measure_val = DistanceMeasure.EUCLIDEAN + + vector_list = [4.0, 5.0] + instance_list_vector = self._make_one( + field_path, vector_list, distance_measure_val + ) + assert isinstance(instance_list_vector.vector, Vector) + assert instance_list_vector.vector == Vector(vector_list) + + def test_repr(self): + field_path = "embedding_field" + vector_val = Vector([1.0, 2.0]) + distance_measure_val = DistanceMeasure.EUCLIDEAN + options_val = stages.FindNearestOptions(limit=5) + instance = self._make_one( + field_path, vector_val, distance_measure_val, options=options_val + ) + repr_str = repr(instance) + expected_repr = "FindNearest(field=Field.of('embedding_field'), vector=Vector<1.0, 2.0>, distance_measure=, options=FindNearestOptions(limit=5))" + assert repr_str == expected_repr + + @pytest.mark.parametrize( + "distance_measure_val, expected_str", + [ + (DistanceMeasure.COSINE, "cosine"), + (DistanceMeasure.DOT_PRODUCT, "dot_product"), + (DistanceMeasure.EUCLIDEAN, "euclidean"), + ], + ) + def test_to_pb(self, distance_measure_val, expected_str): + field_path = "embedding" + vector_val = Vector([0.1, 0.2]) + options_val = stages.FindNearestOptions( + limit=7, distance_field=Field.of("dist_val") + ) + instance = self._make_one( + field_path, vector_val, distance_measure_val, options=options_val + ) + + result = instance._to_pb() + assert result.name == "find_nearest" + assert len(result.args) == 3 + # test field arg + assert result.args[0].field_reference_value == field_path + # test for vector arg + assert result.args[1].map_value.fields["__type__"].string_value == "__vector__" + assert ( + result.args[1].map_value.fields["value"].array_value.values[0].double_value + == 0.1 + ) + assert ( + result.args[1].map_value.fields["value"].array_value.values[1].double_value + == 0.2 + ) + # test for distance measure arg + assert result.args[2].string_value == expected_str + # test options + assert len(result.options) == 2 + assert result.options["limit"].integer_value == 7 + assert result.options["distance_field"].field_reference_value == "dist_val" + + def test_to_pb_no_options(self): + instance = self._make_one("emb", [1.0], DistanceMeasure.DOT_PRODUCT) + result = instance._to_pb() + assert len(result.options) == 0 + assert len(result.args) == 3 + + class TestGenericStage: def _make_one(self, *args, **kwargs): return stages.GenericStage(*args, **kwargs) @@ -119,3 +486,324 @@ def test_to_pb(self): assert result.args[0].boolean_value is True assert result.args[1].string_value == "test" assert len(result.options) == 0 + + +class TestLimit: + def _make_one(self, *args, **kwargs): + return stages.Limit(*args, **kwargs) + + def test_repr(self): + instance = self._make_one(10) + repr_str = repr(instance) + assert repr_str == "Limit(limit=10)" + + def test_to_pb(self): + instance = self._make_one(5) + result = instance._to_pb() + assert result.name == "limit" + assert len(result.args) == 1 + assert result.args[0].integer_value == 5 + assert len(result.options) == 0 + + +class TestOffset: + def _make_one(self, *args, **kwargs): + return stages.Offset(*args, **kwargs) + + def test_repr(self): + instance = self._make_one(20) + repr_str = repr(instance) + assert repr_str == "Offset(offset=20)" + + def test_to_pb(self): + instance = self._make_one(3) + result = instance._to_pb() + assert result.name == "offset" + assert len(result.args) == 1 + assert result.args[0].integer_value == 3 + assert len(result.options) == 0 + + +class TestRemoveFields: + def _make_one(self, *args, **kwargs): + return stages.RemoveFields(*args, **kwargs) + + def test_ctor(self): + field1 = Field.of("field1") + instance = self._make_one("field2", field1) + assert len(instance.fields) == 2 + assert isinstance(instance.fields[0], Field) + assert instance.fields[0].path == "field2" + assert instance.fields[1] == field1 + assert instance.name == "remove_fields" + + def test_repr(self): + instance = self._make_one("field1", Field.of("field2")) + repr_str = repr(instance) + assert repr_str == "RemoveFields(Field.of('field1'), Field.of('field2'))" + + def test_to_pb(self): + instance = self._make_one("field1", Field.of("field2")) + result = instance._to_pb() + assert result.name == "remove_fields" + assert len(result.args) == 2 + assert result.args[0].field_reference_value == "field1" + assert result.args[1].field_reference_value == "field2" + assert len(result.options) == 0 + + +class TestSample: + class TestSampleOptions: + def test_ctor_percent(self): + instance = stages.SampleOptions(0.25, stages.SampleOptions.Mode.PERCENT) + assert instance.value == 0.25 + assert instance.mode == stages.SampleOptions.Mode.PERCENT + + def test_ctor_documents(self): + instance = stages.SampleOptions(10, stages.SampleOptions.Mode.DOCUMENTS) + assert instance.value == 10 + assert instance.mode == stages.SampleOptions.Mode.DOCUMENTS + + def test_percentage(self): + instance = stages.SampleOptions.percentage(1) + assert instance.value == 1 + assert instance.mode == stages.SampleOptions.Mode.PERCENT + + def test_doc_limit(self): + instance = stages.SampleOptions.doc_limit(2) + assert instance.value == 2 + assert instance.mode == stages.SampleOptions.Mode.DOCUMENTS + + def test_repr_percentage(self): + instance = stages.SampleOptions.percentage(0.5) + assert repr(instance) == "SampleOptions.percentage(0.5)" + + def test_repr_documents(self): + instance = stages.SampleOptions.doc_limit(10) + assert repr(instance) == "SampleOptions.doc_limit(10)" + + def _make_one(self, *args, **kwargs): + return stages.Sample(*args, **kwargs) + + def test_ctor_w_int(self): + instance_int = self._make_one(10) + assert isinstance(instance_int.options, stages.SampleOptions) + assert instance_int.options.value == 10 + assert instance_int.options.mode == stages.SampleOptions.Mode.DOCUMENTS + assert instance_int.name == "sample" + + def test_ctor_w_options(self): + options = stages.SampleOptions.percentage(0.5) + instance_options = self._make_one(options) + assert instance_options.options == options + assert instance_options.name == "sample" + + def test_repr(self): + instance_int = self._make_one(10) + repr_str_int = repr(instance_int) + assert repr_str_int == "Sample(options=SampleOptions.doc_limit(10))" + + options = stages.SampleOptions.percentage(0.5) + instance_options = self._make_one(options) + repr_str_options = repr(instance_options) + assert repr_str_options == "Sample(options=SampleOptions.percentage(0.5))" + + def test_to_pb_documents_mode(self): + instance_docs = self._make_one(10) + result_docs = instance_docs._to_pb() + assert result_docs.name == "sample" + assert len(result_docs.args) == 2 + assert result_docs.args[0].integer_value == 10 + assert result_docs.args[1].string_value == "documents" + assert len(result_docs.options) == 0 + + def test_to_pb_percent_mode(self): + options_percent = stages.SampleOptions.percentage(0.25) + instance_percent = self._make_one(options_percent) + result_percent = instance_percent._to_pb() + assert result_percent.name == "sample" + assert len(result_percent.args) == 2 + assert result_percent.args[0].double_value == 0.25 + assert result_percent.args[1].string_value == "percent" + assert len(result_percent.options) == 0 + + +class TestSelect: + def _make_one(self, *args, **kwargs): + return stages.Select(*args, **kwargs) + + def test_repr(self): + instance = self._make_one("field1", Field.of("field2")) + repr_str = repr(instance) + assert ( + repr_str == "Select(projections=[Field.of('field1'), Field.of('field2')])" + ) + + def test_to_pb(self): + instance = self._make_one("field1", "field2.subfield", Field.of("field3")) + result = instance._to_pb() + assert result.name == "select" + assert len(result.args) == 1 + got_map = result.args[0].map_value.fields + assert got_map.get("field1").field_reference_value == "field1" + assert got_map.get("field2.subfield").field_reference_value == "field2.subfield" + assert got_map.get("field3").field_reference_value == "field3" + assert len(result.options) == 0 + + +class TestSort: + def _make_one(self, *args, **kwargs): + return stages.Sort(*args, **kwargs) + + def test_repr(self): + order1 = Ordering(Field.of("field1"), "ASCENDING") + instance = self._make_one(order1) + repr_str = repr(instance) + assert repr_str == "Sort(orders=[Field.of('field1').ascending()])" + + def test_to_pb(self): + order1 = Ordering(Field.of("name"), "ASCENDING") + order2 = Ordering(Field.of("age"), "DESCENDING") + instance = self._make_one(order1, order2) + result = instance._to_pb() + assert result.name == "sort" + assert len(result.args) == 2 + got_map = result.args[0].map_value.fields + assert got_map.get("expression").field_reference_value == "name" + assert got_map.get("direction").string_value == "ascending" + assert len(result.options) == 0 + + +class TestUnion: + def _make_one(self, *args, **kwargs): + return stages.Union(*args, **kwargs) + + def test_ctor(self): + mock_pipeline = mock.Mock(spec=_BasePipeline) + instance = self._make_one(mock_pipeline) + assert instance.other == mock_pipeline + assert instance.name == "union" + + def test_repr(self): + test_pipeline = _BasePipeline(mock.Mock()).sample(5) + instance = self._make_one(test_pipeline) + repr_str = repr(instance) + assert repr_str == f"Union(other={test_pipeline!r})" + + def test_to_pb(self): + test_pipeline = _BasePipeline(mock.Mock()).sample(5) + + instance = self._make_one(test_pipeline) + result = instance._to_pb() + + assert result.name == "union" + assert len(result.args) == 1 + assert result.args[0].pipeline_value == test_pipeline._to_pb().pipeline + assert len(result.options) == 0 + + +class TestUnnest: + class TestUnnestOptions: + def _make_one_options(self, *args, **kwargs): + return stages.UnnestOptions(*args, **kwargs) + + def test_ctor_options(self): + index_field_val = "my_index" + instance = self._make_one_options(index_field=index_field_val) + assert instance.index_field == index_field_val + + def test_repr(self): + instance = self._make_one_options(index_field="my_idx") + repr_str = repr(instance) + assert repr_str == "UnnestOptions(index_field='my_idx')" + + def _make_one(self, *args, **kwargs): + return stages.Unnest(*args, **kwargs) + + def test_ctor(self): + instance = self._make_one("my_field") + assert isinstance(instance.field, Field) + assert instance.field.path == "my_field" + assert isinstance(instance.alias, Field) + assert instance.alias.path == "my_field" + assert instance.options is None + assert instance.name == "unnest" + + def test_ctor_full(self): + """constructor with alias and options set""" + field = Field.of("items") + alias = Field.of("alias") + options = stages.UnnestOptions(index_field="item_index") + instance = self._make_one(field, alias, options=options) + assert isinstance(field, Field) + assert instance.field == field + assert isinstance(alias, Field) + assert instance.alias == alias + assert instance.options == options + assert instance.name == "unnest" + + def test_repr(self): + instance_simple = self._make_one("my_field") + repr_str_simple = repr(instance_simple) + assert ( + repr_str_simple + == "Unnest(field=Field.of('my_field'), alias=Field.of('my_field'), options=None)" + ) + + options = stages.UnnestOptions(index_field="item_idx") + instance_full = self._make_one( + Field.of("items"), Field.of("alias"), options=options + ) + repr_str_full = repr(instance_full) + assert ( + repr_str_full + == "Unnest(field=Field.of('items'), alias=Field.of('alias'), options=UnnestOptions(index_field='item_idx'))" + ) + + def test_to_pb(self): + instance = self._make_one(Field.of("dataPoints")) + result = instance._to_pb() + assert result.name == "unnest" + assert len(result.args) == 2 + assert result.args[0].field_reference_value == "dataPoints" + assert result.args[1].field_reference_value == "dataPoints" + assert len(result.options) == 0 + + def test_to_pb_full(self): + field_str = "items" + alias_str = "single_item" + options_val = stages.UnnestOptions(index_field="item_index") + instance = self._make_one(field_str, alias_str, options=options_val) + + result = instance._to_pb() + assert result.name == "unnest" + assert len(result.args) == 2 + assert result.args[0].field_reference_value == field_str + assert result.args[1].field_reference_value == alias_str + + assert len(result.options) == 1 + assert result.options["index_field"].string_value == "item_index" + + +class TestWhere: + def _make_one(self, *args, **kwargs): + return stages.Where(*args, **kwargs) + + def test_repr(self): + condition = Field.of("age").gt(30) + instance = self._make_one(condition) + repr_str = repr(instance) + assert repr_str == "Where(condition=Field.of('age').gt(Constant.of(30)))" + + def test_to_pb(self): + condition = Field.of("city").eq("SF") + instance = self._make_one(condition) + result = instance._to_pb() + assert result.name == "where" + assert len(result.args) == 1 + got_fn = result.args[0].function_value + assert got_fn.name == "eq" + assert len(got_fn.args) == 2 + assert got_fn.args[0].field_reference_value == "city" + assert got_fn.args[1].string_value == "SF" + assert len(result.options) == 0 From cd578c12ec78eaf37acbac6849eff6dbee303a14 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 22 Oct 2025 16:53:46 -0700 Subject: [PATCH 04/27] chore: update and refactor pipeline expressions (#1111) --- google/cloud/firestore_v1/_pipeline_stages.py | 43 +- google/cloud/firestore_v1/base_pipeline.py | 19 +- .../firestore_v1/pipeline_expressions.py | 1866 +++++------------ tests/system/pipeline_e2e.yaml | 558 +++-- tests/system/test_pipeline_acceptance.py | 67 +- tests/unit/v1/test_async_pipeline.py | 3 +- tests/unit/v1/test_pipeline.py | 3 +- tests/unit/v1/test_pipeline_expressions.py | 901 ++++---- tests/unit/v1/test_pipeline_stages.py | 32 +- 9 files changed, 1464 insertions(+), 2028 deletions(-) diff --git a/google/cloud/firestore_v1/_pipeline_stages.py b/google/cloud/firestore_v1/_pipeline_stages.py index f7d311d89..aefddbcf8 100644 --- a/google/cloud/firestore_v1/_pipeline_stages.py +++ b/google/cloud/firestore_v1/_pipeline_stages.py @@ -23,11 +23,12 @@ from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.pipeline_expressions import ( - Accumulator, + AggregateFunction, Expr, - ExprWithAlias, + AliasedAggregate, + AliasedExpr, Field, - FilterCondition, + BooleanExpr, Selectable, Ordering, ) @@ -156,13 +157,7 @@ def __init__(self, *fields: Selectable): self.fields = list(fields) def _pb_args(self): - return [ - Value( - map_value={ - "fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]} - } - ) - ] + return [Selectable._to_value(self.fields)] class Aggregate(Stage): @@ -170,8 +165,8 @@ class Aggregate(Stage): def __init__( self, - *args: ExprWithAlias[Accumulator], - accumulators: Sequence[ExprWithAlias[Accumulator]] = (), + *args: AliasedExpr[AggregateFunction], + accumulators: Sequence[AliasedAggregate] = (), groups: Sequence[str | Selectable] = (), ): super().__init__() @@ -186,18 +181,8 @@ def __init__( def _pb_args(self): return [ - Value( - map_value={ - "fields": { - m[0]: m[1] for m in [f._to_map() for f in self.accumulators] - } - } - ), - Value( - map_value={ - "fields": {m[0]: m[1] for m in [f._to_map() for f in self.groups]} - } - ), + Selectable._to_value(self.accumulators), + Selectable._to_value(self.groups), ] def __repr__(self): @@ -254,13 +239,7 @@ def __init__(self, *fields: str | Selectable): ] def _pb_args(self) -> list[Value]: - return [ - Value( - map_value={ - "fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]} - } - ) - ] + return [Selectable._to_value(self.fields)] class Documents(Stage): @@ -461,7 +440,7 @@ def _pb_options(self): class Where(Stage): """Filters documents based on a specified condition.""" - def __init__(self, condition: FilterCondition): + def __init__(self, condition: BooleanExpr): super().__init__() self.condition = condition diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 50ae7ab62..01f48ee78 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -23,11 +23,10 @@ from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest from google.cloud.firestore_v1.pipeline_result import PipelineResult from google.cloud.firestore_v1.pipeline_expressions import ( - Accumulator, + AliasedAggregate, Expr, - ExprWithAlias, Field, - FilterCondition, + BooleanExpr, Selectable, ) from google.cloud.firestore_v1 import _helpers @@ -220,14 +219,14 @@ def select(self, *selections: str | Selectable) -> "_BasePipeline": """ return self._append(stages.Select(*selections)) - def where(self, condition: FilterCondition) -> "_BasePipeline": + def where(self, condition: BooleanExpr) -> "_BasePipeline": """ Filters the documents from previous stages to only include those matching - the specified `FilterCondition`. + the specified `BooleanExpr`. This stage allows you to apply conditions to the data, similar to a "WHERE" clause in SQL. You can filter documents based on their field values, using - implementations of `FilterCondition`, typically including but not limited to: + implementations of `BooleanExpr`, typically including but not limited to: - field comparators: `eq`, `lt` (less than), `gt` (greater than), etc. - logical operators: `And`, `Or`, `Not`, etc. - advanced functions: `regex_matches`, `array_contains`, etc. @@ -252,7 +251,7 @@ def where(self, condition: FilterCondition) -> "_BasePipeline": Args: - condition: The `FilterCondition` to apply. + condition: The `BooleanExpr` to apply. Returns: A new Pipeline object with this stage appended to the stage list @@ -531,7 +530,7 @@ def limit(self, limit: int) -> "_BasePipeline": def aggregate( self, - *accumulators: ExprWithAlias[Accumulator], + *accumulators: AliasedAggregate, groups: Sequence[str | Selectable] = (), ) -> "_BasePipeline": """ @@ -541,7 +540,7 @@ def aggregate( This stage allows you to calculate aggregate values (like sum, average, count, min, max) over a set of documents. - - **Accumulators:** Define the aggregation calculations using `Accumulator` + - **Accumulators:** Define the aggregation calculations using `AggregateFunction` expressions (e.g., `sum()`, `avg()`, `count()`, `min()`, `max()`) combined with `as_()` to name the result field. - **Groups:** Optionally specify fields (by name or `Selectable`) to group @@ -569,7 +568,7 @@ def aggregate( Args: - *accumulators: One or more `ExprWithAlias[Accumulator]` expressions defining + *accumulators: One or more `AliasedAggregate` expressions defining the aggregations to perform and their output names. groups: An optional sequence of field names (str) or `Selectable` expressions to group by before aggregating. diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 70d619d3b..ef57f5b72 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -15,7 +15,6 @@ from __future__ import annotations from typing import ( Any, - List, Generic, TypeVar, Dict, @@ -117,7 +116,35 @@ def _to_pb(self) -> Value: def _cast_to_expr_or_convert_to_constant(o: Any) -> "Expr": return o if isinstance(o, Expr) else Constant(o) - def add(self, other: Expr | float) -> "Add": + class expose_as_static: + """ + Decorator to mark instance methods to be exposed as static methods as well as instance + methods. + + When called statically, the first argument is converted to a Field expression if needed. + + Example: + >>> Field.of("test").add(5) + >>> Function.add("test", 5) + """ + + def __init__(self, instance_func): + self.instance_func = instance_func + + def static_func(self, first_arg, *other_args, **kwargs): + first_expr = ( + Field.of(first_arg) if not isinstance(first_arg, Expr) else first_arg + ) + return self.instance_func(first_expr, *other_args, **kwargs) + + def __get__(self, instance, owner): + if instance is None: + return self.static_func.__get__(instance, owner) + else: + return self.instance_func.__get__(instance, owner) + + @expose_as_static + def add(self, other: Expr | float) -> "Expr": """Creates an expression that adds this expression to another expression or constant. Example: @@ -132,9 +159,10 @@ def add(self, other: Expr | float) -> "Add": Returns: A new `Expr` representing the addition operation. """ - return Add(self, self._cast_to_expr_or_convert_to_constant(other)) + return Function("add", [self, self._cast_to_expr_or_convert_to_constant(other)]) - def subtract(self, other: Expr | float) -> "Subtract": + @expose_as_static + def subtract(self, other: Expr | float) -> "Expr": """Creates an expression that subtracts another expression or constant from this expression. Example: @@ -149,9 +177,12 @@ def subtract(self, other: Expr | float) -> "Subtract": Returns: A new `Expr` representing the subtraction operation. """ - return Subtract(self, self._cast_to_expr_or_convert_to_constant(other)) + return Function( + "subtract", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) - def multiply(self, other: Expr | float) -> "Multiply": + @expose_as_static + def multiply(self, other: Expr | float) -> "Expr": """Creates an expression that multiplies this expression by another expression or constant. Example: @@ -166,9 +197,12 @@ def multiply(self, other: Expr | float) -> "Multiply": Returns: A new `Expr` representing the multiplication operation. """ - return Multiply(self, self._cast_to_expr_or_convert_to_constant(other)) + return Function( + "multiply", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) - def divide(self, other: Expr | float) -> "Divide": + @expose_as_static + def divide(self, other: Expr | float) -> "Expr": """Creates an expression that divides this expression by another expression or constant. Example: @@ -183,9 +217,12 @@ def divide(self, other: Expr | float) -> "Divide": Returns: A new `Expr` representing the division operation. """ - return Divide(self, self._cast_to_expr_or_convert_to_constant(other)) + return Function( + "divide", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) - def mod(self, other: Expr | float) -> "Mod": + @expose_as_static + def mod(self, other: Expr | float) -> "Expr": """Creates an expression that calculates the modulo (remainder) to another expression or constant. Example: @@ -200,9 +237,10 @@ def mod(self, other: Expr | float) -> "Mod": Returns: A new `Expr` representing the modulo operation. """ - return Mod(self, self._cast_to_expr_or_convert_to_constant(other)) + return Function("mod", [self, self._cast_to_expr_or_convert_to_constant(other)]) - def logical_max(self, other: Expr | CONSTANT_TYPE) -> "LogicalMax": + @expose_as_static + def logical_maximum(self, other: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that returns the larger value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -211,19 +249,24 @@ def logical_max(self, other: Expr | CONSTANT_TYPE) -> "LogicalMax": Example: >>> # Returns the larger value between the 'discount' field and the 'cap' field. - >>> Field.of("discount").logical_max(Field.of("cap")) + >>> Field.of("discount").logical_maximum(Field.of("cap")) >>> # Returns the larger value between the 'value' field and 10. - >>> Field.of("value").logical_max(10) + >>> Field.of("value").logical_maximum(10) Args: other: The other expression or constant value to compare with. Returns: - A new `Expr` representing the logical max operation. + A new `Expr` representing the logical maximum operation. """ - return LogicalMax(self, self._cast_to_expr_or_convert_to_constant(other)) + return Function( + "maximum", + [self, self._cast_to_expr_or_convert_to_constant(other)], + infix_name_override="logical_maximum", + ) - def logical_min(self, other: Expr | CONSTANT_TYPE) -> "LogicalMin": + @expose_as_static + def logical_minimum(self, other: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that returns the smaller value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -232,27 +275,32 @@ def logical_min(self, other: Expr | CONSTANT_TYPE) -> "LogicalMin": Example: >>> # Returns the smaller value between the 'discount' field and the 'floor' field. - >>> Field.of("discount").logical_min(Field.of("floor")) + >>> Field.of("discount").logical_minimum(Field.of("floor")) >>> # Returns the smaller value between the 'value' field and 10. - >>> Field.of("value").logical_min(10) + >>> Field.of("value").logical_minimum(10) Args: other: The other expression or constant value to compare with. Returns: - A new `Expr` representing the logical min operation. + A new `Expr` representing the logical minimum operation. """ - return LogicalMin(self, self._cast_to_expr_or_convert_to_constant(other)) + return Function( + "minimum", + [self, self._cast_to_expr_or_convert_to_constant(other)], + infix_name_override="logical_minimum", + ) - def eq(self, other: Expr | CONSTANT_TYPE) -> "Eq": + @expose_as_static + def equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to another expression or constant value. Example: >>> # Check if the 'age' field is equal to 21 - >>> Field.of("age").eq(21) + >>> Field.of("age").equal(21) >>> # Check if the 'city' field is equal to "London" - >>> Field.of("city").eq("London") + >>> Field.of("city").equal("London") Args: other: The expression or constant value to compare for equality. @@ -260,17 +308,20 @@ def eq(self, other: Expr | CONSTANT_TYPE) -> "Eq": Returns: A new `Expr` representing the equality comparison. """ - return Eq(self, self._cast_to_expr_or_convert_to_constant(other)) + return BooleanExpr( + "equal", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) - def neq(self, other: Expr | CONSTANT_TYPE) -> "Neq": + @expose_as_static + def not_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to another expression or constant value. Example: >>> # Check if the 'status' field is not equal to "completed" - >>> Field.of("status").neq("completed") + >>> Field.of("status").not_equal("completed") >>> # Check if the 'country' field is not equal to "USA" - >>> Field.of("country").neq("USA") + >>> Field.of("country").not_equal("USA") Args: other: The expression or constant value to compare for inequality. @@ -278,17 +329,20 @@ def neq(self, other: Expr | CONSTANT_TYPE) -> "Neq": Returns: A new `Expr` representing the inequality comparison. """ - return Neq(self, self._cast_to_expr_or_convert_to_constant(other)) + return BooleanExpr( + "not_equal", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) - def gt(self, other: Expr | CONSTANT_TYPE) -> "Gt": + @expose_as_static + def greater_than(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is greater than another expression or constant value. Example: >>> # Check if the 'age' field is greater than the 'limit' field - >>> Field.of("age").gt(Field.of("limit")) + >>> Field.of("age").greater_than(Field.of("limit")) >>> # Check if the 'price' field is greater than 100 - >>> Field.of("price").gt(100) + >>> Field.of("price").greater_than(100) Args: other: The expression or constant value to compare for greater than. @@ -296,17 +350,20 @@ def gt(self, other: Expr | CONSTANT_TYPE) -> "Gt": Returns: A new `Expr` representing the greater than comparison. """ - return Gt(self, self._cast_to_expr_or_convert_to_constant(other)) + return BooleanExpr( + "greater_than", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) - def gte(self, other: Expr | CONSTANT_TYPE) -> "Gte": + @expose_as_static + def greater_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is greater than or equal to another expression or constant value. Example: >>> # Check if the 'quantity' field is greater than or equal to field 'requirement' plus 1 - >>> Field.of("quantity").gte(Field.of('requirement').add(1)) + >>> Field.of("quantity").greater_than_or_equal(Field.of('requirement').add(1)) >>> # Check if the 'score' field is greater than or equal to 80 - >>> Field.of("score").gte(80) + >>> Field.of("score").greater_than_or_equal(80) Args: other: The expression or constant value to compare for greater than or equal to. @@ -314,17 +371,21 @@ def gte(self, other: Expr | CONSTANT_TYPE) -> "Gte": Returns: A new `Expr` representing the greater than or equal to comparison. """ - return Gte(self, self._cast_to_expr_or_convert_to_constant(other)) + return BooleanExpr( + "greater_than_or_equal", + [self, self._cast_to_expr_or_convert_to_constant(other)], + ) - def lt(self, other: Expr | CONSTANT_TYPE) -> "Lt": + @expose_as_static + def less_than(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is less than another expression or constant value. Example: >>> # Check if the 'age' field is less than 'limit' - >>> Field.of("age").lt(Field.of('limit')) + >>> Field.of("age").less_than(Field.of('limit')) >>> # Check if the 'price' field is less than 50 - >>> Field.of("price").lt(50) + >>> Field.of("price").less_than(50) Args: other: The expression or constant value to compare for less than. @@ -332,17 +393,20 @@ def lt(self, other: Expr | CONSTANT_TYPE) -> "Lt": Returns: A new `Expr` representing the less than comparison. """ - return Lt(self, self._cast_to_expr_or_convert_to_constant(other)) + return BooleanExpr( + "less_than", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) - def lte(self, other: Expr | CONSTANT_TYPE) -> "Lte": + @expose_as_static + def less_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is less than or equal to another expression or constant value. Example: >>> # Check if the 'quantity' field is less than or equal to 20 - >>> Field.of("quantity").lte(Constant.of(20)) + >>> Field.of("quantity").less_than_or_equal(Constant.of(20)) >>> # Check if the 'score' field is less than or equal to 70 - >>> Field.of("score").lte(70) + >>> Field.of("score").less_than_or_equal(70) Args: other: The expression or constant value to compare for less than or equal to. @@ -350,15 +414,19 @@ def lte(self, other: Expr | CONSTANT_TYPE) -> "Lte": Returns: A new `Expr` representing the less than or equal to comparison. """ - return Lte(self, self._cast_to_expr_or_convert_to_constant(other)) + return BooleanExpr( + "less_than_or_equal", + [self, self._cast_to_expr_or_convert_to_constant(other)], + ) - def in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "In": + @expose_as_static + def equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to any of the provided values or expressions. Example: >>> # Check if the 'category' field is either "Electronics" or value of field 'primaryType' - >>> Field.of("category").in_any(["Electronics", Field.of("primaryType")]) + >>> Field.of("category").equal_any(["Electronics", Field.of("primaryType")]) Args: array: The values or expressions to check against. @@ -366,25 +434,43 @@ def in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "In": Returns: A new `Expr` representing the 'IN' comparison. """ - return In(self, [self._cast_to_expr_or_convert_to_constant(v) for v in array]) + return BooleanExpr( + "equal_any", + [ + self, + _ListOfExprs( + [self._cast_to_expr_or_convert_to_constant(v) for v in array] + ), + ], + ) - def not_in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "Not": + @expose_as_static + def not_equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to any of the provided values or expressions. Example: >>> # Check if the 'status' field is neither "pending" nor "cancelled" - >>> Field.of("status").not_in_any(["pending", "cancelled"]) + >>> Field.of("status").not_equal_any(["pending", "cancelled"]) Args: - *others: The values or expressions to check against. + array: The values or expressions to check against. Returns: A new `Expr` representing the 'NOT IN' comparison. """ - return Not(self.in_any(array)) + return BooleanExpr( + "not_equal_any", + [ + self, + _ListOfExprs( + [self._cast_to_expr_or_convert_to_constant(v) for v in array] + ), + ], + ) - def array_contains(self, element: Expr | CONSTANT_TYPE) -> "ArrayContains": + @expose_as_static + def array_contains(self, element: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if an array contains a specific element or value. Example: @@ -399,11 +485,15 @@ def array_contains(self, element: Expr | CONSTANT_TYPE) -> "ArrayContains": Returns: A new `Expr` representing the 'array_contains' comparison. """ - return ArrayContains(self, self._cast_to_expr_or_convert_to_constant(element)) + return BooleanExpr( + "array_contains", [self, self._cast_to_expr_or_convert_to_constant(element)] + ) + @expose_as_static def array_contains_all( - self, elements: List[Expr | CONSTANT_TYPE] - ) -> "ArrayContainsAll": + self, + elements: Sequence[Expr | CONSTANT_TYPE], + ) -> "BooleanExpr": """Creates an expression that checks if an array contains all the specified elements. Example: @@ -418,13 +508,21 @@ def array_contains_all( Returns: A new `Expr` representing the 'array_contains_all' comparison. """ - return ArrayContainsAll( - self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements] + return BooleanExpr( + "array_contains_all", + [ + self, + _ListOfExprs( + [self._cast_to_expr_or_convert_to_constant(e) for e in elements] + ), + ], ) + @expose_as_static def array_contains_any( - self, elements: List[Expr | CONSTANT_TYPE] - ) -> "ArrayContainsAny": + self, + elements: Sequence[Expr | CONSTANT_TYPE], + ) -> "BooleanExpr": """Creates an expression that checks if an array contains any of the specified elements. Example: @@ -440,11 +538,18 @@ def array_contains_any( Returns: A new `Expr` representing the 'array_contains_any' comparison. """ - return ArrayContainsAny( - self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements] + return BooleanExpr( + "array_contains_any", + [ + self, + _ListOfExprs( + [self._cast_to_expr_or_convert_to_constant(e) for e in elements] + ), + ], ) - def array_length(self) -> "ArrayLength": + @expose_as_static + def array_length(self) -> "Expr": """Creates an expression that calculates the length of an array. Example: @@ -454,9 +559,10 @@ def array_length(self) -> "ArrayLength": Returns: A new `Expr` representing the length of the array. """ - return ArrayLength(self) + return Function("array_length", [self]) - def array_reverse(self) -> "ArrayReverse": + @expose_as_static + def array_reverse(self) -> "Expr": """Creates an expression that returns the reversed content of an array. Example: @@ -466,9 +572,10 @@ def array_reverse(self) -> "ArrayReverse": Returns: A new `Expr` representing the reversed array. """ - return ArrayReverse(self) + return Function("array_reverse", [self]) - def is_nan(self) -> "IsNaN": + @expose_as_static + def is_nan(self) -> "BooleanExpr": """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). Example: @@ -478,9 +585,10 @@ def is_nan(self) -> "IsNaN": Returns: A new `Expr` representing the 'isNaN' check. """ - return IsNaN(self) + return BooleanExpr("is_nan", [self]) - def exists(self) -> "Exists": + @expose_as_static + def exists(self) -> "BooleanExpr": """Creates an expression that checks if a field exists in the document. Example: @@ -490,9 +598,10 @@ def exists(self) -> "Exists": Returns: A new `Expr` representing the 'exists' check. """ - return Exists(self) + return BooleanExpr("exists", [self]) - def sum(self) -> "Sum": + @expose_as_static + def sum(self) -> "Expr": """Creates an aggregation that calculates the sum of a numeric field across multiple stage inputs. Example: @@ -500,24 +609,25 @@ def sum(self) -> "Sum": >>> Field.of("orderAmount").sum().as_("totalRevenue") Returns: - A new `Accumulator` representing the 'sum' aggregation. + A new `AggregateFunction` representing the 'sum' aggregation. """ - return Sum(self) + return AggregateFunction("sum", [self]) - def avg(self) -> "Avg": + @expose_as_static + def average(self) -> "Expr": """Creates an aggregation that calculates the average (mean) of a numeric field across multiple stage inputs. Example: >>> # Calculate the average age of users - >>> Field.of("age").avg().as_("averageAge") + >>> Field.of("age").average().as_("averageAge") Returns: - A new `Accumulator` representing the 'avg' aggregation. + A new `AggregateFunction` representing the 'avg' aggregation. """ - return Avg(self) + return AggregateFunction("average", [self]) - def count(self) -> "Count": + def count(self) -> "Expr": """Creates an aggregation that counts the number of stage inputs with valid evaluations of the expression or field. @@ -526,35 +636,38 @@ def count(self) -> "Count": >>> Field.of("productId").count().as_("totalProducts") Returns: - A new `Accumulator` representing the 'count' aggregation. + A new `AggregateFunction` representing the 'count' aggregation. """ - return Count(self) + return AggregateFunction("count", [self]) - def min(self) -> "Min": + @expose_as_static + def minimum(self) -> "Expr": """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. Example: >>> # Find the lowest price of all products - >>> Field.of("price").min().as_("lowestPrice") + >>> Field.of("price").minimum().as_("lowestPrice") Returns: - A new `Accumulator` representing the 'min' aggregation. + A new `AggregateFunction` representing the 'minimum' aggregation. """ - return Min(self) + return AggregateFunction("minimum", [self]) - def max(self) -> "Max": + @expose_as_static + def maximum(self) -> "Expr": """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. Example: >>> # Find the highest score in a leaderboard - >>> Field.of("score").max().as_("highestScore") + >>> Field.of("score").maximum().as_("highestScore") Returns: - A new `Accumulator` representing the 'max' aggregation. + A new `AggregateFunction` representing the 'maximum' aggregation. """ - return Max(self) + return AggregateFunction("maximum", [self]) - def char_length(self) -> "CharLength": + @expose_as_static + def char_length(self) -> "Expr": """Creates an expression that calculates the character length of a string. Example: @@ -564,9 +677,10 @@ def char_length(self) -> "CharLength": Returns: A new `Expr` representing the length of the string. """ - return CharLength(self) + return Function("char_length", [self]) - def byte_length(self) -> "ByteLength": + @expose_as_static + def byte_length(self) -> "Expr": """Creates an expression that calculates the byte length of a string in its UTF-8 form. Example: @@ -576,9 +690,10 @@ def byte_length(self) -> "ByteLength": Returns: A new `Expr` representing the byte length of the string. """ - return ByteLength(self) + return Function("byte_length", [self]) - def like(self, pattern: Expr | str) -> "Like": + @expose_as_static + def like(self, pattern: Expr | str) -> "BooleanExpr": """Creates an expression that performs a case-sensitive string comparison. Example: @@ -593,9 +708,12 @@ def like(self, pattern: Expr | str) -> "Like": Returns: A new `Expr` representing the 'like' comparison. """ - return Like(self, self._cast_to_expr_or_convert_to_constant(pattern)) + return BooleanExpr( + "like", [self, self._cast_to_expr_or_convert_to_constant(pattern)] + ) - def regex_contains(self, regex: Expr | str) -> "RegexContains": + @expose_as_static + def regex_contains(self, regex: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string contains a specified regular expression as a substring. @@ -611,16 +729,19 @@ def regex_contains(self, regex: Expr | str) -> "RegexContains": Returns: A new `Expr` representing the 'contains' comparison. """ - return RegexContains(self, self._cast_to_expr_or_convert_to_constant(regex)) + return BooleanExpr( + "regex_contains", [self, self._cast_to_expr_or_convert_to_constant(regex)] + ) - def regex_matches(self, regex: Expr | str) -> "RegexMatch": + @expose_as_static + def regex_match(self, regex: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string matches a specified regular expression. Example: >>> # Check if the 'email' field matches a valid email pattern - >>> Field.of("email").regex_matches("[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}") + >>> Field.of("email").regex_match("[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}") >>> # Check if the 'email' field matches a regular expression stored in field 'regex' - >>> Field.of("email").regex_matches(Field.of("regex")) + >>> Field.of("email").regex_match(Field.of("regex")) Args: regex: The regular expression (string or expression) to use for the match. @@ -628,16 +749,19 @@ def regex_matches(self, regex: Expr | str) -> "RegexMatch": Returns: A new `Expr` representing the regular expression match. """ - return RegexMatch(self, self._cast_to_expr_or_convert_to_constant(regex)) + return BooleanExpr( + "regex_match", [self, self._cast_to_expr_or_convert_to_constant(regex)] + ) - def str_contains(self, substring: Expr | str) -> "StrContains": + @expose_as_static + def string_contains(self, substring: Expr | str) -> "BooleanExpr": """Creates an expression that checks if this string expression contains a specified substring. Example: >>> # Check if the 'description' field contains "example". - >>> Field.of("description").str_contains("example") + >>> Field.of("description").string_contains("example") >>> # Check if the 'description' field contains the value of the 'keyword' field. - >>> Field.of("description").str_contains(Field.of("keyword")) + >>> Field.of("description").string_contains(Field.of("keyword")) Args: substring: The substring (string or expression) to use for the search. @@ -645,9 +769,13 @@ def str_contains(self, substring: Expr | str) -> "StrContains": Returns: A new `Expr` representing the 'contains' comparison. """ - return StrContains(self, self._cast_to_expr_or_convert_to_constant(substring)) + return BooleanExpr( + "string_contains", + [self, self._cast_to_expr_or_convert_to_constant(substring)], + ) - def starts_with(self, prefix: Expr | str) -> "StartsWith": + @expose_as_static + def starts_with(self, prefix: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string starts with a given prefix. Example: @@ -662,9 +790,12 @@ def starts_with(self, prefix: Expr | str) -> "StartsWith": Returns: A new `Expr` representing the 'starts with' comparison. """ - return StartsWith(self, self._cast_to_expr_or_convert_to_constant(prefix)) + return BooleanExpr( + "starts_with", [self, self._cast_to_expr_or_convert_to_constant(prefix)] + ) - def ends_with(self, postfix: Expr | str) -> "EndsWith": + @expose_as_static + def ends_with(self, postfix: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string ends with a given postfix. Example: @@ -679,14 +810,17 @@ def ends_with(self, postfix: Expr | str) -> "EndsWith": Returns: A new `Expr` representing the 'ends with' comparison. """ - return EndsWith(self, self._cast_to_expr_or_convert_to_constant(postfix)) + return BooleanExpr( + "ends_with", [self, self._cast_to_expr_or_convert_to_constant(postfix)] + ) - def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": + @expose_as_static + def string_concat(self, *elements: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that concatenates string expressions, fields or constants together. Example: >>> # Combine the 'firstName', " ", and 'lastName' fields into a single string - >>> Field.of("firstName").str_concat(" ", Field.of("lastName")) + >>> Field.of("firstName").string_concat(" ", Field.of("lastName")) Args: *elements: The expressions or constants (typically strings) to concatenate. @@ -694,16 +828,17 @@ def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": Returns: A new `Expr` representing the concatenated string. """ - return StrConcat( - self, *[self._cast_to_expr_or_convert_to_constant(el) for el in elements] + return Function( + "string_concat", + [self] + [self._cast_to_expr_or_convert_to_constant(el) for el in elements], ) - def map_get(self, key: str) -> "MapGet": - """Accesses a value from a map (object) field using the provided key. + @expose_as_static + def map_get(self, key: str | Constant[str]) -> "Expr": + """Accesses a value from the map produced by evaluating this expression. Example: - >>> # Get the 'city' value from - >>> # the 'address' map field + >>> Expr.map({"city": "London"}).map_get("city") >>> Field.of("address").map_get("city") Args: @@ -712,9 +847,12 @@ def map_get(self, key: str) -> "MapGet": Returns: A new `Expr` representing the value associated with the given key in the map. """ - return MapGet(self, Constant.of(key)) + return Function( + "map_get", [self, Constant.of(key) if isinstance(key, str) else key] + ) - def vector_length(self) -> "VectorLength": + @expose_as_static + def vector_length(self) -> "Expr": """Creates an expression that calculates the length (dimension) of a Firestore Vector. Example: @@ -724,9 +862,10 @@ def vector_length(self) -> "VectorLength": Returns: A new `Expr` representing the length of the vector. """ - return VectorLength(self) + return Function("vector_length", [self]) - def timestamp_to_unix_micros(self) -> "TimestampToUnixMicros": + @expose_as_static + def timestamp_to_unix_micros(self) -> "Expr": """Creates an expression that converts a timestamp to the number of microseconds since the epoch (1970-01-01 00:00:00 UTC). @@ -739,9 +878,10 @@ def timestamp_to_unix_micros(self) -> "TimestampToUnixMicros": Returns: A new `Expr` representing the number of microseconds since the epoch. """ - return TimestampToUnixMicros(self) + return Function("timestamp_to_unix_micros", [self]) - def unix_micros_to_timestamp(self) -> "UnixMicrosToTimestamp": + @expose_as_static + def unix_micros_to_timestamp(self) -> "Expr": """Creates an expression that converts a number of microseconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -752,9 +892,10 @@ def unix_micros_to_timestamp(self) -> "UnixMicrosToTimestamp": Returns: A new `Expr` representing the timestamp. """ - return UnixMicrosToTimestamp(self) + return Function("unix_micros_to_timestamp", [self]) - def timestamp_to_unix_millis(self) -> "TimestampToUnixMillis": + @expose_as_static + def timestamp_to_unix_millis(self) -> "Expr": """Creates an expression that converts a timestamp to the number of milliseconds since the epoch (1970-01-01 00:00:00 UTC). @@ -767,9 +908,10 @@ def timestamp_to_unix_millis(self) -> "TimestampToUnixMillis": Returns: A new `Expr` representing the number of milliseconds since the epoch. """ - return TimestampToUnixMillis(self) + return Function("timestamp_to_unix_millis", [self]) - def unix_millis_to_timestamp(self) -> "UnixMillisToTimestamp": + @expose_as_static + def unix_millis_to_timestamp(self) -> "Expr": """Creates an expression that converts a number of milliseconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -780,9 +922,10 @@ def unix_millis_to_timestamp(self) -> "UnixMillisToTimestamp": Returns: A new `Expr` representing the timestamp. """ - return UnixMillisToTimestamp(self) + return Function("unix_millis_to_timestamp", [self]) - def timestamp_to_unix_seconds(self) -> "TimestampToUnixSeconds": + @expose_as_static + def timestamp_to_unix_seconds(self) -> "Expr": """Creates an expression that converts a timestamp to the number of seconds since the epoch (1970-01-01 00:00:00 UTC). @@ -795,9 +938,10 @@ def timestamp_to_unix_seconds(self) -> "TimestampToUnixSeconds": Returns: A new `Expr` representing the number of seconds since the epoch. """ - return TimestampToUnixSeconds(self) + return Function("timestamp_to_unix_seconds", [self]) - def unix_seconds_to_timestamp(self) -> "UnixSecondsToTimestamp": + @expose_as_static + def unix_seconds_to_timestamp(self) -> "Expr": """Creates an expression that converts a number of seconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -808,9 +952,10 @@ def unix_seconds_to_timestamp(self) -> "UnixSecondsToTimestamp": Returns: A new `Expr` representing the timestamp. """ - return UnixSecondsToTimestamp(self) + return Function("unix_seconds_to_timestamp", [self]) - def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "TimestampAdd": + @expose_as_static + def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "Expr": """Creates an expression that adds a specified amount of time to this timestamp expression. Example: @@ -827,20 +972,24 @@ def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "TimestampAdd Returns: A new `Expr` representing the resulting timestamp. """ - return TimestampAdd( - self, - self._cast_to_expr_or_convert_to_constant(unit), - self._cast_to_expr_or_convert_to_constant(amount), + return Function( + "timestamp_add", + [ + self, + self._cast_to_expr_or_convert_to_constant(unit), + self._cast_to_expr_or_convert_to_constant(amount), + ], ) - def timestamp_sub(self, unit: Expr | str, amount: Expr | float) -> "TimestampSub": + @expose_as_static + def timestamp_subtract(self, unit: Expr | str, amount: Expr | float) -> "Expr": """Creates an expression that subtracts a specified amount of time from this timestamp expression. Example: >>> # Subtract a duration specified by the 'unit' and 'amount' fields from the 'timestamp' field. - >>> Field.of("timestamp").timestamp_sub(Field.of("unit"), Field.of("amount")) + >>> Field.of("timestamp").timestamp_subtract(Field.of("unit"), Field.of("amount")) >>> # Subtract 2.5 hours from the 'timestamp' field. - >>> Field.of("timestamp").timestamp_sub("hour", 2.5) + >>> Field.of("timestamp").timestamp_subtract("hour", 2.5) Args: unit: The expression or string evaluating to the unit of time to subtract, must be one of @@ -850,18 +999,34 @@ def timestamp_sub(self, unit: Expr | str, amount: Expr | float) -> "TimestampSub Returns: A new `Expr` representing the resulting timestamp. """ - return TimestampSub( - self, - self._cast_to_expr_or_convert_to_constant(unit), - self._cast_to_expr_or_convert_to_constant(amount), + return Function( + "timestamp_subtract", + [ + self, + self._cast_to_expr_or_convert_to_constant(unit), + self._cast_to_expr_or_convert_to_constant(amount), + ], ) + @expose_as_static + def collection_id(self): + """Creates an expression that returns the collection ID from a path. + + Example: + >>> # Get the collection ID from a path. + >>> Field.of("__name__").collection_id() + + Returns: + A new `Expr` representing the collection ID. + """ + return Function("collection_id", [self]) + def ascending(self) -> Ordering: """Creates an `Ordering` that sorts documents in ascending order based on this expression. Example: >>> # Sort documents by the 'name' field in ascending order - >>> firestore.pipeline().collection("users").sort(Field.of("name").ascending()) + >>> client.pipeline().collection("users").sort(Field.of("name").ascending()) Returns: A new `Ordering` for ascending sorting. @@ -873,14 +1038,14 @@ def descending(self) -> Ordering: Example: >>> # Sort documents by the 'createdAt' field in descending order - >>> firestore.pipeline().collection("users").sort(Field.of("createdAt").descending()) + >>> client.pipeline().collection("users").sort(Field.of("createdAt").descending()) Returns: A new `Ordering` for descending sorting. """ return Ordering(self, Ordering.Direction.DESCENDING) - def as_(self, alias: str) -> "ExprWithAlias": + def as_(self, alias: str) -> "AliasedExpr": """Assigns an alias to this expression. Aliases are useful for renaming fields in the output of a stage or for giving meaningful @@ -888,7 +1053,7 @@ def as_(self, alias: str) -> "ExprWithAlias": Example: >>> # Calculate the total price and assign it the alias "totalPrice" and add it to the output. - >>> firestore.pipeline().collection("items").add_fields( + >>> client.pipeline().collection("items").add_fields( ... Field.of("price").multiply(Field.of("quantity")).as_("totalPrice") ... ) @@ -896,10 +1061,10 @@ def as_(self, alias: str) -> "ExprWithAlias": alias: The alias to assign to this expression. Returns: - A new `Selectable` (typically an `ExprWithAlias`) that wraps this + A new `Selectable` (typically an `AliasedExpr`) that wraps this expression and associates it with the provided alias. """ - return ExprWithAlias(self, alias) + return AliasedExpr(self, alias) class Constant(Expr, Generic[CONSTANT_TYPE]): @@ -922,24 +1087,27 @@ def of(value: CONSTANT_TYPE) -> Constant[CONSTANT_TYPE]: def __repr__(self): return f"Constant.of({self.value!r})" + def __hash__(self): + return hash(self.value) + def _to_pb(self) -> Value: return encode_value(self.value) -class ListOfExprs(Expr): +class _ListOfExprs(Expr): """Represents a list of expressions, typically used as an argument to functions like 'in' or array functions.""" - def __init__(self, exprs: List[Expr]): - self.exprs: list[Expr] = exprs + def __init__(self, exprs: Sequence[Expr]): + self.exprs: list[Expr] = list(exprs) def __eq__(self, other): - if not isinstance(other, ListOfExprs): + if not isinstance(other, _ListOfExprs): return False else: return other.exprs == self.exprs def __repr__(self): - return f"{self.__class__.__name__}({self.exprs})" + return repr(self.exprs) def _to_pb(self): return Value(array_value={"values": [e._to_pb() for e in self.exprs]}) @@ -948,9 +1116,34 @@ def _to_pb(self): class Function(Expr): """A base class for expressions that represent function calls.""" - def __init__(self, name: str, params: Sequence[Expr]): + def __init__( + self, + name: str, + params: Sequence[Expr], + *, + use_infix_repr: bool = True, + infix_name_override: str | None = None, + ): self.name = name self.params = list(params) + self._use_infix_repr = use_infix_repr + self._infix_name_override = infix_name_override + + def __repr__(self): + """ + Most Functions can be triggered infix. Eg: Field.of('age').greater_than(18). + + Display them this way in the repr string where possible + """ + if self._use_infix_repr: + infix_name = self._infix_name_override or self.name + if len(self.params) == 1: + return f"{self.params[0]!r}.{infix_name}()" + elif len(self.params) == 2: + return f"{self.params[0]!r}.{infix_name}({self.params[1]!r})" + else: + return f"{self.params[0]!r}.{infix_name}({', '.join([repr(p) for p in self.params[1:]])})" + return f"{self.__class__.__name__}({', '.join([repr(p) for p in self.params])})" def __eq__(self, other): if not isinstance(other, Function): @@ -958,9 +1151,6 @@ def __eq__(self, other): else: return other.name == self.name and other.params == self.params - def __repr__(self): - return f"{self.__class__.__name__}({', '.join([repr(p) for p in self.params])})" - def _to_pb(self): return Value( function_value={ @@ -969,1118 +1159,143 @@ def _to_pb(self): } ) - def add(left: Expr | str, right: Expr | float) -> "Add": - """Creates an expression that adds two expressions together. - - Example: - >>> Function.add("rating", 5) - >>> Function.add(Field.of("quantity"), Field.of("reserve")) - - Args: - left: The first expression or field path to add. - right: The second expression or constant value to add. - - Returns: - A new `Expr` representing the addition operation. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.add(left_expr, right) - - def subtract(left: Expr | str, right: Expr | float) -> "Subtract": - """Creates an expression that subtracts another expression or constant from this expression. - - Example: - >>> Function.subtract("total", 20) - >>> Function.subtract(Field.of("price"), Field.of("discount")) - - Args: - left: The expression or field path to subtract from. - right: The expression or constant value to subtract. - Returns: - A new `Expr` representing the subtraction operation. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.subtract(left_expr, right) +class AggregateFunction(Function): + """A base class for aggregation functions that operate across multiple inputs.""" - def multiply(left: Expr | str, right: Expr | float) -> "Multiply": - """Creates an expression that multiplies this expression by another expression or constant. + def as_(self, alias: str) -> "AliasedAggregate": + """Assigns an alias to this expression. - Example: - >>> Function.multiply("value", 2) - >>> Function.multiply(Field.of("quantity"), Field.of("price")) + Aliases are useful for renaming fields in the output of a stage or for giving meaningful + names to calculated values. Args: - left: The expression or field path to multiply. - right: The expression or constant value to multiply by. + alias: The alias to assign to this expression. - Returns: - A new `Expr` representing the multiplication operation. + Returns: A new AliasedAggregate that wraps this expression and associates it with the + provided alias. """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.multiply(left_expr, right) + return AliasedAggregate(self, alias) - def divide(left: Expr | str, right: Expr | float) -> "Divide": - """Creates an expression that divides this expression by another expression or constant. - Example: - >>> Function.divide("value", 10) - >>> Function.divide(Field.of("total"), Field.of("count")) +class Selectable(Expr): + """Base class for expressions that can be selected or aliased in projection stages.""" - Args: - left: The expression or field path to be divided. - right: The expression or constant value to divide by. + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + else: + return other._to_map() == self._to_map() - Returns: - A new `Expr` representing the division operation. + @abstractmethod + def _to_map(self) -> tuple[str, Value]: """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.divide(left_expr, right) - - def mod(left: Expr | str, right: Expr | float) -> "Mod": - """Creates an expression that calculates the modulo (remainder) to another expression or constant. - - Example: - >>> Function.mod("value", 5) - >>> Function.mod(Field.of("value"), Field.of("divisor")) - - Args: - left: The dividend expression or field path. - right: The divisor expression or constant. - - Returns: - A new `Expr` representing the modulo operation. + Returns a str: Value representation of the Selectable """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.mod(left_expr, right) - - def logical_max(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "LogicalMax": - """Creates an expression that returns the larger value between this expression - and another expression or constant, based on Firestore's value type ordering. - - Firestore's value type ordering is described here: - https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering - - Example: - >>> Function.logical_max("value", 10) - >>> Function.logical_max(Field.of("discount"), Field.of("cap")) - - Args: - left: The expression or field path to compare. - right: The other expression or constant value to compare with. + raise NotImplementedError - Returns: - A new `Expr` representing the logical max operation. + @classmethod + def _value_from_selectables(cls, *selectables: Selectable) -> Value: """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.logical_max(left_expr, right) - - def logical_min(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "LogicalMin": - """Creates an expression that returns the smaller value between this expression - and another expression or constant, based on Firestore's value type ordering. - - Firestore's value type ordering is described here: - https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering - - Example: - >>> Function.logical_min("value", 10) - >>> Function.logical_min(Field.of("discount"), Field.of("floor")) - - Args: - left: The expression or field path to compare. - right: The other expression or constant value to compare with. - - Returns: - A new `Expr` representing the logical min operation. + Returns a Value representing a map of Selectables """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.logical_min(left_expr, right) - - def eq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Eq": - """Creates an expression that checks if this expression is equal to another - expression or constant value. - - Example: - >>> Function.eq("city", "London") - >>> Function.eq(Field.of("age"), 21) - - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for equality. + return Value( + map_value={ + "fields": {m[0]: m[1] for m in [s._to_map() for s in selectables]} + } + ) - Returns: - A new `Expr` representing the equality comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.eq(left_expr, right) + @staticmethod + def _to_value(field_list: Sequence[Selectable]) -> Value: + return Value( + map_value={ + "fields": {m[0]: m[1] for m in [f._to_map() for f in field_list]} + } + ) - def neq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Neq": - """Creates an expression that checks if this expression is not equal to another - expression or constant value. - Example: - >>> Function.neq("country", "USA") - >>> Function.neq(Field.of("status"), "completed") +T = TypeVar("T", bound=Expr) - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for inequality. - Returns: - A new `Expr` representing the inequality comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.neq(left_expr, right) +class AliasedExpr(Selectable, Generic[T]): + """Wraps an expression with an alias.""" - def gt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Gt": - """Creates an expression that checks if this expression is greater than another - expression or constant value. + def __init__(self, expr: T, alias: str): + self.expr = expr + self.alias = alias - Example: - >>> Function.gt("price", 100) - >>> Function.gt(Field.of("age"), Field.of("limit")) + def _to_map(self): + return self.alias, self.expr._to_pb() - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for greater than. + def __repr__(self): + return f"{self.expr}.as_('{self.alias}')" - Returns: - A new `Expr` representing the greater than comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.gt(left_expr, right) + def _to_pb(self): + return Value(map_value={"fields": {self.alias: self.expr._to_pb()}}) - def gte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Gte": - """Creates an expression that checks if this expression is greater than or equal - to another expression or constant value. - Example: - >>> Function.gte("score", 80) - >>> Function.gte(Field.of("quantity"), Field.of('requirement').add(1)) +class AliasedAggregate: + """Wraps an aggregate with an alias""" - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for greater than or equal to. + def __init__(self, expr: AggregateFunction, alias: str): + self.expr = expr + self.alias = alias - Returns: - A new `Expr` representing the greater than or equal to comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.gte(left_expr, right) + def _to_map(self): + return self.alias, self.expr._to_pb() - def lt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Lt": - """Creates an expression that checks if this expression is less than another - expression or constant value. + def __repr__(self): + return f"{self.expr}.as_('{self.alias}')" - Example: - >>> Function.lt("price", 50) - >>> Function.lt(Field.of("age"), Field.of('limit')) + def _to_pb(self): + return Value(map_value={"fields": {self.alias: self.expr._to_pb()}}) - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for less than. - Returns: - A new `Expr` representing the less than comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.lt(left_expr, right) +class Field(Selectable): + """Represents a reference to a field within a document.""" - def lte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Lte": - """Creates an expression that checks if this expression is less than or equal to - another expression or constant value. + DOCUMENT_ID = "__name__" - Example: - >>> Function.lte("score", 70) - >>> Function.lte(Field.of("quantity"), Constant.of(20)) + def __init__(self, path: str): + """Initializes a Field reference. Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for less than or equal to. - - Returns: - A new `Expr` representing the less than or equal to comparison. + path: The dot-separated path to the field (e.g., "address.city"). + Use Field.DOCUMENT_ID for the document ID. """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.lte(left_expr, right) - - def in_any(left: Expr | str, array: List[Expr | CONSTANT_TYPE]) -> "In": - """Creates an expression that checks if this expression is equal to any of the - provided values or expressions. + self.path = path - Example: - >>> Function.in_any("category", ["Electronics", "Apparel"]) - >>> Function.in_any(Field.of("category"), ["Electronics", Field.of("primaryType")]) + @staticmethod + def of(path: str): + """Creates a Field reference. Args: - left: The expression or field path to compare. - array: The values or expressions to check against. + path: The dot-separated path to the field (e.g., "address.city"). + Use Field.DOCUMENT_ID for the document ID. Returns: - A new `Expr` representing the 'IN' comparison. + A new Field instance. """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.in_any(left_expr, array) - - def not_in_any(left: Expr | str, array: List[Expr | CONSTANT_TYPE]) -> "Not": - """Creates an expression that checks if this expression is not equal to any of the - provided values or expressions. + return Field(path) - Example: - >>> Function.not_in_any("status", ["pending", "cancelled"]) + def _to_map(self): + return self.path, self._to_pb() - Args: - left: The expression or field path to compare. - array: The values or expressions to check against. + def __repr__(self): + return f"Field.of({self.path!r})" - Returns: - A new `Expr` representing the 'NOT IN' comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.not_in_any(left_expr, array) + def _to_pb(self): + return Value(field_reference_value=self.path) - def array_contains( - array: Expr | str, element: Expr | CONSTANT_TYPE - ) -> "ArrayContains": - """Creates an expression that checks if an array contains a specific element or value. - Example: - >>> Function.array_contains("colors", "red") - >>> Function.array_contains(Field.of("sizes"), Field.of("selectedSize")) - - Args: - array: The array expression or field path to check. - element: The element (expression or constant) to search for in the array. - - Returns: - A new `Expr` representing the 'array_contains' comparison. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_contains(array_expr, element) - - def array_contains_all( - array: Expr | str, elements: List[Expr | CONSTANT_TYPE] - ) -> "ArrayContainsAll": - """Creates an expression that checks if an array contains all the specified elements. - - Example: - >>> Function.array_contains_all("tags", ["news", "sports"]) - >>> Function.array_contains_all(Field.of("tags"), [Field.of("tag1"), "tag2"]) - - Args: - array: The array expression or field path to check. - elements: The list of elements (expressions or constants) to check for in the array. - - Returns: - A new `Expr` representing the 'array_contains_all' comparison. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_contains_all(array_expr, elements) - - def array_contains_any( - array: Expr | str, elements: List[Expr | CONSTANT_TYPE] - ) -> "ArrayContainsAny": - """Creates an expression that checks if an array contains any of the specified elements. - - Example: - >>> Function.array_contains_any("groups", ["admin", "editor"]) - >>> Function.array_contains_any(Field.of("categories"), [Field.of("cate1"), Field.of("cate2")]) - - Args: - array: The array expression or field path to check. - elements: The list of elements (expressions or constants) to check for in the array. - - Returns: - A new `Expr` representing the 'array_contains_any' comparison. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_contains_any(array_expr, elements) - - def array_length(array: Expr | str) -> "ArrayLength": - """Creates an expression that calculates the length of an array. - - Example: - >>> Function.array_length("cart") - - Returns: - A new `Expr` representing the length of the array. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_length(array_expr) - - def array_reverse(array: Expr | str) -> "ArrayReverse": - """Creates an expression that returns the reversed content of an array. - - Example: - >>> Function.array_reverse("preferences") - - Returns: - A new `Expr` representing the reversed array. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_reverse(array_expr) - - def is_nan(expr: Expr | str) -> "IsNaN": - """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). - - Example: - >>> Function.is_nan("measurement") - - Returns: - A new `Expr` representing the 'isNaN' check. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.is_nan(expr_val) - - def exists(expr: Expr | str) -> "Exists": - """Creates an expression that checks if a field exists in the document. - - Example: - >>> Function.exists("phoneNumber") - - Returns: - A new `Expr` representing the 'exists' check. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.exists(expr_val) - - def sum(expr: Expr | str) -> "Sum": - """Creates an aggregation that calculates the sum of a numeric field across multiple stage inputs. - - Example: - >>> Function.sum("orderAmount") - - Returns: - A new `Accumulator` representing the 'sum' aggregation. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.sum(expr_val) - - def avg(expr: Expr | str) -> "Avg": - """Creates an aggregation that calculates the average (mean) of a numeric field across multiple - stage inputs. - - Example: - >>> Function.avg("age") - - Returns: - A new `Accumulator` representing the 'avg' aggregation. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.avg(expr_val) - - def count(expr: Expr | str | None = None) -> "Count": - """Creates an aggregation that counts the number of stage inputs with valid evaluations of the - expression or field. If no expression is provided, it counts all inputs. - - Example: - >>> Function.count("productId") - >>> Function.count() - - Returns: - A new `Accumulator` representing the 'count' aggregation. - """ - if expr is None: - return Count() - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.count(expr_val) - - def min(expr: Expr | str) -> "Min": - """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. - - Example: - >>> Function.min("price") - - Returns: - A new `Accumulator` representing the 'min' aggregation. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.min(expr_val) - - def max(expr: Expr | str) -> "Max": - """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. - - Example: - >>> Function.max("score") - - Returns: - A new `Accumulator` representing the 'max' aggregation. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.max(expr_val) - - def char_length(expr: Expr | str) -> "CharLength": - """Creates an expression that calculates the character length of a string. - - Example: - >>> Function.char_length("name") - - Returns: - A new `Expr` representing the length of the string. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.char_length(expr_val) - - def byte_length(expr: Expr | str) -> "ByteLength": - """Creates an expression that calculates the byte length of a string in its UTF-8 form. - - Example: - >>> Function.byte_length("name") - - Returns: - A new `Expr` representing the byte length of the string. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.byte_length(expr_val) - - def like(expr: Expr | str, pattern: Expr | str) -> "Like": - """Creates an expression that performs a case-sensitive string comparison. - - Example: - >>> Function.like("title", "%guide%") - >>> Function.like(Field.of("title"), Field.of("pattern")) - - Args: - expr: The expression or field path to perform the comparison on. - pattern: The pattern (string or expression) to search for. You can use "%" as a wildcard character. - - Returns: - A new `Expr` representing the 'like' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.like(expr_val, pattern) - - def regex_contains(expr: Expr | str, regex: Expr | str) -> "RegexContains": - """Creates an expression that checks if a string contains a specified regular expression as a - substring. - - Example: - >>> Function.regex_contains("description", "(?i)example") - >>> Function.regex_contains(Field.of("description"), Field.of("regex")) - - Args: - expr: The expression or field path to perform the comparison on. - regex: The regular expression (string or expression) to use for the search. - - Returns: - A new `Expr` representing the 'contains' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.regex_contains(expr_val, regex) - - def regex_matches(expr: Expr | str, regex: Expr | str) -> "RegexMatch": - """Creates an expression that checks if a string matches a specified regular expression. - - Example: - >>> # Check if the 'email' field matches a valid email pattern - >>> Function.regex_matches("email", "[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}") - >>> Function.regex_matches(Field.of("email"), Field.of("regex")) - - Args: - expr: The expression or field path to match against. - regex: The regular expression (string or expression) to use for the match. - - Returns: - A new `Expr` representing the regular expression match. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.regex_matches(expr_val, regex) - - def str_contains(expr: Expr | str, substring: Expr | str) -> "StrContains": - """Creates an expression that checks if this string expression contains a specified substring. - - Example: - >>> Function.str_contains("description", "example") - >>> Function.str_contains(Field.of("description"), Field.of("keyword")) - - Args: - expr: The expression or field path to perform the comparison on. - substring: The substring (string or expression) to use for the search. - - Returns: - A new `Expr` representing the 'contains' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.str_contains(expr_val, substring) - - def starts_with(expr: Expr | str, prefix: Expr | str) -> "StartsWith": - """Creates an expression that checks if a string starts with a given prefix. - - Example: - >>> Function.starts_with("name", "Mr.") - >>> Function.starts_with(Field.of("fullName"), Field.of("firstName")) - - Args: - expr: The expression or field path to check. - prefix: The prefix (string or expression) to check for. - - Returns: - A new `Expr` representing the 'starts with' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.starts_with(expr_val, prefix) - - def ends_with(expr: Expr | str, postfix: Expr | str) -> "EndsWith": - """Creates an expression that checks if a string ends with a given postfix. - - Example: - >>> Function.ends_with("filename", ".txt") - >>> Function.ends_with(Field.of("url"), Field.of("extension")) - - Args: - expr: The expression or field path to check. - postfix: The postfix (string or expression) to check for. - - Returns: - A new `Expr` representing the 'ends with' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.ends_with(expr_val, postfix) - - def str_concat(first: Expr | str, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": - """Creates an expression that concatenates string expressions, fields or constants together. - - Example: - >>> Function.str_concat("firstName", " ", Field.of("lastName")) - - Args: - first: The first expression or field path to concatenate. - *elements: The expressions or constants (typically strings) to concatenate. - - Returns: - A new `Expr` representing the concatenated string. - """ - first_expr = Field.of(first) if isinstance(first, str) else first - return Expr.str_concat(first_expr, *elements) - - def map_get(map_expr: Expr | str, key: str) -> "MapGet": - """Accesses a value from a map (object) field using the provided key. - - Example: - >>> Function.map_get("address", "city") - - Args: - map_expr: The expression or field path of the map. - key: The key to access in the map. - - Returns: - A new `Expr` representing the value associated with the given key in the map. - """ - map_val = Field.of(map_expr) if isinstance(map_expr, str) else map_expr - return Expr.map_get(map_val, key) - - def vector_length(vector_expr: Expr | str) -> "VectorLength": - """Creates an expression that calculates the length (dimension) of a Firestore Vector. - - Example: - >>> Function.vector_length("embedding") - - Returns: - A new `Expr` representing the length of the vector. - """ - vector_val = ( - Field.of(vector_expr) if isinstance(vector_expr, str) else vector_expr - ) - return Expr.vector_length(vector_val) - - def timestamp_to_unix_micros(timestamp_expr: Expr | str) -> "TimestampToUnixMicros": - """Creates an expression that converts a timestamp to the number of microseconds since the epoch - (1970-01-01 00:00:00 UTC). - - Truncates higher levels of precision by rounding down to the beginning of the microsecond. - - Example: - >>> Function.timestamp_to_unix_micros("timestamp") - - Returns: - A new `Expr` representing the number of microseconds since the epoch. - """ - timestamp_val = ( - Field.of(timestamp_expr) - if isinstance(timestamp_expr, str) - else timestamp_expr - ) - return Expr.timestamp_to_unix_micros(timestamp_val) - - def unix_micros_to_timestamp(micros_expr: Expr | str) -> "UnixMicrosToTimestamp": - """Creates an expression that converts a number of microseconds since the epoch (1970-01-01 - 00:00:00 UTC) to a timestamp. - - Example: - >>> Function.unix_micros_to_timestamp("microseconds") - - Returns: - A new `Expr` representing the timestamp. - """ - micros_val = ( - Field.of(micros_expr) if isinstance(micros_expr, str) else micros_expr - ) - return Expr.unix_micros_to_timestamp(micros_val) - - def timestamp_to_unix_millis(timestamp_expr: Expr | str) -> "TimestampToUnixMillis": - """Creates an expression that converts a timestamp to the number of milliseconds since the epoch - (1970-01-01 00:00:00 UTC). - - Truncates higher levels of precision by rounding down to the beginning of the millisecond. - - Example: - >>> Function.timestamp_to_unix_millis("timestamp") - - Returns: - A new `Expr` representing the number of milliseconds since the epoch. - """ - timestamp_val = ( - Field.of(timestamp_expr) - if isinstance(timestamp_expr, str) - else timestamp_expr - ) - return Expr.timestamp_to_unix_millis(timestamp_val) - - def unix_millis_to_timestamp(millis_expr: Expr | str) -> "UnixMillisToTimestamp": - """Creates an expression that converts a number of milliseconds since the epoch (1970-01-01 - 00:00:00 UTC) to a timestamp. - - Example: - >>> Function.unix_millis_to_timestamp("milliseconds") - - Returns: - A new `Expr` representing the timestamp. - """ - millis_val = ( - Field.of(millis_expr) if isinstance(millis_expr, str) else millis_expr - ) - return Expr.unix_millis_to_timestamp(millis_val) - - def timestamp_to_unix_seconds( - timestamp_expr: Expr | str, - ) -> "TimestampToUnixSeconds": - """Creates an expression that converts a timestamp to the number of seconds since the epoch - (1970-01-01 00:00:00 UTC). - - Truncates higher levels of precision by rounding down to the beginning of the second. - - Example: - >>> Function.timestamp_to_unix_seconds("timestamp") - - Returns: - A new `Expr` representing the number of seconds since the epoch. - """ - timestamp_val = ( - Field.of(timestamp_expr) - if isinstance(timestamp_expr, str) - else timestamp_expr - ) - return Expr.timestamp_to_unix_seconds(timestamp_val) - - def unix_seconds_to_timestamp(seconds_expr: Expr | str) -> "UnixSecondsToTimestamp": - """Creates an expression that converts a number of seconds since the epoch (1970-01-01 00:00:00 - UTC) to a timestamp. - - Example: - >>> Function.unix_seconds_to_timestamp("seconds") - - Returns: - A new `Expr` representing the timestamp. - """ - seconds_val = ( - Field.of(seconds_expr) if isinstance(seconds_expr, str) else seconds_expr - ) - return Expr.unix_seconds_to_timestamp(seconds_val) - - def timestamp_add( - timestamp: Expr | str, unit: Expr | str, amount: Expr | float - ) -> "TimestampAdd": - """Creates an expression that adds a specified amount of time to this timestamp expression. - - Example: - >>> Function.timestamp_add("timestamp", "day", 1.5) - >>> Function.timestamp_add(Field.of("timestamp"), Field.of("unit"), Field.of("amount")) - - Args: - timestamp: The expression or field path of the timestamp. - unit: The expression or string evaluating to the unit of time to add, must be one of - 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. - amount: The expression or float representing the amount of time to add. - - Returns: - A new `Expr` representing the resulting timestamp. - """ - timestamp_expr = ( - Field.of(timestamp) if isinstance(timestamp, str) else timestamp - ) - return Expr.timestamp_add(timestamp_expr, unit, amount) - - def timestamp_sub( - timestamp: Expr | str, unit: Expr | str, amount: Expr | float - ) -> "TimestampSub": - """Creates an expression that subtracts a specified amount of time from this timestamp expression. - - Example: - >>> Function.timestamp_sub("timestamp", "hour", 2.5) - >>> Function.timestamp_sub(Field.of("timestamp"), Field.of("unit"), Field.of("amount")) - - Args: - timestamp: The expression or field path of the timestamp. - unit: The expression or string evaluating to the unit of time to subtract, must be one of - 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. - amount: The expression or float representing the amount of time to subtract. - - Returns: - A new `Expr` representing the resulting timestamp. - """ - timestamp_expr = ( - Field.of(timestamp) if isinstance(timestamp, str) else timestamp - ) - return Expr.timestamp_sub(timestamp_expr, unit, amount) - - -class Divide(Function): - """Represents the division function.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("divide", [left, right]) - - -class LogicalMax(Function): - """Represents the logical maximum function based on Firestore type ordering.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("logical_maximum", [left, right]) - - -class LogicalMin(Function): - """Represents the logical minimum function based on Firestore type ordering.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("logical_minimum", [left, right]) - - -class MapGet(Function): - """Represents accessing a value within a map by key.""" - - def __init__(self, map_: Expr, key: Constant[str]): - super().__init__("map_get", [map_, key]) - - -class Mod(Function): - """Represents the modulo function.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("mod", [left, right]) - - -class Multiply(Function): - """Represents the multiplication function.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("multiply", [left, right]) - - -class Parent(Function): - """Represents getting the parent document reference.""" - - def __init__(self, value: Expr): - super().__init__("parent", [value]) - - -class StrConcat(Function): - """Represents concatenating multiple strings.""" - - def __init__(self, *exprs: Expr): - super().__init__("str_concat", exprs) - - -class Subtract(Function): - """Represents the subtraction function.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("subtract", [left, right]) - - -class TimestampAdd(Function): - """Represents adding a duration to a timestamp.""" - - def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): - super().__init__("timestamp_add", [timestamp, unit, amount]) - - -class TimestampSub(Function): - """Represents subtracting a duration from a timestamp.""" - - def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): - super().__init__("timestamp_sub", [timestamp, unit, amount]) - - -class TimestampToUnixMicros(Function): - """Represents converting a timestamp to microseconds since epoch.""" - - def __init__(self, input: Expr): - super().__init__("timestamp_to_unix_micros", [input]) - - -class TimestampToUnixMillis(Function): - """Represents converting a timestamp to milliseconds since epoch.""" - - def __init__(self, input: Expr): - super().__init__("timestamp_to_unix_millis", [input]) - - -class TimestampToUnixSeconds(Function): - """Represents converting a timestamp to seconds since epoch.""" - - def __init__(self, input: Expr): - super().__init__("timestamp_to_unix_seconds", [input]) - - -class UnixMicrosToTimestamp(Function): - """Represents converting microseconds since epoch to a timestamp.""" - - def __init__(self, input: Expr): - super().__init__("unix_micros_to_timestamp", [input]) - - -class UnixMillisToTimestamp(Function): - """Represents converting milliseconds since epoch to a timestamp.""" - - def __init__(self, input: Expr): - super().__init__("unix_millis_to_timestamp", [input]) - - -class UnixSecondsToTimestamp(Function): - """Represents converting seconds since epoch to a timestamp.""" - - def __init__(self, input: Expr): - super().__init__("unix_seconds_to_timestamp", [input]) - - -class VectorLength(Function): - """Represents getting the length (dimension) of a vector.""" - - def __init__(self, array: Expr): - super().__init__("vector_length", [array]) - - -class Add(Function): - """Represents the addition function.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("add", [left, right]) - - -class ArrayElement(Function): - """Represents accessing an element within an array""" - - def __init__(self): - super().__init__("array_element", []) - - -class ArrayFilter(Function): - """Represents filtering elements from an array based on a condition.""" - - def __init__(self, array: Expr, filter: "FilterCondition"): - super().__init__("array_filter", [array, filter]) - - -class ArrayLength(Function): - """Represents getting the length of an array.""" - - def __init__(self, array: Expr): - super().__init__("array_length", [array]) - - -class ArrayReverse(Function): - """Represents reversing the elements of an array.""" - - def __init__(self, array: Expr): - super().__init__("array_reverse", [array]) - - -class ArrayTransform(Function): - """Represents applying a transformation function to each element of an array.""" - - def __init__(self, array: Expr, transform: Function): - super().__init__("array_transform", [array, transform]) - - -class ByteLength(Function): - """Represents getting the byte length of a string (UTF-8).""" - - def __init__(self, expr: Expr): - super().__init__("byte_length", [expr]) - - -class CharLength(Function): - """Represents getting the character length of a string.""" - - def __init__(self, expr: Expr): - super().__init__("char_length", [expr]) - - -class CollectionId(Function): - """Represents getting the collection ID from a document reference.""" - - def __init__(self, value: Expr): - super().__init__("collection_id", [value]) - - -class Accumulator(Function): - """A base class for aggregation functions that operate across multiple inputs.""" - - -class Max(Accumulator): - """Represents the maximum aggregation function.""" - - def __init__(self, value: Expr): - super().__init__("maximum", [value]) - - -class Min(Accumulator): - """Represents the minimum aggregation function.""" - - def __init__(self, value: Expr): - super().__init__("minimum", [value]) - - -class Sum(Accumulator): - """Represents the sum aggregation function.""" - - def __init__(self, value: Expr): - super().__init__("sum", [value]) - - -class Avg(Accumulator): - """Represents the average aggregation function.""" - - def __init__(self, value: Expr): - super().__init__("avg", [value]) - - -class Count(Accumulator): - """Represents an aggregation that counts the total number of inputs.""" - - def __init__(self, value: Expr | None = None): - super().__init__("count", [value] if value else []) - - -class Selectable(Expr): - """Base class for expressions that can be selected or aliased in projection stages.""" - - def __eq__(self, other): - if not isinstance(other, type(self)): - return False - else: - return other._to_map() == self._to_map() - - @abstractmethod - def _to_map(self) -> tuple[str, Value]: - """ - Returns a str: Value representation of the Selectable - """ - raise NotImplementedError - - @classmethod - def _value_from_selectables(cls, *selectables: Selectable) -> Value: - """ - Returns a Value representing a map of Selectables - """ - return Value( - map_value={ - "fields": {m[0]: m[1] for m in [s._to_map() for s in selectables]} - } - ) - - -T = TypeVar("T", bound=Expr) - - -class ExprWithAlias(Selectable, Generic[T]): - """Wraps an expression with an alias.""" - - def __init__(self, expr: T, alias: str): - self.expr = expr - self.alias = alias - - def _to_map(self): - return self.alias, self.expr._to_pb() - - def __repr__(self): - return f"{self.expr}.as_('{self.alias}')" - - def _to_pb(self): - return Value(map_value={"fields": {self.alias: self.expr._to_pb()}}) - - -class Field(Selectable): - """Represents a reference to a field within a document.""" - - DOCUMENT_ID = "__name__" - - def __init__(self, path: str): - """Initializes a Field reference. - - Args: - path: The dot-separated path to the field (e.g., "address.city"). - Use Field.DOCUMENT_ID for the document ID. - """ - self.path = path - - @staticmethod - def of(path: str): - """Creates a Field reference. - - Args: - path: The dot-separated path to the field (e.g., "address.city"). - Use Field.DOCUMENT_ID for the document ID. - - Returns: - A new Field instance. - """ - return Field(path) - - def _to_map(self): - return self.path, self._to_pb() - - def __repr__(self): - return f"Field.of({self.path!r})" - - def _to_pb(self): - return Value(field_reference_value=self.path) - - -class FilterCondition(Function): - """Filters the given data in some way.""" - - def __init__( - self, - *args, - use_infix_repr: bool = True, - infix_name_override: str | None = None, - **kwargs, - ): - self._use_infix_repr = use_infix_repr - self._infix_name_override = infix_name_override - super().__init__(*args, **kwargs) - - def __repr__(self): - """ - Most FilterConditions can be triggered infix. Eg: Field.of('age').gte(18). - - Display them this way in the repr string where possible - """ - if self._use_infix_repr: - infix_name = self._infix_name_override or self.name - if len(self.params) == 1: - return f"{self.params[0]!r}.{infix_name}()" - elif len(self.params) == 2: - return f"{self.params[0]!r}.{infix_name}({self.params[1]!r})" - return super().__repr__() +class BooleanExpr(Function): + """Filters the given data in some way.""" @staticmethod def _from_query_filter_pb(filter_pb, client): if isinstance(filter_pb, Query_pb.CompositeFilter): sub_filters = [ - FilterCondition._from_query_filter_pb(f, client) - for f in filter_pb.filters + BooleanExpr._from_query_filter_pb(f, client) for f in filter_pb.filters ] if filter_pb.op == Query_pb.CompositeFilter.Operator.OR: return Or(*sub_filters) @@ -2097,34 +1312,34 @@ def _from_query_filter_pb(filter_pb, client): elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NAN: return And(field.exists(), Not(field.is_nan())) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NULL: - return And(field.exists(), field.eq(None)) + return And(field.exists(), field.equal(None)) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NULL: - return And(field.exists(), Not(field.eq(None))) + return And(field.exists(), Not(field.equal(None))) else: raise TypeError(f"Unexpected UnaryFilter operator type: {filter_pb.op}") elif isinstance(filter_pb, Query_pb.FieldFilter): field = Field.of(filter_pb.field.field_path) value = decode_value(filter_pb.value, client) if filter_pb.op == Query_pb.FieldFilter.Operator.LESS_THAN: - return And(field.exists(), field.lt(value)) + return And(field.exists(), field.less_than(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.LESS_THAN_OR_EQUAL: - return And(field.exists(), field.lte(value)) + return And(field.exists(), field.less_than_or_equal(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.GREATER_THAN: - return And(field.exists(), field.gt(value)) + return And(field.exists(), field.greater_than(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.GREATER_THAN_OR_EQUAL: - return And(field.exists(), field.gte(value)) + return And(field.exists(), field.greater_than_or_equal(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.EQUAL: - return And(field.exists(), field.eq(value)) + return And(field.exists(), field.equal(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_EQUAL: - return And(field.exists(), field.neq(value)) + return And(field.exists(), field.not_equal(value)) if filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS: return And(field.exists(), field.array_contains(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS_ANY: return And(field.exists(), field.array_contains_any(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.IN: - return And(field.exists(), field.in_any(value)) + return And(field.exists(), field.equal_any(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_IN: - return And(field.exists(), field.not_in_any(value)) + return And(field.exists(), field.not_equal_any(value)) else: raise TypeError(f"Unexpected FieldFilter operator type: {filter_pb.op}") elif isinstance(filter_pb, Query_pb.Filter): @@ -2134,169 +1349,94 @@ def _from_query_filter_pb(filter_pb, client): or filter_pb.field_filter or filter_pb.unary_filter ) - return FilterCondition._from_query_filter_pb(f, client) + return BooleanExpr._from_query_filter_pb(f, client) else: raise TypeError(f"Unexpected filter type: {type(filter_pb)}") -class And(FilterCondition): - def __init__(self, *conditions: "FilterCondition"): - super().__init__("and", conditions, use_infix_repr=False) - - -class ArrayContains(FilterCondition): - def __init__(self, array: Expr, element: Expr): - super().__init__( - "array_contains", [array, element if element else Constant(None)] - ) - - -class ArrayContainsAll(FilterCondition): - """Represents checking if an array contains all specified elements.""" - - def __init__(self, array: Expr, elements: List[Expr]): - super().__init__("array_contains_all", [array, ListOfExprs(elements)]) - +class And(BooleanExpr): + """ + Represents an expression that performs a logical 'AND' operation on multiple filter conditions. -class ArrayContainsAny(FilterCondition): - """Represents checking if an array contains any of the specified elements.""" + Example: + >>> # Check if the 'age' field is greater than 18 AND the 'city' field is "London" AND + >>> # the 'status' field is "active" + >>> Expr.And(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) - def __init__(self, array: Expr, elements: List[Expr]): - super().__init__("array_contains_any", [array, ListOfExprs(elements)]) + Args: + *conditions: The filter conditions to 'AND' together. + """ + def __init__(self, *conditions: "BooleanExpr"): + super().__init__("and", conditions, use_infix_repr=False) -class EndsWith(FilterCondition): - """Represents checking if a string ends with a specific postfix.""" - def __init__(self, expr: Expr, postfix: Expr): - super().__init__("ends_with", [expr, postfix]) +class Not(BooleanExpr): + """ + Represents an expression that negates a filter condition. + Example: + >>> # Find documents where the 'completed' field is NOT true + >>> Expr.Not(Field.of("completed").equal(True)) -class Eq(FilterCondition): - """Represents the equality comparison.""" + Args: + condition: The filter condition to negate. + """ - def __init__(self, left: Expr, right: Expr): - super().__init__("eq", [left, right if right else Constant(None)]) + def __init__(self, condition: BooleanExpr): + super().__init__("not", [condition], use_infix_repr=False) -class Exists(FilterCondition): - """Represents checking if a field exists.""" +class Or(BooleanExpr): + """ + Represents expression that performs a logical 'OR' operation on multiple filter conditions. - def __init__(self, expr: Expr): - super().__init__("exists", [expr]) + Example: + >>> # Check if the 'age' field is greater than 18 OR the 'city' field is "London" OR + >>> # the 'status' field is "active" + >>> Expr.Or(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) + Args: + *conditions: The filter conditions to 'OR' together. + """ -class Gt(FilterCondition): - """Represents the greater than comparison.""" + def __init__(self, *conditions: "BooleanExpr"): + super().__init__("or", conditions, use_infix_repr=False) - def __init__(self, left: Expr, right: Expr): - super().__init__("gt", [left, right if right else Constant(None)]) +class Xor(BooleanExpr): + """ + Represents an expression that performs a logical 'XOR' (exclusive OR) operation on multiple filter conditions. -class Gte(FilterCondition): - """Represents the greater than or equal to comparison.""" + Example: + >>> # Check if only one of the conditions is true: 'age' greater than 18, 'city' is "London", + >>> # or 'status' is "active". + >>> Expr.Xor(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) - def __init__(self, left: Expr, right: Expr): - super().__init__("gte", [left, right if right else Constant(None)]) + Args: + *conditions: The filter conditions to 'XOR' together. + """ + def __init__(self, conditions: Sequence["BooleanExpr"]): + super().__init__("xor", conditions, use_infix_repr=False) -class If(FilterCondition): - """Represents a conditional expression (if-then-else).""" - def __init__(self, condition: "FilterCondition", true_expr: Expr, false_expr: Expr): - super().__init__( - "if", [condition, true_expr, false_expr if false_expr else Constant(None)] - ) +class Conditional(BooleanExpr): + """ + Represents a conditional expression that evaluates to a 'then' expression if a condition is true + and an 'else' expression if the condition is false. + Example: + >>> # If 'age' is greater than 18, return "Adult"; otherwise, return "Minor". + >>> Expr.conditional(Field.of("age").greater_than(18), Constant.of("Adult"), Constant.of("Minor")); -class In(FilterCondition): - """Represents checking if an expression's value is within a list of values.""" + Args: + condition: The condition to evaluate. + then_expr: The expression to return if the condition is true. + else_expr: The expression to return if the condition is false + """ - def __init__(self, left: Expr, others: List[Expr]): + def __init__(self, condition: BooleanExpr, then_expr: Expr, else_expr: Expr): super().__init__( - "in", [left, ListOfExprs(others)], infix_name_override="in_any" + "conditional", [condition, then_expr, else_expr], use_infix_repr=False ) - - -class IsNaN(FilterCondition): - """Represents checking if a numeric value is NaN.""" - - def __init__(self, value: Expr): - super().__init__("is_nan", [value]) - - -class Like(FilterCondition): - """Represents a case-sensitive wildcard string comparison.""" - - def __init__(self, expr: Expr, pattern: Expr): - super().__init__("like", [expr, pattern]) - - -class Lt(FilterCondition): - """Represents the less than comparison.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("lt", [left, right if right else Constant(None)]) - - -class Lte(FilterCondition): - """Represents the less than or equal to comparison.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("lte", [left, right if right else Constant(None)]) - - -class Neq(FilterCondition): - """Represents the inequality comparison.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("neq", [left, right if right else Constant(None)]) - - -class Not(FilterCondition): - """Represents the logical NOT of a filter condition.""" - - def __init__(self, condition: Expr): - super().__init__("not", [condition], use_infix_repr=False) - - -class Or(FilterCondition): - """Represents the logical OR of multiple filter conditions.""" - - def __init__(self, *conditions: "FilterCondition"): - super().__init__("or", conditions) - - -class RegexContains(FilterCondition): - """Represents checking if a string contains a substring matching a regex.""" - - def __init__(self, expr: Expr, regex: Expr): - super().__init__("regex_contains", [expr, regex]) - - -class RegexMatch(FilterCondition): - """Represents checking if a string fully matches a regex.""" - - def __init__(self, expr: Expr, regex: Expr): - super().__init__("regex_match", [expr, regex]) - - -class StartsWith(FilterCondition): - """Represents checking if a string starts with a specific prefix.""" - - def __init__(self, expr: Expr, prefix: Expr): - super().__init__("starts_with", [expr, prefix]) - - -class StrContains(FilterCondition): - """Represents checking if a string contains a specific substring.""" - - def __init__(self, expr: Expr, substring: Expr): - super().__init__("str_contains", [expr, substring]) - - -class Xor(FilterCondition): - """Represents the logical XOR of multiple filter conditions.""" - - def __init__(self, conditions: List["FilterCondition"]): - super().__init__("xor", conditions, use_infix_repr=False) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index dc262f4a9..50cc7c29d 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -125,13 +125,25 @@ data: awards: hugo: true nebula: true + timestamps: + ts1: + time: "1993-04-28T12:01:00.654321+00:00" + micros: 735998460654321 + millis: 735998460654 + seconds: 735998460 + vectors: + vec1: + embedding: [1.0, 2.0, 3.0] + vec2: + embedding: [4.0, 5.0, 6.0, 7.0] tests: - description: "testAggregates - count" pipeline: - Collection: books - Aggregate: - - ExprWithAlias: - - Count + - AliasedExpr: + - Expr.count: + - Field: rating - "count" assert_results: - count: 10 @@ -147,25 +159,28 @@ tests: count: functionValue: name: count + args: + - fieldReferenceValue: rating - mapValue: {} name: aggregate - description: "testAggregates - avg, count, max" pipeline: - Collection: books - Where: - - Eq: + - Expr.equal: - Field: genre - Constant: Science Fiction - Aggregate: - - ExprWithAlias: - - Count + - AliasedExpr: + - Expr.count: + - Field: rating - "count" - - ExprWithAlias: - - Avg: + - AliasedExpr: + - Expr.average: - Field: rating - "avg_rating" - - ExprWithAlias: - - Max: + - AliasedExpr: + - Expr.maximum: - Field: rating - "max_rating" assert_results: @@ -183,7 +198,7 @@ tests: args: - fieldReferenceValue: genre - stringValue: Science Fiction - name: eq + name: equal name: where - args: - mapValue: @@ -192,10 +207,12 @@ tests: functionValue: args: - fieldReferenceValue: rating - name: avg + name: average count: functionValue: name: count + args: + - fieldReferenceValue: rating max_rating: functionValue: args: @@ -207,7 +224,7 @@ tests: pipeline: - Collection: books - Where: - - Lt: + - Expr.less_than: - Field: published - Constant: 1900 - Aggregate: @@ -218,18 +235,18 @@ tests: pipeline: - Collection: books - Where: - - Lt: + - Expr.less_than: - Field: published - Constant: 1984 - Aggregate: accumulators: - - ExprWithAlias: - - Avg: + - AliasedExpr: + - Expr.average: - Field: rating - "avg_rating" groups: [genre] - Where: - - Gt: + - Expr.greater_than: - Field: avg_rating - Constant: 4.3 - Sort: @@ -254,7 +271,7 @@ tests: args: - fieldReferenceValue: published - integerValue: '1984' - name: lt + name: less_than name: where - args: - mapValue: @@ -263,7 +280,7 @@ tests: functionValue: args: - fieldReferenceValue: rating - name: avg + name: average - mapValue: fields: genre: @@ -274,7 +291,7 @@ tests: args: - fieldReferenceValue: avg_rating - doubleValue: 4.3 - name: gt + name: greater_than name: where - args: - mapValue: @@ -288,15 +305,16 @@ tests: pipeline: - Collection: books - Aggregate: - - ExprWithAlias: - - Count + - AliasedExpr: + - Expr.count: + - Field: rating - "count" - - ExprWithAlias: - - Max: + - AliasedExpr: + - Expr.maximum: - Field: rating - "max_rating" - - ExprWithAlias: - - Min: + - AliasedExpr: + - Expr.minimum: - Field: published - "min_published" assert_results: @@ -314,6 +332,8 @@ tests: fields: count: functionValue: + args: + - fieldReferenceValue: rating name: count max_rating: functionValue: @@ -384,14 +404,14 @@ tests: pipeline: - Collection: books - AddFields: - - ExprWithAlias: - - StrConcat: + - AliasedExpr: + - Expr.string_concat: - Field: author - Constant: _ - Field: title - "author_title" - - ExprWithAlias: - - StrConcat: + - AliasedExpr: + - Expr.string_concat: - Field: title - Constant: _ - Field: author @@ -445,14 +465,14 @@ tests: - fieldReferenceValue: author - stringValue: _ - fieldReferenceValue: title - name: str_concat + name: string_concat title_author: functionValue: args: - fieldReferenceValue: title - stringValue: _ - fieldReferenceValue: author - name: str_concat + name: string_concat name: add_fields - args: - fieldReferenceValue: title_author @@ -477,10 +497,10 @@ tests: - Collection: books - Where: - And: - - Gt: + - Expr.greater_than: - Field: rating - Constant: 4.5 - - Eq: + - Expr.equal: - Field: genre - Constant: Science Fiction assert_results: @@ -509,12 +529,12 @@ tests: args: - fieldReferenceValue: rating - doubleValue: 4.5 - name: gt + name: greater_than - functionValue: args: - fieldReferenceValue: genre - stringValue: Science Fiction - name: eq + name: equal name: and name: where - description: whereByOrCondition @@ -522,10 +542,10 @@ tests: - Collection: books - Where: - Or: - - Eq: + - Expr.equal: - Field: genre - Constant: Romance - - Eq: + - Expr.equal: - Field: genre - Constant: Dystopian - Select: @@ -551,12 +571,12 @@ tests: args: - fieldReferenceValue: genre - stringValue: Romance - name: eq + name: equal - functionValue: args: - fieldReferenceValue: genre - stringValue: Dystopian - name: eq + name: equal name: or name: where - args: @@ -624,7 +644,7 @@ tests: pipeline: - Collection: books - Where: - - ArrayContains: + - Expr.array_contains: - Field: tags - Constant: comedy assert_results: @@ -654,7 +674,7 @@ tests: pipeline: - Collection: books - Where: - - ArrayContainsAny: + - Expr.array_contains_any: - Field: tags - - Constant: comedy - Constant: classic @@ -701,7 +721,7 @@ tests: pipeline: - Collection: books - Where: - - ArrayContainsAll: + - Expr.array_contains_all: - Field: tags - - Constant: adventure - Constant: magic @@ -735,12 +755,12 @@ tests: pipeline: - Collection: books - Select: - - ExprWithAlias: - - ArrayLength: + - AliasedExpr: + - Expr.array_length: - Field: tags - "tagsCount" - Where: - - Eq: + - Expr.equal: - Field: tagsCount - Constant: 3 assert_results: # All documents have 3 tags @@ -774,9 +794,9 @@ tests: args: - fieldReferenceValue: tagsCount - integerValue: '3' - name: eq + name: equal name: where - - description: testStrConcat + - description: testStringConcat pipeline: - Collection: books - Sort: @@ -784,8 +804,8 @@ tests: - Field: author - ASCENDING - Select: - - ExprWithAlias: - - StrConcat: + - AliasedExpr: + - Expr.string_concat: - Field: author - Constant: " - " - Field: title @@ -816,7 +836,7 @@ tests: - fieldReferenceValue: author - stringValue: ' - ' - fieldReferenceValue: title - name: str_concat + name: string_concat name: select - args: - integerValue: '1' @@ -825,7 +845,7 @@ tests: pipeline: - Collection: books - Where: - - StartsWith: + - Expr.starts_with: - Field: title - Constant: The - Select: @@ -870,7 +890,7 @@ tests: pipeline: - Collection: books - Where: - - EndsWith: + - Expr.ends_with: - Field: title - Constant: y - Select: @@ -913,13 +933,13 @@ tests: pipeline: - Collection: books - Select: - - ExprWithAlias: - - CharLength: + - AliasedExpr: + - Expr.char_length: - Field: title - "titleLength" - title - Where: - - Gt: + - Expr.greater_than: - Field: titleLength - Constant: 20 - Sort: @@ -957,7 +977,7 @@ tests: args: - fieldReferenceValue: titleLength - integerValue: '20' - name: gt + name: greater_than name: where - args: - mapValue: @@ -971,12 +991,12 @@ tests: pipeline: - Collection: books - Where: - - Eq: + - Expr.equal: - Field: author - Constant: "Douglas Adams" - Select: - - ExprWithAlias: - - CharLength: + - AliasedExpr: + - Expr.char_length: - Field: title - "title_length" assert_results: @@ -992,7 +1012,7 @@ tests: args: - fieldReferenceValue: author - stringValue: Douglas Adams - name: eq + name: equal name: where - args: - mapValue: @@ -1007,13 +1027,13 @@ tests: pipeline: - Collection: books - Where: - - Eq: + - Expr.equal: - Field: author - Constant: Douglas Adams - Select: - - ExprWithAlias: - - ByteLength: - - StrConcat: + - AliasedExpr: + - Expr.byte_length: + - Expr.string_concat: - Field: title - Constant: _银河系漫游指南 - "title_byte_length" @@ -1030,7 +1050,7 @@ tests: args: - fieldReferenceValue: author - stringValue: Douglas Adams - name: eq + name: equal name: where - args: - mapValue: @@ -1042,14 +1062,14 @@ tests: args: - fieldReferenceValue: title - stringValue: "_\u94F6\u6CB3\u7CFB\u6F2B\u6E38\u6307\u5357" - name: str_concat + name: string_concat name: byte_length name: select - description: testLike pipeline: - Collection: books - Where: - - Like: + - Expr.like: - Field: title - Constant: "%Guide%" - Select: @@ -1061,7 +1081,7 @@ tests: pipeline: - Collection: books - Where: - - RegexContains: + - Expr.regex_contains: - Field: title - Constant: "(?i)(the|of)" assert_count: 5 @@ -1083,7 +1103,7 @@ tests: pipeline: - Collection: books - Where: - - RegexMatch: + - Expr.regex_match: - Field: title - Constant: ".*(?i)(the|of).*" assert_count: 5 @@ -1104,42 +1124,42 @@ tests: pipeline: - Collection: books - Where: - - Eq: + - Expr.equal: - Field: title - Constant: To Kill a Mockingbird - Select: - - ExprWithAlias: - - Add: + - AliasedExpr: + - Expr.add: - Field: rating - Constant: 1 - "ratingPlusOne" - - ExprWithAlias: - - Subtract: + - AliasedExpr: + - Expr.subtract: - Field: published - Constant: 1900 - "yearsSince1900" - - ExprWithAlias: - - Multiply: + - AliasedExpr: + - Expr.multiply: - Field: rating - Constant: 10 - "ratingTimesTen" - - ExprWithAlias: - - Divide: + - AliasedExpr: + - Expr.divide: - Field: rating - Constant: 2 - "ratingDividedByTwo" - - ExprWithAlias: - - Multiply: + - AliasedExpr: + - Expr.multiply: - Field: rating - Constant: 20 - "ratingTimes20" - - ExprWithAlias: - - Add: + - AliasedExpr: + - Expr.add: - Field: rating - Constant: 3 - "ratingPlus3" - - ExprWithAlias: - - Mod: + - AliasedExpr: + - Expr.mod: - Field: rating - Constant: 2 - "ratingMod2" @@ -1162,7 +1182,7 @@ tests: args: - fieldReferenceValue: title - stringValue: To Kill a Mockingbird - name: eq + name: equal name: where - args: - mapValue: @@ -1215,13 +1235,13 @@ tests: - Collection: books - Where: - And: - - Gt: + - Expr.greater_than: - Field: rating - Constant: 4.2 - - Lte: + - Expr.less_than_or_equal: - Field: rating - Constant: 4.5 - - Neq: + - Expr.not_equal: - Field: genre - Constant: Science Fiction - Select: @@ -1251,17 +1271,17 @@ tests: args: - fieldReferenceValue: rating - doubleValue: 4.2 - name: gt + name: greater_than - functionValue: args: - fieldReferenceValue: rating - doubleValue: 4.5 - name: lte + name: less_than_or_equal - functionValue: args: - fieldReferenceValue: genre - stringValue: Science Fiction - name: neq + name: not_equal name: and name: where - args: @@ -1286,13 +1306,13 @@ tests: - Where: - Or: - And: - - Gt: + - Expr.greater_than: - Field: rating - Constant: 4.5 - - Eq: + - Expr.equal: - Field: genre - Constant: Science Fiction - - Lt: + - Expr.less_than: - Field: published - Constant: 1900 - Select: @@ -1320,18 +1340,18 @@ tests: args: - fieldReferenceValue: rating - doubleValue: 4.5 - name: gt + name: greater_than - functionValue: args: - fieldReferenceValue: genre - stringValue: Science Fiction - name: eq + name: equal name: and - functionValue: args: - fieldReferenceValue: published - integerValue: '1900' - name: lt + name: less_than name: or name: where - args: @@ -1353,12 +1373,12 @@ tests: - Collection: books - Where: - Not: - - IsNaN: + - Expr.is_nan: - Field: rating - Select: - - ExprWithAlias: + - AliasedExpr: - Not: - - IsNaN: + - Expr.is_nan: - Field: rating - "ratingIsNotNaN" - Limit: 1 @@ -1398,23 +1418,23 @@ tests: pipeline: - Collection: books - Where: - - Eq: + - Expr.equal: - Field: author - Constant: Douglas Adams - Select: - - ExprWithAlias: - - LogicalMax: + - AliasedExpr: + - Expr.logical_maximum: - Field: rating - Constant: 4.5 - "max_rating" - - ExprWithAlias: - - LogicalMax: + - AliasedExpr: + - Expr.logical_minimum: - Field: published - Constant: 1900 - - "max_published" + - "min_published" assert_results: - max_rating: 4.5 - max_published: 1979 + min_published: 1900 assert_proto: pipeline: stages: @@ -1426,23 +1446,23 @@ tests: args: - fieldReferenceValue: author - stringValue: Douglas Adams - name: eq + name: equal name: where - args: - mapValue: fields: - max_published: + min_published: functionValue: args: - fieldReferenceValue: published - integerValue: '1900' - name: logical_maximum + name: minimum max_rating: functionValue: args: - fieldReferenceValue: rating - doubleValue: 4.5 - name: logical_maximum + name: maximum name: select - description: testMapGet pipeline: @@ -1452,14 +1472,14 @@ tests: - Field: published - DESCENDING - Select: - - ExprWithAlias: - - MapGet: + - AliasedExpr: + - Expr.map_get: - Field: awards - - Constant: hugo + - hugo - "hugoAward" - Field: title - Where: - - Eq: + - Expr.equal: - Field: hugoAward - Constant: true assert_results: @@ -1498,13 +1518,13 @@ tests: args: - fieldReferenceValue: hugoAward - booleanValue: true - name: eq + name: equal name: where - description: testNestedFields pipeline: - Collection: books - Where: - - Eq: + - Expr.equal: - Field: awards.hugo - Constant: true - Sort: @@ -1530,7 +1550,7 @@ tests: args: - fieldReferenceValue: awards.hugo - booleanValue: true - name: eq + name: equal name: where - args: - mapValue: @@ -1604,7 +1624,7 @@ tests: pipeline: - Collection: books - Where: - - Eq: + - Expr.equal: - Field: title - Constant: The Hitchhiker's Guide to the Galaxy - Unnest: @@ -1626,7 +1646,7 @@ tests: args: - fieldReferenceValue: title - stringValue: The Hitchhiker's Guide to the Galaxy - name: eq + name: equal name: where - args: - fieldReferenceValue: tags @@ -1638,3 +1658,301 @@ tests: tags_alias: fieldReferenceValue: tags_alias name: select + - description: testGreaterThanOrEqual + pipeline: + - Collection: books + - Where: + - Expr.greater_than_or_equal: + - Field: rating + - Constant: 4.6 + - Select: + - title + - rating + - Sort: + - Ordering: + - Field: rating + - ASCENDING + assert_results: + - title: Dune + rating: 4.6 + - title: The Lord of the Rings + rating: 4.7 + - description: testInAndNotIn + pipeline: + - Collection: books + - Where: + - And: + - Expr.equal_any: + - Field: genre + - - Constant: Romance + - Constant: Dystopian + - Expr.not_equal_any: + - Field: author + - - Constant: "George Orwell" + assert_results: + - title: "Pride and Prejudice" + author: "Jane Austen" + genre: "Romance" + published: 1813 + rating: 4.5 + tags: + - classic + - social commentary + - love + awards: + none: true + - title: "The Handmaid's Tale" + author: "Margaret Atwood" + genre: "Dystopian" + published: 1985 + rating: 4.1 + tags: + - feminism + - totalitarianism + - resistance + awards: + "arthur c. clarke": true + "booker prize": false + - description: testArrayReverse + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpr: + - Expr.array_reverse: + - Field: tags + - "reversedTags" + assert_results: + - reversedTags: + - adventure + - space + - comedy + - description: testExists + pipeline: + - Collection: books + - Where: + - And: + - Expr.exists: + - Field: awards.pulitzer + - Expr.equal: + - Field: awards.pulitzer + - Constant: true + - Select: + - title + assert_results: + - title: To Kill a Mockingbird + - description: testSum + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: genre + - Constant: Science Fiction + - Aggregate: + - AliasedExpr: + - Expr.sum: + - Field: rating + - "total_rating" + assert_results: + - total_rating: 8.8 + - description: testStringContains + pipeline: + - Collection: books + - Where: + - Expr.string_contains: + - Field: title + - Constant: "Hitchhiker's" + - Select: + - title + assert_results: + - title: "The Hitchhiker's Guide to the Galaxy" + - description: testVectorLength + pipeline: + - Collection: vectors + - Select: + - AliasedExpr: + - Expr.vector_length: + - Field: embedding + - "embedding_length" + - Sort: + - Ordering: + - Field: embedding_length + - ASCENDING + assert_results: + - embedding_length: 3 + - embedding_length: 4 + - description: testTimestampFunctions + pipeline: + - Collection: timestamps + - Select: + - AliasedExpr: + - Expr.timestamp_to_unix_micros: + - Field: time + - "micros" + - AliasedExpr: + - Expr.timestamp_to_unix_millis: + - Field: time + - "millis" + - AliasedExpr: + - Expr.timestamp_to_unix_seconds: + - Field: time + - "seconds" + - AliasedExpr: + - Expr.unix_micros_to_timestamp: + - Field: micros + - "from_micros" + - AliasedExpr: + - Expr.unix_millis_to_timestamp: + - Field: millis + - "from_millis" + - AliasedExpr: + - Expr.unix_seconds_to_timestamp: + - Field: seconds + - "from_seconds" + - AliasedExpr: + - Expr.timestamp_add: + - Field: time + - Constant: "day" + - Constant: 1 + - "plus_day" + - AliasedExpr: + - Expr.timestamp_subtract: + - Field: time + - Constant: "hour" + - Constant: 1 + - "minus_hour" + assert_results: + - micros: 735998460654321 + millis: 735998460654 + seconds: 735998460 + from_micros: "1993-04-28T12:01:00.654321+00:00" + from_millis: "1993-04-28T12:01:00.654000+00:00" + from_seconds: "1993-04-28T12:01:00.000000+00:00" + plus_day: "1993-04-29T12:01:00.654321+00:00" + minus_hour: "1993-04-28T11:01:00.654321+00:00" + - description: testCollectionId + pipeline: + - Collection: books + - Limit: 1 + - Select: + - AliasedExpr: + - Expr.collection_id: + - Field: __name__ + - "collectionName" + assert_results: + - collectionName: "books" + - description: testXor + pipeline: + - Collection: books + - Where: + - Xor: + - - Expr.equal: + - Field: genre + - Constant: Romance + - Expr.greater_than: + - Field: published + - Constant: 1980 + - Select: + - title + - genre + - published + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "Pride and Prejudice" + genre: "Romance" + published: 1813 + - title: "The Handmaid's Tale" + genre: "Dystopian" + published: 1985 + - description: testConditional + pipeline: + - Collection: books + - Select: + - title + - AliasedExpr: + - Conditional: + - Expr.greater_than: + - Field: published + - Constant: 1950 + - Constant: "Modern" + - Constant: "Classic" + - "era" + - Sort: + - Ordering: + - Field: title + - ASCENDING + - Limit: 4 + assert_results: + - title: "1984" + era: "Classic" + - title: "Crime and Punishment" + era: "Classic" + - title: "Dune" + era: "Modern" + - title: "One Hundred Years of Solitude" + era: "Modern" + - description: testFieldToFieldArithmetic + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "Dune" + - Select: + - AliasedExpr: + - Expr.add: + - Field: published + - Field: rating + - "pub_plus_rating" + assert_results: + - pub_plus_rating: 1969.6 + - description: testFieldToFieldComparison + pipeline: + - Collection: books + - Where: + - Expr.greater_than: + - Field: published + - Field: rating + - Select: + - title + assert_count: 10 # All books were published after year 4.7 + - description: testExistsNegative + pipeline: + - Collection: books + - Where: + - Expr.exists: + - Field: non_existent_field + assert_count: 0 + - description: testConditionalWithFields + pipeline: + - Collection: books + - Where: + - Expr.equal_any: + - Field: title + - - Constant: "Dune" + - Constant: "1984" + - Select: + - title + - AliasedExpr: + - Conditional: + - Expr.greater_than: + - Field: published + - Constant: 1950 + - Field: author + - Field: genre + - "conditional_field" + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "1984" + conditional_field: "Dystopian" + - title: "Dune" + conditional_field: "Frank Herbert" diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 9d44bbc57..d4c654e63 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -17,6 +17,7 @@ from __future__ import annotations import os +import datetime import pytest import yaml import re @@ -26,6 +27,7 @@ from google.cloud.firestore_v1 import _pipeline_stages as stages from google.cloud.firestore_v1 import pipeline_expressions +from google.cloud.firestore_v1.vector import Vector from google.api_core.exceptions import GoogleAPIError from google.cloud.firestore import Client, AsyncClient @@ -91,7 +93,7 @@ def test_pipeline_results(test_dict, client): """ Ensure pipeline returns expected results """ - expected_results = test_dict.get("assert_results", None) + expected_results = _parse_yaml_types(test_dict.get("assert_results", None)) expected_count = test_dict.get("assert_count", None) pipeline = parse_pipeline(client, test_dict["pipeline"]) # check if server responds as expected @@ -132,7 +134,7 @@ async def test_pipeline_results_async(test_dict, async_client): """ Ensure pipeline returns expected results """ - expected_results = test_dict.get("assert_results", None) + expected_results = _parse_yaml_types(test_dict.get("assert_results", None)) expected_count = test_dict.get("assert_count", None) pipeline = parse_pipeline(async_client, test_dict["pipeline"]) # check if server responds as expected @@ -160,7 +162,7 @@ def parse_pipeline(client, pipeline: list[dict[str, Any], str]): # find arguments if given if isinstance(stage, dict): stage_yaml_args = stage[stage_name] - stage_obj = _apply_yaml_args(stage_cls, client, stage_yaml_args) + stage_obj = _apply_yaml_args_to_callable(stage_cls, client, stage_yaml_args) else: # yaml has no arguments stage_obj = stage_cls() @@ -178,15 +180,21 @@ def _parse_expressions(client, yaml_element: Any): if len(yaml_element) == 1 and _is_expr_string(next(iter(yaml_element))): # build pipeline expressions if possible cls_str = next(iter(yaml_element)) - cls = getattr(pipeline_expressions, cls_str) + callable_obj = None + if "." in cls_str: + cls_name, method_name = cls_str.split(".") + cls = getattr(pipeline_expressions, cls_name) + callable_obj = getattr(cls, method_name) + else: + callable_obj = getattr(pipeline_expressions, cls_str) yaml_args = yaml_element[cls_str] - return _apply_yaml_args(cls, client, yaml_args) + return _apply_yaml_args_to_callable(callable_obj, client, yaml_args) elif len(yaml_element) == 1 and _is_stage_string(next(iter(yaml_element))): # build pipeline stage if possible (eg, for SampleOptions) cls_str = next(iter(yaml_element)) cls = getattr(stages, cls_str) yaml_args = yaml_element[cls_str] - return _apply_yaml_args(cls, client, yaml_args) + return _apply_yaml_args_to_callable(cls, client, yaml_args) elif len(yaml_element) == 1 and list(yaml_element)[0] == "Pipeline": # find Pipeline objects for Union expressions other_ppl = yaml_element["Pipeline"] @@ -203,25 +211,33 @@ def _parse_expressions(client, yaml_element: Any): return yaml_element -def _apply_yaml_args(cls, client, yaml_args): +def _apply_yaml_args_to_callable(callable_obj, client, yaml_args): """ Helper to instantiate a class with yaml arguments. The arguments will be applied as positional or keyword arguments, based on type """ if isinstance(yaml_args, dict): - return cls(**_parse_expressions(client, yaml_args)) + return callable_obj(**_parse_expressions(client, yaml_args)) elif isinstance(yaml_args, list): # yaml has an array of arguments. Treat as args - return cls(*_parse_expressions(client, yaml_args)) + return callable_obj(*_parse_expressions(client, yaml_args)) else: # yaml has a single argument - return cls(_parse_expressions(client, yaml_args)) + return callable_obj(_parse_expressions(client, yaml_args)) def _is_expr_string(yaml_str): """ Returns true if a string represents a class in pipeline_expressions """ + if isinstance(yaml_str, str) and "." in yaml_str: + parts = yaml_str.split(".") + if len(parts) == 2: + cls_name, method_name = parts + if hasattr(pipeline_expressions, cls_name): + cls = getattr(pipeline_expressions, cls_name) + if hasattr(cls, method_name): + return True return ( isinstance(yaml_str, str) and yaml_str[0].isupper() @@ -251,6 +267,26 @@ def event_loop(): loop.close() +def _parse_yaml_types(data): + """helper to convert yaml data to firestore objects when needed""" + if isinstance(data, dict): + return {key: _parse_yaml_types(value) for key, value in data.items()} + if isinstance(data, list): + # detect vectors + if all([isinstance(d, float) for d in data]): + return Vector(data) + else: + return [_parse_yaml_types(value) for value in data] + # detect timestamps + if isinstance(data, str) and ":" in data: + try: + parsed_datetime = datetime.datetime.fromisoformat(data) + return parsed_datetime + except ValueError: + pass + return data + + @pytest.fixture(scope="module") def client(): """ @@ -258,6 +294,7 @@ def client(): """ client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_ENTERPRISE_DB) data = yaml_loader("data") + to_delete = [] try: # setup data batch = client.batch() @@ -265,16 +302,14 @@ def client(): collection_ref = client.collection(collection_name) for document_id, document_data in documents.items(): document_ref = collection_ref.document(document_id) - batch.set(document_ref, document_data) + to_delete.append(document_ref) + batch.set(document_ref, _parse_yaml_types(document_data)) batch.commit() yield client finally: # clear data - for collection_name, documents in data.items(): - collection_ref = client.collection(collection_name) - for document_id in documents: - document_ref = collection_ref.document(document_id) - document_ref.delete() + for document_ref in to_delete: + document_ref.delete() @pytest.fixture(scope="module") diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py index 47eedc983..b3ed83337 100644 --- a/tests/unit/v1/test_async_pipeline.py +++ b/tests/unit/v1/test_async_pipeline.py @@ -17,7 +17,6 @@ from google.cloud.firestore_v1 import _pipeline_stages as stages from google.cloud.firestore_v1.pipeline_expressions import Field -from google.cloud.firestore_v1.pipeline_expressions import Exists def _make_async_pipeline(*args, client=mock.Mock()): @@ -386,7 +385,7 @@ async def test_async_pipeline_stream_stream_equivalence_mocked(): ("remove_fields", (Field.of("n"),), stages.RemoveFields), ("select", ("name",), stages.Select), ("select", (Field.of("n"),), stages.Select), - ("where", (Exists(Field.of("n")),), stages.Where), + ("where", (Field.of("n").exists(),), stages.Where), ("find_nearest", ("name", [0.1], 0), stages.FindNearest), ( "find_nearest", diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index b237ad5ac..f90279e00 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -17,7 +17,6 @@ from google.cloud.firestore_v1 import _pipeline_stages as stages from google.cloud.firestore_v1.pipeline_expressions import Field -from google.cloud.firestore_v1.pipeline_expressions import Exists def _make_pipeline(*args, client=mock.Mock()): @@ -363,7 +362,7 @@ def test_pipeline_execute_stream_equivalence_mocked(): ("remove_fields", (Field.of("n"),), stages.RemoveFields), ("select", ("name",), stages.Select), ("select", (Field.of("n"),), stages.Select), - ("where", (Exists(Field.of("n")),), stages.Where), + ("where", (Field.of("n").exists(),), stages.Where), ("find_nearest", ("name", [0.1], 0), stages.FindNearest), ( "find_nearest", diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 936c0a0a9..c5329df33 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -22,8 +22,13 @@ from google.cloud.firestore_v1.types.document import Value from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1._helpers import GeoPoint -from google.cloud.firestore_v1.pipeline_expressions import FilterCondition, ListOfExprs import google.cloud.firestore_v1.pipeline_expressions as expr +from google.cloud.firestore_v1.pipeline_expressions import BooleanExpr +from google.cloud.firestore_v1.pipeline_expressions import _ListOfExprs +from google.cloud.firestore_v1.pipeline_expressions import Expr +from google.cloud.firestore_v1.pipeline_expressions import Constant +from google.cloud.firestore_v1.pipeline_expressions import Field +from google.cloud.firestore_v1.pipeline_expressions import Ordering @pytest.fixture @@ -37,126 +42,43 @@ class TestOrdering: @pytest.mark.parametrize( "direction_arg,expected_direction", [ - ("ASCENDING", expr.Ordering.Direction.ASCENDING), - ("DESCENDING", expr.Ordering.Direction.DESCENDING), - ("ascending", expr.Ordering.Direction.ASCENDING), - ("descending", expr.Ordering.Direction.DESCENDING), - (expr.Ordering.Direction.ASCENDING, expr.Ordering.Direction.ASCENDING), - (expr.Ordering.Direction.DESCENDING, expr.Ordering.Direction.DESCENDING), + ("ASCENDING", Ordering.Direction.ASCENDING), + ("DESCENDING", Ordering.Direction.DESCENDING), + ("ascending", Ordering.Direction.ASCENDING), + ("descending", Ordering.Direction.DESCENDING), + (Ordering.Direction.ASCENDING, Ordering.Direction.ASCENDING), + (Ordering.Direction.DESCENDING, Ordering.Direction.DESCENDING), ], ) def test_ctor(self, direction_arg, expected_direction): - instance = expr.Ordering("field1", direction_arg) - assert isinstance(instance.expr, expr.Field) + instance = Ordering("field1", direction_arg) + assert isinstance(instance.expr, Field) assert instance.expr.path == "field1" assert instance.order_dir == expected_direction def test_repr(self): - field_expr = expr.Field.of("field1") - instance = expr.Ordering(field_expr, "ASCENDING") + field_expr = Field.of("field1") + instance = Ordering(field_expr, "ASCENDING") repr_str = repr(instance) assert repr_str == "Field.of('field1').ascending()" - instance = expr.Ordering(field_expr, "DESCENDING") + instance = Ordering(field_expr, "DESCENDING") repr_str = repr(instance) assert repr_str == "Field.of('field1').descending()" def test_to_pb(self): - field_expr = expr.Field.of("field1") - instance = expr.Ordering(field_expr, "ASCENDING") + field_expr = Field.of("field1") + instance = Ordering(field_expr, "ASCENDING") result = instance._to_pb() assert result.map_value.fields["expression"].field_reference_value == "field1" assert result.map_value.fields["direction"].string_value == "ascending" - instance = expr.Ordering(field_expr, "DESCENDING") + instance = Ordering(field_expr, "DESCENDING") result = instance._to_pb() assert result.map_value.fields["expression"].field_reference_value == "field1" assert result.map_value.fields["direction"].string_value == "descending" -class TestExpr: - def test_ctor(self): - """ - Base class should be abstract - """ - with pytest.raises(TypeError): - expr.Expr() - - @pytest.mark.parametrize( - "method,args,result_cls", - [ - ("add", (2,), expr.Add), - ("subtract", (2,), expr.Subtract), - ("multiply", (2,), expr.Multiply), - ("divide", (2,), expr.Divide), - ("mod", (2,), expr.Mod), - ("logical_max", (2,), expr.LogicalMax), - ("logical_min", (2,), expr.LogicalMin), - ("eq", (2,), expr.Eq), - ("neq", (2,), expr.Neq), - ("lt", (2,), expr.Lt), - ("lte", (2,), expr.Lte), - ("gt", (2,), expr.Gt), - ("gte", (2,), expr.Gte), - ("in_any", ([None],), expr.In), - ("not_in_any", ([None],), expr.Not), - ("array_contains", (None,), expr.ArrayContains), - ("array_contains_all", ([None],), expr.ArrayContainsAll), - ("array_contains_any", ([None],), expr.ArrayContainsAny), - ("array_length", (), expr.ArrayLength), - ("array_reverse", (), expr.ArrayReverse), - ("is_nan", (), expr.IsNaN), - ("exists", (), expr.Exists), - ("sum", (), expr.Sum), - ("avg", (), expr.Avg), - ("count", (), expr.Count), - ("min", (), expr.Min), - ("max", (), expr.Max), - ("char_length", (), expr.CharLength), - ("byte_length", (), expr.ByteLength), - ("like", ("pattern",), expr.Like), - ("regex_contains", ("regex",), expr.RegexContains), - ("regex_matches", ("regex",), expr.RegexMatch), - ("str_contains", ("substring",), expr.StrContains), - ("starts_with", ("prefix",), expr.StartsWith), - ("ends_with", ("postfix",), expr.EndsWith), - ("str_concat", ("elem1", expr.Constant("elem2")), expr.StrConcat), - ("map_get", ("key",), expr.MapGet), - ("vector_length", (), expr.VectorLength), - ("timestamp_to_unix_micros", (), expr.TimestampToUnixMicros), - ("unix_micros_to_timestamp", (), expr.UnixMicrosToTimestamp), - ("timestamp_to_unix_millis", (), expr.TimestampToUnixMillis), - ("unix_millis_to_timestamp", (), expr.UnixMillisToTimestamp), - ("timestamp_to_unix_seconds", (), expr.TimestampToUnixSeconds), - ("unix_seconds_to_timestamp", (), expr.UnixSecondsToTimestamp), - ("timestamp_add", ("day", 1), expr.TimestampAdd), - ("timestamp_sub", ("hour", 2.5), expr.TimestampSub), - ("ascending", (), expr.Ordering), - ("descending", (), expr.Ordering), - ("as_", ("alias",), expr.ExprWithAlias), - ], - ) - @pytest.mark.parametrize( - "base_instance", - [ - expr.Constant(1), - expr.Function.add("1", 1), - expr.Field.of("test"), - expr.Constant(1).as_("one"), - ], - ) - def test_infix_call(self, method, args, result_cls, base_instance): - """ - many FilterCondition expressions support infix execution, and are exposed as methods on Expr. Test calling them - """ - method_ptr = getattr(base_instance, method) - - result = method_ptr(*args) - assert isinstance(result, result_cls) - if isinstance(result, expr.Function) and not method == "not_in_any": - assert result.params[0] == base_instance - - class TestConstant: @pytest.mark.parametrize( "input_val, to_pb_val", @@ -200,7 +122,7 @@ class TestConstant: ], ) def test_to_pb(self, input_val, to_pb_val): - instance = expr.Constant.of(input_val) + instance = Constant.of(input_val) assert instance._to_pb() == to_pb_val @pytest.mark.parametrize( @@ -226,25 +148,25 @@ def test_to_pb(self, input_val, to_pb_val): ], ) def test_repr(self, input_val, expected): - instance = expr.Constant.of(input_val) + instance = Constant.of(input_val) repr_string = repr(instance) assert repr_string == expected @pytest.mark.parametrize( "first,second,expected", [ - (expr.Constant.of(1), expr.Constant.of(2), False), - (expr.Constant.of(1), expr.Constant.of(1), True), - (expr.Constant.of(1), 1, True), - (expr.Constant.of(1), 2, False), - (expr.Constant.of("1"), 1, False), - (expr.Constant.of("1"), "1", True), - (expr.Constant.of(None), expr.Constant.of(0), False), - (expr.Constant.of(None), expr.Constant.of(None), True), - (expr.Constant.of([1, 2, 3]), expr.Constant.of([1, 2, 3]), True), - (expr.Constant.of([1, 2, 3]), expr.Constant.of([1, 2]), False), - (expr.Constant.of([1, 2, 3]), [1, 2, 3], True), - (expr.Constant.of([1, 2, 3]), object(), False), + (Constant.of(1), Constant.of(2), False), + (Constant.of(1), Constant.of(1), True), + (Constant.of(1), 1, True), + (Constant.of(1), 2, False), + (Constant.of("1"), 1, False), + (Constant.of("1"), "1", True), + (Constant.of(None), Constant.of(0), False), + (Constant.of(None), Constant.of(None), True), + (Constant.of([1, 2, 3]), Constant.of([1, 2, 3]), True), + (Constant.of([1, 2, 3]), Constant.of([1, 2]), False), + (Constant.of([1, 2, 3]), [1, 2, 3], True), + (Constant.of([1, 2, 3]), object(), False), ], ) def test_equality(self, first, second, expected): @@ -253,49 +175,49 @@ def test_equality(self, first, second, expected): class TestListOfExprs: def test_to_pb(self): - instance = expr.ListOfExprs([expr.Constant(1), expr.Constant(2)]) + instance = _ListOfExprs([Constant(1), Constant(2)]) result = instance._to_pb() assert len(result.array_value.values) == 2 assert result.array_value.values[0].integer_value == 1 assert result.array_value.values[1].integer_value == 2 def test_empty_to_pb(self): - instance = expr.ListOfExprs([]) + instance = _ListOfExprs([]) result = instance._to_pb() assert len(result.array_value.values) == 0 def test_repr(self): - instance = expr.ListOfExprs([expr.Constant(1), expr.Constant(2)]) + instance = _ListOfExprs([Constant(1), Constant(2)]) repr_string = repr(instance) - assert repr_string == "ListOfExprs([Constant.of(1), Constant.of(2)])" - empty_instance = expr.ListOfExprs([]) + assert repr_string == "[Constant.of(1), Constant.of(2)]" + empty_instance = _ListOfExprs([]) empty_repr_string = repr(empty_instance) - assert empty_repr_string == "ListOfExprs([])" + assert empty_repr_string == "[]" @pytest.mark.parametrize( "first,second,expected", [ - (expr.ListOfExprs([]), expr.ListOfExprs([]), True), - (expr.ListOfExprs([]), expr.ListOfExprs([expr.Constant(1)]), False), - (expr.ListOfExprs([expr.Constant(1)]), expr.ListOfExprs([]), False), + (_ListOfExprs([]), _ListOfExprs([]), True), + (_ListOfExprs([]), _ListOfExprs([Constant(1)]), False), + (_ListOfExprs([Constant(1)]), _ListOfExprs([]), False), ( - expr.ListOfExprs([expr.Constant(1)]), - expr.ListOfExprs([expr.Constant(1)]), + _ListOfExprs([Constant(1)]), + _ListOfExprs([Constant(1)]), True, ), ( - expr.ListOfExprs([expr.Constant(1)]), - expr.ListOfExprs([expr.Constant(2)]), + _ListOfExprs([Constant(1)]), + _ListOfExprs([Constant(2)]), False, ), ( - expr.ListOfExprs([expr.Constant(1), expr.Constant(2)]), - expr.ListOfExprs([expr.Constant(1), expr.Constant(2)]), + _ListOfExprs([Constant(1), Constant(2)]), + _ListOfExprs([Constant(1), Constant(2)]), True, ), - (expr.ListOfExprs([expr.Constant(1)]), [expr.Constant(1)], False), - (expr.ListOfExprs([expr.Constant(1)]), [1], False), - (expr.ListOfExprs([expr.Constant(1)]), object(), False), + (_ListOfExprs([Constant(1)]), [Constant(1)], False), + (_ListOfExprs([Constant(1)]), [1], False), + (_ListOfExprs([Constant(1)]), object(), False), ], ) def test_equality(self, first, second, expected): @@ -316,8 +238,8 @@ def test_ctor(self): def test_value_from_selectables(self): selectable_list = [ - expr.Field.of("field1"), - expr.Field.of("field2").as_("alias2"), + Field.of("field1"), + Field.of("field2").as_("alias2"), ] result = expr.Selectable._value_from_selectables(*selectable_list) assert len(result.map_value.fields) == 2 @@ -327,14 +249,14 @@ def test_value_from_selectables(self): @pytest.mark.parametrize( "first,second,expected", [ - (expr.Field.of("field1"), expr.Field.of("field1"), True), - (expr.Field.of("field1"), expr.Field.of("field2"), False), - (expr.Field.of(None), object(), False), - (expr.Field.of("f").as_("a"), expr.Field.of("f").as_("a"), True), - (expr.Field.of("one").as_("a"), expr.Field.of("two").as_("a"), False), - (expr.Field.of("f").as_("one"), expr.Field.of("f").as_("two"), False), - (expr.Field.of("field"), expr.Field.of("field").as_("alias"), False), - (expr.Field.of("field").as_("alias"), expr.Field.of("field"), False), + (Field.of("field1"), Field.of("field1"), True), + (Field.of("field1"), Field.of("field2"), False), + (Field.of(None), object(), False), + (Field.of("f").as_("a"), Field.of("f").as_("a"), True), + (Field.of("one").as_("a"), Field.of("two").as_("a"), False), + (Field.of("f").as_("one"), Field.of("f").as_("two"), False), + (Field.of("field"), Field.of("field").as_("alias"), False), + (Field.of("field").as_("alias"), Field.of("field"), False), ], ) def test_equality(self, first, second, expected): @@ -342,52 +264,79 @@ def test_equality(self, first, second, expected): class TestField: def test_repr(self): - instance = expr.Field.of("field1") + instance = Field.of("field1") repr_string = repr(instance) assert repr_string == "Field.of('field1')" def test_of(self): - instance = expr.Field.of("field1") + instance = Field.of("field1") assert instance.path == "field1" def test_to_pb(self): - instance = expr.Field.of("field1") + instance = Field.of("field1") result = instance._to_pb() assert result.field_reference_value == "field1" def test_to_map(self): - instance = expr.Field.of("field1") + instance = Field.of("field1") result = instance._to_map() assert result[0] == "field1" assert result[1] == Value(field_reference_value="field1") - class TestExprWithAlias: + class TestAliasedExpr: def test_repr(self): - instance = expr.Field.of("field1").as_("alias1") + instance = Field.of("field1").as_("alias1") assert repr(instance) == "Field.of('field1').as_('alias1')" def test_ctor(self): - arg = expr.Field.of("field1") + arg = Field.of("field1") alias = "alias1" - instance = expr.ExprWithAlias(arg, alias) + instance = expr.AliasedExpr(arg, alias) assert instance.expr == arg assert instance.alias == alias def test_to_pb(self): - arg = expr.Field.of("field1") + arg = Field.of("field1") alias = "alias1" - instance = expr.ExprWithAlias(arg, alias) + instance = expr.AliasedExpr(arg, alias) result = instance._to_pb() assert result.map_value.fields.get("alias1") == arg._to_pb() def test_to_map(self): - instance = expr.Field.of("field1").as_("alias1") + instance = Field.of("field1").as_("alias1") result = instance._to_map() assert result[0] == "alias1" assert result[1] == Value(field_reference_value="field1") + class TestAliasedAggregate: + def test_repr(self): + instance = Field.of("field1").maximum().as_("alias1") + assert repr(instance) == "Field.of('field1').maximum().as_('alias1')" + + def test_ctor(self): + arg = Expr.minimum("field1") + alias = "alias1" + instance = expr.AliasedAggregate(arg, alias) + assert instance.expr == arg + assert instance.alias == alias + + def test_to_pb(self): + arg = Field.of("field1").average() + alias = "alias1" + instance = expr.AliasedAggregate(arg, alias) + result = instance._to_pb() + assert result.map_value.fields.get("alias1") == arg._to_pb() + + def test_to_map(self): + arg = Field.of("field1").count() + alias = "alias1" + instance = expr.AliasedAggregate(arg, alias) + result = instance._to_map() + assert result[0] == "alias1" + assert result[1] == arg._to_pb() + -class TestFilterCondition: +class TestBooleanExpr: def test__from_query_filter_pb_composite_filter_or(self, mock_client): """ test composite OR filters @@ -415,17 +364,13 @@ def test__from_query_filter_pb_composite_filter_or(self, mock_client): composite_filter=composite_pb ) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) # should include existance checks - expected_cond1 = expr.And( - expr.Exists(expr.Field.of("field1")), - expr.Eq(expr.Field.of("field1"), expr.Constant("val1")), - ) - expected_cond2 = expr.And( - expr.Exists(expr.Field.of("field2")), - expr.Eq(expr.Field.of("field2"), expr.Constant(None)), - ) + field1 = Field.of("field1") + field2 = Field.of("field2") + expected_cond1 = expr.And(field1.exists(), field1.equal(Constant("val1"))) + expected_cond2 = expr.And(field2.exists(), field2.equal(Constant(None))) expected = expr.Or(expected_cond1, expected_cond2) assert repr(result) == repr(expected) @@ -458,17 +403,13 @@ def test__from_query_filter_pb_composite_filter_and(self, mock_client): composite_filter=composite_pb ) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) # should include existance checks - expected_cond1 = expr.And( - expr.Exists(expr.Field.of("field1")), - expr.Gt(expr.Field.of("field1"), expr.Constant(100)), - ) - expected_cond2 = expr.And( - expr.Exists(expr.Field.of("field2")), - expr.Lt(expr.Field.of("field2"), expr.Constant(200)), - ) + field1 = Field.of("field1") + field2 = Field.of("field2") + expected_cond1 = expr.And(field1.exists(), field1.greater_than(Constant(100))) + expected_cond2 = expr.And(field2.exists(), field2.less_than(Constant(200))) expected = expr.And(expected_cond1, expected_cond2) assert repr(result) == repr(expected) @@ -509,19 +450,15 @@ def test__from_query_filter_pb_composite_filter_nested(self, mock_client): composite_filter=outer_or_pb ) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) - expected_cond1 = expr.And( - expr.Exists(expr.Field.of("field1")), - expr.Eq(expr.Field.of("field1"), expr.Constant("val1")), - ) - expected_cond2 = expr.And( - expr.Exists(expr.Field.of("field2")), - expr.Gt(expr.Field.of("field2"), expr.Constant(10)), - ) + field1 = Field.of("field1") + field2 = Field.of("field2") + field3 = Field.of("field3") + expected_cond1 = expr.And(field1.exists(), field1.equal(Constant("val1"))) + expected_cond2 = expr.And(field2.exists(), field2.greater_than(Constant(10))) expected_cond3 = expr.And( - expr.Exists(expr.Field.of("field3")), - expr.Not(expr.Eq(expr.Field.of("field3"), expr.Constant(None))), + field3.exists(), expr.Not(field3.equal(Constant(None))) ) expected_inner_and = expr.And(expected_cond2, expected_cond3) expected_outer_or = expr.Or(expected_cond1, expected_inner_and) @@ -546,23 +483,23 @@ def test__from_query_filter_pb_composite_filter_unknown_op(self, mock_client): ) with pytest.raises(TypeError, match="Unexpected CompositeFilter operator type"): - FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) @pytest.mark.parametrize( "op_enum, expected_expr_func", [ - (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NAN, expr.IsNaN), + (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NAN, Expr.is_nan), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NAN, lambda f: expr.Not(f.is_nan()), ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, - lambda f: f.eq(None), + lambda f: f.equal(None), ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, - lambda f: expr.Not(f.eq(None)), + lambda f: expr.Not(f.equal(None)), ), ], ) @@ -579,12 +516,12 @@ def test__from_query_filter_pb_unary_filter( ) wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) - field_expr_inst = expr.Field.of(field_path) + field_expr_inst = Field.of(field_path) expected_condition = expected_expr_func(field_expr_inst) # should include existance checks - expected = expr.And(expr.Exists(field_expr_inst), expected_condition) + expected = expr.And(field_expr_inst.exists(), expected_condition) assert repr(result) == repr(expected) @@ -600,40 +537,56 @@ def test__from_query_filter_pb_unary_filter_unknown_op(self, mock_client): wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) with pytest.raises(TypeError, match="Unexpected UnaryFilter operator type"): - FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) @pytest.mark.parametrize( "op_enum, value, expected_expr_func", [ - (query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, 10, expr.Lt), + ( + query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, + 10, + Expr.less_than, + ), ( query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN_OR_EQUAL, 10, - expr.Lte, + Expr.less_than_or_equal, + ), + ( + query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + 10, + Expr.greater_than, ), - (query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, 10, expr.Gt), ( query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN_OR_EQUAL, 10, - expr.Gte, + Expr.greater_than_or_equal, + ), + (query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, 10, Expr.equal), + ( + query_pb.StructuredQuery.FieldFilter.Operator.NOT_EQUAL, + 10, + Expr.not_equal, ), - (query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, 10, expr.Eq), - (query_pb.StructuredQuery.FieldFilter.Operator.NOT_EQUAL, 10, expr.Neq), ( query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS, 10, - expr.ArrayContains, + Expr.array_contains, ), ( query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS_ANY, [10, 20], - expr.ArrayContainsAny, + Expr.array_contains_any, + ), + ( + query_pb.StructuredQuery.FieldFilter.Operator.IN, + [10, 20], + Expr.equal_any, ), - (query_pb.StructuredQuery.FieldFilter.Operator.IN, [10, 20], expr.In), ( query_pb.StructuredQuery.FieldFilter.Operator.NOT_IN, [10, 20], - lambda f, v: expr.Not(f.in_any(v)), + Expr.not_equal_any, ), ], ) @@ -652,18 +605,16 @@ def test__from_query_filter_pb_field_filter( ) wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) - field_expr = expr.Field.of(field_path) + field_expr = Field.of(field_path) # convert values into constants value = ( - [expr.Constant(e) for e in value] - if isinstance(value, list) - else expr.Constant(value) + [Constant(e) for e in value] if isinstance(value, list) else Constant(value) ) expected_condition = expected_expr_func(field_expr, value) # should include existance checks - expected = expr.And(expr.Exists(field_expr), expected_condition) + expected = expr.And(field_expr.exists(), expected_condition) assert repr(result) == repr(expected) @@ -681,7 +632,7 @@ def test__from_query_filter_pb_field_filter_unknown_op(self, mock_client): wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) with pytest.raises(TypeError, match="Unexpected FieldFilter operator type"): - FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) def test__from_query_filter_pb_unknown_filter_type(self, mock_client): """ @@ -689,26 +640,64 @@ def test__from_query_filter_pb_unknown_filter_type(self, mock_client): """ # Test with an unexpected protobuf type with pytest.raises(TypeError, match="Unexpected filter type"): - FilterCondition._from_query_filter_pb(document_pb.Value(), mock_client) + BooleanExpr._from_query_filter_pb(document_pb.Value(), mock_client) -class TestFilterConditionClasses: +class TestExpressionMethods: """ - contains test methods for each Expr class that derives from FilterCondition + contains test methods for each Expr method """ + @pytest.mark.parametrize( + "first,second,expected", + [ + ( + Field.of("a").char_length(), + Field.of("a").char_length(), + True, + ), + ( + Field.of("a").char_length(), + Field.of("b").char_length(), + False, + ), + ( + Field.of("a").char_length(), + Field.of("a").byte_length(), + False, + ), + ( + Field.of("a").char_length(), + Field.of("b").byte_length(), + False, + ), + ( + Constant.of("").byte_length(), + Field.of("").byte_length(), + False, + ), + (Field.of("").byte_length(), Field.of("").byte_length(), True), + ], + ) + def test_equality(self, first, second, expected): + assert (first == second) is expected + def _make_arg(self, name="Mock"): - arg = mock.Mock() - arg.__repr__ = lambda x: name + class MockExpr(Constant): + def __repr__(self): + return self.value + + arg = MockExpr(name) return arg def test_and(self): arg1 = self._make_arg() arg2 = self._make_arg() - instance = expr.And(arg1, arg2) + arg3 = self._make_arg() + instance = expr.And(arg1, arg2, arg3) assert instance.name == "and" - assert instance.params == [arg1, arg2] - assert repr(instance) == "And(Mock, Mock)" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "And(Mock, Mock, Mock)" def test_or(self): arg1 = self._make_arg("Arg1") @@ -716,102 +705,134 @@ def test_or(self): instance = expr.Or(arg1, arg2) assert instance.name == "or" assert instance.params == [arg1, arg2] - assert repr(instance) == "Arg1.or(Arg2)" + assert repr(instance) == "Or(Arg1, Arg2)" def test_array_contains(self): arg1 = self._make_arg("ArrayField") arg2 = self._make_arg("Element") - instance = expr.ArrayContains(arg1, arg2) + instance = Expr.array_contains(arg1, arg2) assert instance.name == "array_contains" assert instance.params == [arg1, arg2] assert repr(instance) == "ArrayField.array_contains(Element)" + infix_instance = arg1.array_contains(arg2) + assert infix_instance == instance def test_array_contains_any(self): arg1 = self._make_arg("ArrayField") arg2 = self._make_arg("Element1") arg3 = self._make_arg("Element2") - instance = expr.ArrayContainsAny(arg1, [arg2, arg3]) + instance = Expr.array_contains_any(arg1, [arg2, arg3]) assert instance.name == "array_contains_any" - assert isinstance(instance.params[1], ListOfExprs) + assert isinstance(instance.params[1], _ListOfExprs) assert instance.params[0] == arg1 assert instance.params[1].exprs == [arg2, arg3] - assert ( - repr(instance) - == "ArrayField.array_contains_any(ListOfExprs([Element1, Element2]))" - ) + assert repr(instance) == "ArrayField.array_contains_any([Element1, Element2])" + infix_instance = arg1.array_contains_any([arg2, arg3]) + assert infix_instance == instance def test_exists(self): arg1 = self._make_arg("Field") - instance = expr.Exists(arg1) + instance = Expr.exists(arg1) assert instance.name == "exists" assert instance.params == [arg1] assert repr(instance) == "Field.exists()" + infix_instance = arg1.exists() + assert infix_instance == instance - def test_eq(self): + def test_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Eq(arg1, arg2) - assert instance.name == "eq" + instance = Expr.equal(arg1, arg2) + assert instance.name == "equal" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.eq(Right)" + assert repr(instance) == "Left.equal(Right)" + infix_instance = arg1.equal(arg2) + assert infix_instance == instance - def test_gte(self): + def test_greater_than_or_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Gte(arg1, arg2) - assert instance.name == "gte" + instance = Expr.greater_than_or_equal(arg1, arg2) + assert instance.name == "greater_than_or_equal" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.gte(Right)" + assert repr(instance) == "Left.greater_than_or_equal(Right)" + infix_instance = arg1.greater_than_or_equal(arg2) + assert infix_instance == instance - def test_gt(self): + def test_greater_than(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Gt(arg1, arg2) - assert instance.name == "gt" + instance = Expr.greater_than(arg1, arg2) + assert instance.name == "greater_than" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.gt(Right)" + assert repr(instance) == "Left.greater_than(Right)" + infix_instance = arg1.greater_than(arg2) + assert infix_instance == instance - def test_lte(self): + def test_less_than_or_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Lte(arg1, arg2) - assert instance.name == "lte" + instance = Expr.less_than_or_equal(arg1, arg2) + assert instance.name == "less_than_or_equal" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.lte(Right)" + assert repr(instance) == "Left.less_than_or_equal(Right)" + infix_instance = arg1.less_than_or_equal(arg2) + assert infix_instance == instance - def test_lt(self): + def test_less_than(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Lt(arg1, arg2) - assert instance.name == "lt" + instance = Expr.less_than(arg1, arg2) + assert instance.name == "less_than" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.lt(Right)" + assert repr(instance) == "Left.less_than(Right)" + infix_instance = arg1.less_than(arg2) + assert infix_instance == instance - def test_neq(self): + def test_not_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Neq(arg1, arg2) - assert instance.name == "neq" + instance = Expr.not_equal(arg1, arg2) + assert instance.name == "not_equal" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.neq(Right)" + assert repr(instance) == "Left.not_equal(Right)" + infix_instance = arg1.not_equal(arg2) + assert infix_instance == instance + + def test_equal_any(self): + arg1 = self._make_arg("Field") + arg2 = self._make_arg("Value1") + arg3 = self._make_arg("Value2") + instance = Expr.equal_any(arg1, [arg2, arg3]) + assert instance.name == "equal_any" + assert isinstance(instance.params[1], _ListOfExprs) + assert instance.params[0] == arg1 + assert instance.params[1].exprs == [arg2, arg3] + assert repr(instance) == "Field.equal_any([Value1, Value2])" + infix_instance = arg1.equal_any([arg2, arg3]) + assert infix_instance == instance - def test_in(self): + def test_not_equal_any(self): arg1 = self._make_arg("Field") arg2 = self._make_arg("Value1") arg3 = self._make_arg("Value2") - instance = expr.In(arg1, [arg2, arg3]) - assert instance.name == "in" - assert isinstance(instance.params[1], ListOfExprs) + instance = Expr.not_equal_any(arg1, [arg2, arg3]) + assert instance.name == "not_equal_any" + assert isinstance(instance.params[1], _ListOfExprs) assert instance.params[0] == arg1 assert instance.params[1].exprs == [arg2, arg3] - assert repr(instance) == "Field.in_any(ListOfExprs([Value1, Value2]))" + assert repr(instance) == "Field.not_equal_any([Value1, Value2])" + infix_instance = arg1.not_equal_any([arg2, arg3]) + assert infix_instance == instance def test_is_nan(self): arg1 = self._make_arg("Value") - instance = expr.IsNaN(arg1) + instance = Expr.is_nan(arg1) assert instance.name == "is_nan" assert instance.params == [arg1] assert repr(instance) == "Value.is_nan()" + infix_instance = arg1.is_nan() + assert infix_instance == instance def test_not(self): arg1 = self._make_arg("Condition") @@ -824,72 +845,83 @@ def test_array_contains_all(self): arg1 = self._make_arg("ArrayField") arg2 = self._make_arg("Element1") arg3 = self._make_arg("Element2") - instance = expr.ArrayContainsAll(arg1, [arg2, arg3]) + instance = Expr.array_contains_all(arg1, [arg2, arg3]) assert instance.name == "array_contains_all" - assert isinstance(instance.params[1], ListOfExprs) + assert isinstance(instance.params[1], _ListOfExprs) assert instance.params[0] == arg1 assert instance.params[1].exprs == [arg2, arg3] - assert ( - repr(instance) - == "ArrayField.array_contains_all(ListOfExprs([Element1, Element2]))" - ) + assert repr(instance) == "ArrayField.array_contains_all([Element1, Element2])" + infix_instance = arg1.array_contains_all([arg2, arg3]) + assert infix_instance == instance def test_ends_with(self): arg1 = self._make_arg("Expr") arg2 = self._make_arg("Postfix") - instance = expr.EndsWith(arg1, arg2) + instance = Expr.ends_with(arg1, arg2) assert instance.name == "ends_with" assert instance.params == [arg1, arg2] assert repr(instance) == "Expr.ends_with(Postfix)" + infix_instance = arg1.ends_with(arg2) + assert infix_instance == instance - def test_if(self): + def test_conditional(self): arg1 = self._make_arg("Condition") - arg2 = self._make_arg("TrueExpr") - arg3 = self._make_arg("FalseExpr") - instance = expr.If(arg1, arg2, arg3) - assert instance.name == "if" + arg2 = self._make_arg("ThenExpr") + arg3 = self._make_arg("ElseExpr") + instance = expr.Conditional(arg1, arg2, arg3) + assert instance.name == "conditional" assert instance.params == [arg1, arg2, arg3] - assert repr(instance) == "If(Condition, TrueExpr, FalseExpr)" + assert repr(instance) == "Conditional(Condition, ThenExpr, ElseExpr)" def test_like(self): arg1 = self._make_arg("Expr") arg2 = self._make_arg("Pattern") - instance = expr.Like(arg1, arg2) + instance = Expr.like(arg1, arg2) assert instance.name == "like" assert instance.params == [arg1, arg2] assert repr(instance) == "Expr.like(Pattern)" + infix_instance = arg1.like(arg2) + assert infix_instance == instance def test_regex_contains(self): arg1 = self._make_arg("Expr") arg2 = self._make_arg("Regex") - instance = expr.RegexContains(arg1, arg2) + instance = Expr.regex_contains(arg1, arg2) assert instance.name == "regex_contains" assert instance.params == [arg1, arg2] assert repr(instance) == "Expr.regex_contains(Regex)" + infix_instance = arg1.regex_contains(arg2) + assert infix_instance == instance def test_regex_match(self): arg1 = self._make_arg("Expr") arg2 = self._make_arg("Regex") - instance = expr.RegexMatch(arg1, arg2) + instance = Expr.regex_match(arg1, arg2) assert instance.name == "regex_match" assert instance.params == [arg1, arg2] assert repr(instance) == "Expr.regex_match(Regex)" + infix_instance = arg1.regex_match(arg2) + assert infix_instance == instance def test_starts_with(self): arg1 = self._make_arg("Expr") arg2 = self._make_arg("Prefix") - instance = expr.StartsWith(arg1, arg2) + instance = Expr.starts_with(arg1, arg2) assert instance.name == "starts_with" assert instance.params == [arg1, arg2] assert repr(instance) == "Expr.starts_with(Prefix)" + infix_instance = arg1.starts_with(arg2) + assert infix_instance == instance - def test_str_contains(self): + def test_string_contains(self): arg1 = self._make_arg("Expr") arg2 = self._make_arg("Substring") - instance = expr.StrContains(arg1, arg2) - assert instance.name == "str_contains" + instance = Expr.string_contains(arg1, arg2) + assert instance.name == "string_contains" assert instance.params == [arg1, arg2] - assert repr(instance) == "Expr.str_contains(Substring)" + assert repr(instance) == "Expr.string_contains(Substring)" + infix_instance = arg1.string_contains(arg2) + assert infix_instance == instance def test_xor(self): arg1 = self._make_arg("Condition1") @@ -899,333 +931,268 @@ def test_xor(self): assert instance.params == [arg1, arg2] assert repr(instance) == "Xor(Condition1, Condition2)" - -class TestFunctionClasses: - """ - contains test methods for each Expr class that derives from Function - """ - - @pytest.mark.parametrize( - "method,args,result_cls", - [ - ("add", ("field", 2), expr.Add), - ("subtract", ("field", 2), expr.Subtract), - ("multiply", ("field", 2), expr.Multiply), - ("divide", ("field", 2), expr.Divide), - ("mod", ("field", 2), expr.Mod), - ("logical_max", ("field", 2), expr.LogicalMax), - ("logical_min", ("field", 2), expr.LogicalMin), - ("eq", ("field", 2), expr.Eq), - ("neq", ("field", 2), expr.Neq), - ("lt", ("field", 2), expr.Lt), - ("lte", ("field", 2), expr.Lte), - ("gt", ("field", 2), expr.Gt), - ("gte", ("field", 2), expr.Gte), - ("in_any", ("field", [None]), expr.In), - ("not_in_any", ("field", [None]), expr.Not), - ("array_contains", ("field", None), expr.ArrayContains), - ("array_contains_all", ("field", [None]), expr.ArrayContainsAll), - ("array_contains_any", ("field", [None]), expr.ArrayContainsAny), - ("array_length", ("field",), expr.ArrayLength), - ("array_reverse", ("field",), expr.ArrayReverse), - ("is_nan", ("field",), expr.IsNaN), - ("exists", ("field",), expr.Exists), - ("sum", ("field",), expr.Sum), - ("avg", ("field",), expr.Avg), - ("count", ("field",), expr.Count), - ("count", (), expr.Count), - ("min", ("field",), expr.Min), - ("max", ("field",), expr.Max), - ("char_length", ("field",), expr.CharLength), - ("byte_length", ("field",), expr.ByteLength), - ("like", ("field", "pattern"), expr.Like), - ("regex_contains", ("field", "regex"), expr.RegexContains), - ("regex_matches", ("field", "regex"), expr.RegexMatch), - ("str_contains", ("field", "substring"), expr.StrContains), - ("starts_with", ("field", "prefix"), expr.StartsWith), - ("ends_with", ("field", "postfix"), expr.EndsWith), - ("str_concat", ("field", "elem1", "elem2"), expr.StrConcat), - ("map_get", ("field", "key"), expr.MapGet), - ("vector_length", ("field",), expr.VectorLength), - ("timestamp_to_unix_micros", ("field",), expr.TimestampToUnixMicros), - ("unix_micros_to_timestamp", ("field",), expr.UnixMicrosToTimestamp), - ("timestamp_to_unix_millis", ("field",), expr.TimestampToUnixMillis), - ("unix_millis_to_timestamp", ("field",), expr.UnixMillisToTimestamp), - ("timestamp_to_unix_seconds", ("field",), expr.TimestampToUnixSeconds), - ("unix_seconds_to_timestamp", ("field",), expr.UnixSecondsToTimestamp), - ("timestamp_add", ("field", "day", 1), expr.TimestampAdd), - ("timestamp_sub", ("field", "hour", 2.5), expr.TimestampSub), - ], - ) - def test_function_builder(self, method, args, result_cls): - """ - Test building functions using methods exposed on base Function class. - """ - method_ptr = getattr(expr.Function, method) - - result = method_ptr(*args) - assert isinstance(result, result_cls) - - @pytest.mark.parametrize( - "first,second,expected", - [ - (expr.ArrayElement(), expr.ArrayElement(), True), - (expr.ArrayElement(), expr.CharLength(1), False), - (expr.ArrayElement(), object(), False), - (expr.ArrayElement(), None, False), - (expr.CharLength(1), expr.ArrayElement(), False), - (expr.CharLength(1), expr.CharLength(2), False), - (expr.CharLength(1), expr.CharLength(1), True), - (expr.CharLength(1), expr.ByteLength(1), False), - ], - ) - def test_equality(self, first, second, expected): - assert (first == second) is expected - - def _make_arg(self, name="Mock"): - arg = mock.Mock() - arg.__repr__ = lambda x: name - return arg - def test_divide(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Divide(arg1, arg2) + instance = Expr.divide(arg1, arg2) assert instance.name == "divide" assert instance.params == [arg1, arg2] - assert repr(instance) == "Divide(Left, Right)" + assert repr(instance) == "Left.divide(Right)" + infix_instance = arg1.divide(arg2) + assert infix_instance == instance - def test_logical_max(self): + def test_logical_maximum(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.LogicalMax(arg1, arg2) - assert instance.name == "logical_maximum" + instance = Expr.logical_maximum(arg1, arg2) + assert instance.name == "maximum" assert instance.params == [arg1, arg2] - assert repr(instance) == "LogicalMax(Left, Right)" + assert repr(instance) == "Left.logical_maximum(Right)" + infix_instance = arg1.logical_maximum(arg2) + assert infix_instance == instance - def test_logical_min(self): + def test_logical_minimum(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.LogicalMin(arg1, arg2) - assert instance.name == "logical_minimum" + instance = Expr.logical_minimum(arg1, arg2) + assert instance.name == "minimum" assert instance.params == [arg1, arg2] - assert repr(instance) == "LogicalMin(Left, Right)" + assert repr(instance) == "Left.logical_minimum(Right)" + infix_instance = arg1.logical_minimum(arg2) + assert infix_instance == instance def test_map_get(self): arg1 = self._make_arg("Map") - arg2 = expr.Constant("Key") - instance = expr.MapGet(arg1, arg2) + arg2 = "key" + instance = Expr.map_get(arg1, arg2) assert instance.name == "map_get" - assert instance.params == [arg1, arg2] - assert repr(instance) == "MapGet(Map, Constant.of('Key'))" + assert instance.params == [arg1, Constant.of(arg2)] + assert repr(instance) == "Map.map_get(Constant.of('key'))" + infix_instance = arg1.map_get(Constant.of(arg2)) + assert infix_instance == instance def test_mod(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Mod(arg1, arg2) + instance = Expr.mod(arg1, arg2) assert instance.name == "mod" assert instance.params == [arg1, arg2] - assert repr(instance) == "Mod(Left, Right)" + assert repr(instance) == "Left.mod(Right)" + infix_instance = arg1.mod(arg2) + assert infix_instance == instance def test_multiply(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Multiply(arg1, arg2) + instance = Expr.multiply(arg1, arg2) assert instance.name == "multiply" assert instance.params == [arg1, arg2] - assert repr(instance) == "Multiply(Left, Right)" - - def test_parent(self): - arg1 = self._make_arg("Value") - instance = expr.Parent(arg1) - assert instance.name == "parent" - assert instance.params == [arg1] - assert repr(instance) == "Parent(Value)" + assert repr(instance) == "Left.multiply(Right)" + infix_instance = arg1.multiply(arg2) + assert infix_instance == instance - def test_str_concat(self): + def test_string_concat(self): arg1 = self._make_arg("Str1") arg2 = self._make_arg("Str2") - instance = expr.StrConcat(arg1, arg2) - assert instance.name == "str_concat" - assert instance.params == [arg1, arg2] - assert repr(instance) == "StrConcat(Str1, Str2)" + arg3 = self._make_arg("Str3") + instance = Expr.string_concat(arg1, arg2, arg3) + assert instance.name == "string_concat" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "Str1.string_concat(Str2, Str3)" + infix_instance = arg1.string_concat(arg2, arg3) + assert infix_instance == instance def test_subtract(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Subtract(arg1, arg2) + instance = Expr.subtract(arg1, arg2) assert instance.name == "subtract" assert instance.params == [arg1, arg2] - assert repr(instance) == "Subtract(Left, Right)" + assert repr(instance) == "Left.subtract(Right)" + infix_instance = arg1.subtract(arg2) + assert infix_instance == instance def test_timestamp_add(self): arg1 = self._make_arg("Timestamp") arg2 = self._make_arg("Unit") arg3 = self._make_arg("Amount") - instance = expr.TimestampAdd(arg1, arg2, arg3) + instance = Expr.timestamp_add(arg1, arg2, arg3) assert instance.name == "timestamp_add" assert instance.params == [arg1, arg2, arg3] - assert repr(instance) == "TimestampAdd(Timestamp, Unit, Amount)" + assert repr(instance) == "Timestamp.timestamp_add(Unit, Amount)" + infix_instance = arg1.timestamp_add(arg2, arg3) + assert infix_instance == instance - def test_timestamp_sub(self): + def test_timestamp_subtract(self): arg1 = self._make_arg("Timestamp") arg2 = self._make_arg("Unit") arg3 = self._make_arg("Amount") - instance = expr.TimestampSub(arg1, arg2, arg3) - assert instance.name == "timestamp_sub" + instance = Expr.timestamp_subtract(arg1, arg2, arg3) + assert instance.name == "timestamp_subtract" assert instance.params == [arg1, arg2, arg3] - assert repr(instance) == "TimestampSub(Timestamp, Unit, Amount)" + assert repr(instance) == "Timestamp.timestamp_subtract(Unit, Amount)" + infix_instance = arg1.timestamp_subtract(arg2, arg3) + assert infix_instance == instance def test_timestamp_to_unix_micros(self): arg1 = self._make_arg("Input") - instance = expr.TimestampToUnixMicros(arg1) + instance = Expr.timestamp_to_unix_micros(arg1) assert instance.name == "timestamp_to_unix_micros" assert instance.params == [arg1] - assert repr(instance) == "TimestampToUnixMicros(Input)" + assert repr(instance) == "Input.timestamp_to_unix_micros()" + infix_instance = arg1.timestamp_to_unix_micros() + assert infix_instance == instance def test_timestamp_to_unix_millis(self): arg1 = self._make_arg("Input") - instance = expr.TimestampToUnixMillis(arg1) + instance = Expr.timestamp_to_unix_millis(arg1) assert instance.name == "timestamp_to_unix_millis" assert instance.params == [arg1] - assert repr(instance) == "TimestampToUnixMillis(Input)" + assert repr(instance) == "Input.timestamp_to_unix_millis()" + infix_instance = arg1.timestamp_to_unix_millis() + assert infix_instance == instance def test_timestamp_to_unix_seconds(self): arg1 = self._make_arg("Input") - instance = expr.TimestampToUnixSeconds(arg1) + instance = Expr.timestamp_to_unix_seconds(arg1) assert instance.name == "timestamp_to_unix_seconds" assert instance.params == [arg1] - assert repr(instance) == "TimestampToUnixSeconds(Input)" + assert repr(instance) == "Input.timestamp_to_unix_seconds()" + infix_instance = arg1.timestamp_to_unix_seconds() + assert infix_instance == instance def test_unix_micros_to_timestamp(self): arg1 = self._make_arg("Input") - instance = expr.UnixMicrosToTimestamp(arg1) + instance = Expr.unix_micros_to_timestamp(arg1) assert instance.name == "unix_micros_to_timestamp" assert instance.params == [arg1] - assert repr(instance) == "UnixMicrosToTimestamp(Input)" + assert repr(instance) == "Input.unix_micros_to_timestamp()" + infix_instance = arg1.unix_micros_to_timestamp() + assert infix_instance == instance def test_unix_millis_to_timestamp(self): arg1 = self._make_arg("Input") - instance = expr.UnixMillisToTimestamp(arg1) + instance = Expr.unix_millis_to_timestamp(arg1) assert instance.name == "unix_millis_to_timestamp" assert instance.params == [arg1] - assert repr(instance) == "UnixMillisToTimestamp(Input)" + assert repr(instance) == "Input.unix_millis_to_timestamp()" + infix_instance = arg1.unix_millis_to_timestamp() + assert infix_instance == instance def test_unix_seconds_to_timestamp(self): arg1 = self._make_arg("Input") - instance = expr.UnixSecondsToTimestamp(arg1) + instance = Expr.unix_seconds_to_timestamp(arg1) assert instance.name == "unix_seconds_to_timestamp" assert instance.params == [arg1] - assert repr(instance) == "UnixSecondsToTimestamp(Input)" + assert repr(instance) == "Input.unix_seconds_to_timestamp()" + infix_instance = arg1.unix_seconds_to_timestamp() + assert infix_instance == instance def test_vector_length(self): arg1 = self._make_arg("Array") - instance = expr.VectorLength(arg1) + instance = Expr.vector_length(arg1) assert instance.name == "vector_length" assert instance.params == [arg1] - assert repr(instance) == "VectorLength(Array)" + assert repr(instance) == "Array.vector_length()" + infix_instance = arg1.vector_length() + assert infix_instance == instance def test_add(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Add(arg1, arg2) + instance = Expr.add(arg1, arg2) assert instance.name == "add" assert instance.params == [arg1, arg2] - assert repr(instance) == "Add(Left, Right)" - - def test_array_element(self): - instance = expr.ArrayElement() - assert instance.name == "array_element" - assert instance.params == [] - assert repr(instance) == "ArrayElement()" - - def test_array_filter(self): - arg1 = self._make_arg("Array") - arg2 = self._make_arg("FilterCond") - instance = expr.ArrayFilter(arg1, arg2) - assert instance.name == "array_filter" - assert instance.params == [arg1, arg2] - assert repr(instance) == "ArrayFilter(Array, FilterCond)" + assert repr(instance) == "Left.add(Right)" + infix_instance = arg1.add(arg2) + assert infix_instance == instance def test_array_length(self): arg1 = self._make_arg("Array") - instance = expr.ArrayLength(arg1) + instance = Expr.array_length(arg1) assert instance.name == "array_length" assert instance.params == [arg1] - assert repr(instance) == "ArrayLength(Array)" + assert repr(instance) == "Array.array_length()" + infix_instance = arg1.array_length() + assert infix_instance == instance def test_array_reverse(self): arg1 = self._make_arg("Array") - instance = expr.ArrayReverse(arg1) + instance = Expr.array_reverse(arg1) assert instance.name == "array_reverse" assert instance.params == [arg1] - assert repr(instance) == "ArrayReverse(Array)" - - def test_array_transform(self): - arg1 = self._make_arg("Array") - arg2 = self._make_arg("TransformFunc") - instance = expr.ArrayTransform(arg1, arg2) - assert instance.name == "array_transform" - assert instance.params == [arg1, arg2] - assert repr(instance) == "ArrayTransform(Array, TransformFunc)" + assert repr(instance) == "Array.array_reverse()" + infix_instance = arg1.array_reverse() + assert infix_instance == instance def test_byte_length(self): arg1 = self._make_arg("Expr") - instance = expr.ByteLength(arg1) + instance = Expr.byte_length(arg1) assert instance.name == "byte_length" assert instance.params == [arg1] - assert repr(instance) == "ByteLength(Expr)" + assert repr(instance) == "Expr.byte_length()" + infix_instance = arg1.byte_length() + assert infix_instance == instance def test_char_length(self): arg1 = self._make_arg("Expr") - instance = expr.CharLength(arg1) + instance = Expr.char_length(arg1) assert instance.name == "char_length" assert instance.params == [arg1] - assert repr(instance) == "CharLength(Expr)" + assert repr(instance) == "Expr.char_length()" + infix_instance = arg1.char_length() + assert infix_instance == instance def test_collection_id(self): arg1 = self._make_arg("Value") - instance = expr.CollectionId(arg1) + instance = Expr.collection_id(arg1) assert instance.name == "collection_id" assert instance.params == [arg1] - assert repr(instance) == "CollectionId(Value)" + assert repr(instance) == "Value.collection_id()" + infix_instance = arg1.collection_id() + assert infix_instance == instance def test_sum(self): arg1 = self._make_arg("Value") - instance = expr.Sum(arg1) + instance = Expr.sum(arg1) assert instance.name == "sum" assert instance.params == [arg1] - assert repr(instance) == "Sum(Value)" + assert repr(instance) == "Value.sum()" + infix_instance = arg1.sum() + assert infix_instance == instance - def test_avg(self): + def test_average(self): arg1 = self._make_arg("Value") - instance = expr.Avg(arg1) - assert instance.name == "avg" + instance = Expr.average(arg1) + assert instance.name == "average" assert instance.params == [arg1] - assert repr(instance) == "Avg(Value)" + assert repr(instance) == "Value.average()" + infix_instance = arg1.average() + assert infix_instance == instance def test_count(self): arg1 = self._make_arg("Value") - instance = expr.Count(arg1) + instance = Expr.count(arg1) assert instance.name == "count" assert instance.params == [arg1] - assert repr(instance) == "Count(Value)" - - def test_count_empty(self): - instance = expr.Count() - assert instance.params == [] - assert repr(instance) == "Count()" + assert repr(instance) == "Value.count()" + infix_instance = arg1.count() + assert infix_instance == instance - def test_min(self): + def test_minimum(self): arg1 = self._make_arg("Value") - instance = expr.Min(arg1) + instance = Expr.minimum(arg1) assert instance.name == "minimum" assert instance.params == [arg1] - assert repr(instance) == "Min(Value)" + assert repr(instance) == "Value.minimum()" + infix_instance = arg1.minimum() + assert infix_instance == instance - def test_max(self): + def test_maximum(self): arg1 = self._make_arg("Value") - instance = expr.Max(arg1) + instance = Expr.maximum(arg1) assert instance.name == "maximum" assert instance.params == [arg1] - assert repr(instance) == "Max(Value)" + assert repr(instance) == "Value.maximum()" + infix_instance = arg1.maximum() + assert infix_instance == instance diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py index e67a4ca3a..d5b36e56c 100644 --- a/tests/unit/v1/test_pipeline_stages.py +++ b/tests/unit/v1/test_pipeline_stages.py @@ -21,8 +21,6 @@ Constant, Field, Ordering, - Sum, - Count, ) from google.cloud.firestore_v1.types.document import Value from google.cloud.firestore_v1._helpers import GeoPoint @@ -79,8 +77,8 @@ def _make_one(self, *args, **kwargs): def test_ctor_positional(self): """test with only positional arguments""" - sum_total = Sum(Field.of("total")).as_("sum_total") - avg_price = Field.of("price").avg().as_("avg_price") + sum_total = Field.of("total").sum().as_("sum_total") + avg_price = Field.of("price").average().as_("avg_price") instance = self._make_one(sum_total, avg_price) assert list(instance.accumulators) == [sum_total, avg_price] assert len(instance.groups) == 0 @@ -88,8 +86,8 @@ def test_ctor_positional(self): def test_ctor_keyword(self): """test with only keyword arguments""" - sum_total = Sum(Field.of("total")).as_("sum_total") - avg_price = Field.of("price").avg().as_("avg_price") + sum_total = Field.of("total").sum().as_("sum_total") + avg_price = Field.of("price").average().as_("avg_price") group_category = Field.of("category") instance = self._make_one( accumulators=[avg_price, sum_total], groups=[group_category, "city"] @@ -103,24 +101,24 @@ def test_ctor_keyword(self): def test_ctor_combined(self): """test with a mix of arguments""" - sum_total = Sum(Field.of("total")).as_("sum_total") - avg_price = Field.of("price").avg().as_("avg_price") - count = Count(Field.of("total")).as_("count") + sum_total = Field.of("total").sum().as_("sum_total") + avg_price = Field.of("price").average().as_("avg_price") + count = Field.of("total").count().as_("count") with pytest.raises(ValueError): self._make_one(sum_total, accumulators=[avg_price, count]) def test_repr(self): - sum_total = Sum(Field.of("total")).as_("sum_total") + sum_total = Field.of("total").sum().as_("sum_total") group_category = Field.of("category") instance = self._make_one(sum_total, groups=[group_category]) repr_str = repr(instance) assert ( repr_str - == "Aggregate(Sum(Field.of('total')).as_('sum_total'), groups=[Field.of('category')])" + == "Aggregate(Field.of('total').sum().as_('sum_total'), groups=[Field.of('category')])" ) def test_to_pb(self): - sum_total = Sum(Field.of("total")).as_("sum_total") + sum_total = Field.of("total").sum().as_("sum_total") group_category = Field.of("category") instance = self._make_one(sum_total, groups=[group_category]) result = instance._to_pb() @@ -790,19 +788,21 @@ def _make_one(self, *args, **kwargs): return stages.Where(*args, **kwargs) def test_repr(self): - condition = Field.of("age").gt(30) + condition = Field.of("age").greater_than(30) instance = self._make_one(condition) repr_str = repr(instance) - assert repr_str == "Where(condition=Field.of('age').gt(Constant.of(30)))" + assert ( + repr_str == "Where(condition=Field.of('age').greater_than(Constant.of(30)))" + ) def test_to_pb(self): - condition = Field.of("city").eq("SF") + condition = Field.of("city").equal("SF") instance = self._make_one(condition) result = instance._to_pb() assert result.name == "where" assert len(result.args) == 1 got_fn = result.args[0].function_value - assert got_fn.name == "eq" + assert got_fn.name == "equal" assert len(got_fn.args) == 2 assert got_fn.args[0].field_reference_value == "city" assert got_fn.args[1].string_value == "SF" From fed7af2aee6c8e59ecf89c5a684dd61dc44f0d82 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 23 Oct 2025 13:54:00 -0700 Subject: [PATCH 05/27] feat: query to pipeline conversion (#1071) --- google/cloud/firestore_v1/_pipeline_stages.py | 2 +- google/cloud/firestore_v1/base_aggregation.py | 57 ++++- google/cloud/firestore_v1/base_collection.py | 13 ++ google/cloud/firestore_v1/base_query.py | 69 ++++++ .../firestore_v1/pipeline_expressions.py | 48 +++- tests/system/test__helpers.py | 1 + tests/system/test_system.py | 135 ++++++++--- tests/system/test_system_async.py | 103 +++++++-- tests/unit/v1/test_aggregation.py | 215 ++++++++++++++++++ tests/unit/v1/test_async_aggregation.py | 145 ++++++++++++ tests/unit/v1/test_async_collection.py | 20 ++ tests/unit/v1/test_async_query.py | 19 ++ tests/unit/v1/test_base_collection.py | 14 ++ tests/unit/v1/test_base_query.py | 170 ++++++++++++++ tests/unit/v1/test_collection.py | 22 ++ tests/unit/v1/test_pipeline_expressions.py | 15 ++ tests/unit/v1/test_pipeline_stages.py | 5 +- tests/unit/v1/test_query.py | 19 ++ 18 files changed, 1019 insertions(+), 53 deletions(-) diff --git a/google/cloud/firestore_v1/_pipeline_stages.py b/google/cloud/firestore_v1/_pipeline_stages.py index aefddbcf8..7233a8eec 100644 --- a/google/cloud/firestore_v1/_pipeline_stages.py +++ b/google/cloud/firestore_v1/_pipeline_stages.py @@ -216,7 +216,7 @@ def __init__(self, collection_id: str): self.collection_id = collection_id def _pb_args(self): - return [Value(string_value=self.collection_id)] + return [Value(reference_value=""), Value(string_value=self.collection_id)] class Database(Stage): diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index c5e6a7b7f..3dd7a453e 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -21,9 +21,10 @@ from __future__ import annotations import abc +import itertools from abc import ABC -from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union, Iterable from google.api_core import gapic_v1 from google.api_core import retry as retries @@ -33,6 +34,10 @@ from google.cloud.firestore_v1.types import ( StructuredAggregationQuery, ) +from google.cloud.firestore_v1.pipeline_expressions import AggregateFunction +from google.cloud.firestore_v1.pipeline_expressions import Count +from google.cloud.firestore_v1.pipeline_expressions import AliasedExpr +from google.cloud.firestore_v1.pipeline_expressions import Field # Types needed only for Type Hints if TYPE_CHECKING: # pragma: NO COVER @@ -66,6 +71,9 @@ def __init__(self, alias: str, value: float, read_time=None): def __repr__(self): return f"" + def _to_dict(self): + return {self.alias: self.value} + class BaseAggregation(ABC): def __init__(self, alias: str | None = None): @@ -75,6 +83,27 @@ def __init__(self, alias: str | None = None): def _to_protobuf(self): """Convert this instance to the protobuf representation""" + @abc.abstractmethod + def _to_pipeline_expr( + self, autoindexer: Iterable[int] + ) -> AliasedExpr[AggregateFunction]: + """ + Convert this instance to a pipeline expression for use with pipeline.aggregate() + + Args: + autoindexer: If an alias isn't supplied, one should be created with the format "field_n" + The autoindexer is an iterable that provides the `n` value to use for each expression + """ + + def _pipeline_alias(self, autoindexer): + """ + Helper to build the alias for the pipeline expression + """ + if self.alias is not None: + return self.alias + else: + return f"field_{next(autoindexer)}" + class CountAggregation(BaseAggregation): def __init__(self, alias: str | None = None): @@ -88,6 +117,9 @@ def _to_protobuf(self): aggregation_pb.count = StructuredAggregationQuery.Aggregation.Count() return aggregation_pb + def _to_pipeline_expr(self, autoindexer: Iterable[int]): + return Count().as_(self._pipeline_alias(autoindexer)) + class SumAggregation(BaseAggregation): def __init__(self, field_ref: str | FieldPath, alias: str | None = None): @@ -107,6 +139,9 @@ def _to_protobuf(self): aggregation_pb.sum.field.field_path = self.field_ref return aggregation_pb + def _to_pipeline_expr(self, autoindexer: Iterable[int]): + return Field.of(self.field_ref).sum().as_(self._pipeline_alias(autoindexer)) + class AvgAggregation(BaseAggregation): def __init__(self, field_ref: str | FieldPath, alias: str | None = None): @@ -126,6 +161,9 @@ def _to_protobuf(self): aggregation_pb.avg.field.field_path = self.field_ref return aggregation_pb + def _to_pipeline_expr(self, autoindexer: Iterable[int]): + return Field.of(self.field_ref).average().as_(self._pipeline_alias(autoindexer)) + def _query_response_to_result( response_pb, @@ -317,3 +355,20 @@ def stream( StreamGenerator[List[AggregationResult]] | AsyncStreamGenerator[List[AggregationResult]]: A generator of the query results. """ + + def pipeline(self): + """ + Convert this query into a Pipeline + + Queries containing a `cursor` or `limit_to_last` are not currently supported + + Raises: + - ValueError: raised if Query wasn't created with an associated client + - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` + Returns: + a Pipeline representing the query + """ + # use autoindexer to keep track of which field number to use for un-aliased fields + autoindexer = itertools.count(start=1) + exprs = [a._to_pipeline_expr(autoindexer) for a in self._aggregations] + return self._nested_query.pipeline().aggregate(*exprs) diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 1b1ef0411..a4cc2b1b7 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -602,6 +602,19 @@ def find_nearest( distance_threshold=distance_threshold, ) + def pipeline(self): + """ + Convert this query into a Pipeline + + Queries containing a `cursor` or `limit_to_last` are not currently supported + + Raises: + - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` + Returns: + a Pipeline representing the query + """ + return self._query().pipeline() + def _auto_id() -> str: """Generate a "random" automatically generated ID. diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 2de95b79a..797572b1b 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -59,6 +59,7 @@ query, ) from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1 import pipeline_expressions if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator @@ -1128,6 +1129,74 @@ def recursive(self: QueryType) -> QueryType: return copied + def pipeline(self): + """ + Convert this query into a Pipeline + + Queries containing a `cursor` or `limit_to_last` are not currently supported + + Raises: + - ValueError: raised if Query wasn't created with an associated client + - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` + Returns: + a Pipeline representing the query + """ + if not self._client: + raise ValueError("Query does not have an associated client") + if self._all_descendants: + ppl = self._client.pipeline().collection_group(self._parent.id) + else: + ppl = self._client.pipeline().collection(self._parent._path) + + # Filters + for filter_ in self._field_filters: + ppl = ppl.where( + pipeline_expressions.BooleanExpr._from_query_filter_pb( + filter_, self._client + ) + ) + + # Projections + if self._projection and self._projection.fields: + ppl = ppl.select(*[field.field_path for field in self._projection.fields]) + + # Orders + orders = self._normalize_orders() + if orders: + exists = [] + orderings = [] + for order in orders: + field = pipeline_expressions.Field.of(order.field.field_path) + exists.append(field.exists()) + direction = ( + "ascending" + if order.direction == StructuredQuery.Direction.ASCENDING + else "descending" + ) + orderings.append(pipeline_expressions.Ordering(field, direction)) + + # Add exists filters to match Query's implicit orderby semantics. + if len(exists) == 1: + ppl = ppl.where(exists[0]) + else: + ppl = ppl.where(pipeline_expressions.And(*exists)) + + # Add sort orderings + ppl = ppl.sort(*orderings) + + # Cursors, Limit and Offset + if self._start_at or self._end_at or self._limit_to_last: + raise NotImplementedError( + "Query to Pipeline conversion: cursors and limit_to_last is not supported yet." + ) + else: # Limit & Offset without cursors + if self._offset: + ppl = ppl.offset(self._offset) + if self._limit: + ppl = ppl.limit(self._limit) + + return ppl + def _comparator(self, doc1, doc2) -> int: _orders = self._orders diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index ef57f5b72..4639e0f7d 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -587,6 +587,18 @@ def is_nan(self) -> "BooleanExpr": """ return BooleanExpr("is_nan", [self]) + @expose_as_static + def is_null(self) -> "BooleanExpr": + """Creates an expression that checks if this expression evaluates to 'Null'. + + Example: + >>> Field.of("value").is_null() + + Returns: + A new `Expr` representing the 'isNull' check. + """ + return BooleanExpr("is_null", [self]) + @expose_as_static def exists(self) -> "BooleanExpr": """Creates an expression that checks if a field exists in the document. @@ -627,6 +639,7 @@ def average(self) -> "Expr": """ return AggregateFunction("average", [self]) + @expose_as_static def count(self) -> "Expr": """Creates an aggregation that counts the number of stage inputs with valid evaluations of the expression or field. @@ -1312,9 +1325,9 @@ def _from_query_filter_pb(filter_pb, client): elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NAN: return And(field.exists(), Not(field.is_nan())) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NULL: - return And(field.exists(), field.equal(None)) + return And(field.exists(), field.is_null()) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NULL: - return And(field.exists(), Not(field.equal(None))) + return And(field.exists(), Not(field.is_null())) else: raise TypeError(f"Unexpected UnaryFilter operator type: {filter_pb.op}") elif isinstance(filter_pb, Query_pb.FieldFilter): @@ -1361,7 +1374,7 @@ class And(BooleanExpr): Example: >>> # Check if the 'age' field is greater than 18 AND the 'city' field is "London" AND >>> # the 'status' field is "active" - >>> Expr.And(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) + >>> And(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) Args: *conditions: The filter conditions to 'AND' together. @@ -1377,7 +1390,7 @@ class Not(BooleanExpr): Example: >>> # Find documents where the 'completed' field is NOT true - >>> Expr.Not(Field.of("completed").equal(True)) + >>> Not(Field.of("completed").equal(True)) Args: condition: The filter condition to negate. @@ -1394,7 +1407,7 @@ class Or(BooleanExpr): Example: >>> # Check if the 'age' field is greater than 18 OR the 'city' field is "London" OR >>> # the 'status' field is "active" - >>> Expr.Or(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) + >>> Or(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) Args: *conditions: The filter conditions to 'OR' together. @@ -1411,7 +1424,7 @@ class Xor(BooleanExpr): Example: >>> # Check if only one of the conditions is true: 'age' greater than 18, 'city' is "London", >>> # or 'status' is "active". - >>> Expr.Xor(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) + >>> Xor(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) Args: *conditions: The filter conditions to 'XOR' together. @@ -1428,7 +1441,7 @@ class Conditional(BooleanExpr): Example: >>> # If 'age' is greater than 18, return "Adult"; otherwise, return "Minor". - >>> Expr.conditional(Field.of("age").greater_than(18), Constant.of("Adult"), Constant.of("Minor")); + >>> Conditional(Field.of("age").greater_than(18), Constant.of("Adult"), Constant.of("Minor")); Args: condition: The condition to evaluate. @@ -1440,3 +1453,24 @@ def __init__(self, condition: BooleanExpr, then_expr: Expr, else_expr: Expr): super().__init__( "conditional", [condition, then_expr, else_expr], use_infix_repr=False ) + +class Count(AggregateFunction): + """ + Represents an aggregation that counts the number of stage inputs with valid evaluations of the + expression or field. + + Example: + >>> # Count the total number of products + >>> Field.of("productId").count().as_("totalProducts") + >>> Count(Field.of("productId")) + >>> Count().as_("count") + + Args: + expression: The expression or field to count. If None, counts all stage inputs. + """ + + def __init__(self, expression: Expr | None = None): + expression_list = [expression] if expression else [] + super().__init__( + "count", expression_list, use_infix_repr=bool(expression_list) + ) diff --git a/tests/system/test__helpers.py b/tests/system/test__helpers.py index c146a5763..5a93a869e 100644 --- a/tests/system/test__helpers.py +++ b/tests/system/test__helpers.py @@ -20,3 +20,4 @@ # run all tests against default database, and a named database # TODO: add enterprise mode when GA (RunQuery not currently supported) TEST_DATABASES = [None, FIRESTORE_OTHER_DB] +TEST_DATABASES_W_ENTERPRISE = TEST_DATABASES + [FIRESTORE_ENTERPRISE_DB] diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 9909fb05e..a8f94e2ba 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -42,7 +42,9 @@ MISSING_DOCUMENT, RANDOM_ID_REGEX, UNIQUE_RESOURCE_ID, + ENTERPRISE_MODE_ERROR, TEST_DATABASES, + TEST_DATABASES_W_ENTERPRISE, ) @@ -80,6 +82,58 @@ def cleanup(): operation() +def verify_pipeline(query): + """ + This function ensures a pipeline produces the same + results as the query it is derived from + + It can be attached to existing query tests to check both + modalities at the same time + """ + from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + + def _clean_results(results): + if isinstance(results, dict): + return {k: _clean_results(v) for k, v in results.items()} + elif isinstance(results, list): + return [_clean_results(r) for r in results] + elif isinstance(results, float) and math.isnan(results): + return "__NAN_VALUE__" + else: + return results + + query_exception = None + query_results = None + try: + try: + if isinstance(query, BaseAggregationQuery): + # aggregation queries return a list of lists of aggregation results + query_results = _clean_results( + list(itertools.chain.from_iterable( + [[a._to_dict() for a in s] for s in query.get()] + )) + ) + else: + # other qureies return a simple list of results + query_results = _clean_results([s.to_dict() for s in query.get()]) + except Exception as e: + # if we expect the query to fail, capture the exception + query_exception = e + pipeline = query.pipeline() + if query_exception: + # ensure that the pipeline uses same error as query + with pytest.raises(query_exception.__class__): + pipeline.execute() + else: + # ensure results match query + pipeline_results = _clean_results([s.data() for s in pipeline.execute()]) + assert query_results == pipeline_results + except FailedPrecondition as e: + # if testing against a non-enterprise db, skip this check + if ENTERPRISE_MODE_ERROR not in e.message: + raise e + + @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collections(client, database): collections = list(client.collections()) @@ -1231,7 +1285,7 @@ def query(collection): return collection.where(filter=FieldFilter("a", "==", 1)) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_legacy_where(query_docs, database): """Assert the legacy code still works and returns value""" collection, stored, allowed_vals = query_docs @@ -1245,9 +1299,10 @@ def test_query_stream_legacy_where(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_simple_field_eq_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("a", "==", 1)) @@ -1256,9 +1311,10 @@ def test_query_stream_w_simple_field_eq_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_simple_field_array_contains_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("c", "array_contains", 1)) @@ -1267,9 +1323,10 @@ def test_query_stream_w_simple_field_array_contains_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_simple_field_in_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1279,9 +1336,10 @@ def test_query_stream_w_simple_field_in_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_not_eq_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", "!=", 4)) @@ -1301,9 +1359,10 @@ def test_query_stream_w_not_eq_op(query_docs, database): ] ) assert expected_ab_pairs == ab_pairs2 + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_simple_not_in_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1313,9 +1372,10 @@ def test_query_stream_w_simple_not_in_op(query_docs, database): values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()} assert len(values) == 22 + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1327,9 +1387,10 @@ def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database) for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_order_by(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.order_by("b", direction=firestore.Query.DESCENDING) @@ -1341,9 +1402,10 @@ def test_query_stream_w_order_by(query_docs, database): b_vals.append(value["b"]) # Make sure the ``b``-values are in DESCENDING order. assert sorted(b_vals, reverse=True) == b_vals + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_field_path(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", ">", 4)) @@ -1363,6 +1425,7 @@ def test_query_stream_w_field_path(query_docs, database): ] ) assert expected_ab_pairs == ab_pairs2 + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1381,13 +1444,14 @@ def test_query_stream_w_start_end_cursor(query_docs, database): assert value["a"] == num_vals - 2 -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_wo_results(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("b", "==", num_vals + 100)) values = list(query.stream()) assert len(values) == 0 + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1407,7 +1471,7 @@ def test_query_stream_w_projection(query_docs, database): assert expected == value -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_multiple_filters(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.product", ">", 5)).where( @@ -1425,9 +1489,10 @@ def test_query_stream_w_multiple_filters(query_docs, database): assert stored[key] == value pair = (value["a"], value["b"]) assert pair in matching_pairs + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_w_offset(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1441,13 +1506,14 @@ def test_query_stream_w_offset(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["b"] == 2 + verify_pipeline(query) @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -1463,7 +1529,7 @@ def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): # is called with pytest.raises(QueryExplainError, match="explain_options not set on query"): results.get_explain_metrics() - + verify_pipeline(query) @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." @@ -1646,7 +1712,7 @@ def test_query_with_order_dot_key(client, cleanup, database): assert found_data == [snap.to_dict() for snap in cursor_with_key_data] -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_unary(client, cleanup, database): collection_name = "unary" + UNIQUE_RESOURCE_ID collection = client.collection(collection_name) @@ -1672,6 +1738,7 @@ def test_query_unary(client, cleanup, database): snapshot0 = values0[0] assert snapshot0.reference._path == document0._path assert snapshot0.to_dict() == {field_name: None} + verify_pipeline(query0) # 1. Query for a NAN. query1 = collection.where(filter=FieldFilter(field_name, "==", nan_val)) @@ -1682,6 +1749,7 @@ def test_query_unary(client, cleanup, database): data1 = snapshot1.to_dict() assert len(data1) == 1 assert math.isnan(data1[field_name]) + verify_pipeline(query1) # 2. Query for not null query2 = collection.where(filter=FieldFilter(field_name, "!=", None)) @@ -1701,7 +1769,7 @@ def test_query_unary(client, cleanup, database): assert snapshot3.to_dict() == {field_name: 123} -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_collection_group_queries(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1731,7 +1799,8 @@ def test_collection_group_queries(client, cleanup, database): snapshots = list(query.stream()) found = [snapshot.id for snapshot in snapshots] expected = ["cg-doc1", "cg-doc2", "cg-doc3", "cg-doc4", "cg-doc5"] - assert found == expected + assert set(found) == set(expected) + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1777,7 +1846,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database): assert found == set(["cg-doc2"]) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_collection_group_queries_filters(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1820,6 +1889,7 @@ def test_collection_group_queries_filters(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"]) + verify_pipeline(query) query = ( client.collection_group(collection_group) @@ -1841,6 +1911,7 @@ def test_collection_group_queries_filters(client, cleanup, database): snapshots = list(query.stream()) found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2"]) + verify_pipeline(query) @pytest.mark.skipif( @@ -2129,7 +2200,7 @@ def on_snapshot(docs, changes, read_time): ) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_watch_query(client, cleanup, database): db = client collection_ref = db.collection("wq-users" + UNIQUE_RESOURCE_ID) @@ -2150,6 +2221,7 @@ def on_snapshot(docs, changes, read_time): query_ran_query = collection_ref.where(filter=FieldFilter("first", "==", "Ada")) query_ran = query_ran_query.stream() assert len(docs) == len([i for i in query_ran]) + verify_pipeline(query_ran_query) on_snapshot.called_count = 0 @@ -2490,7 +2562,7 @@ def test_chunked_and_recursive(client, cleanup, database): assert [doc.id for doc in next(iter)] == page_3_ids -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_watch_query_order(client, cleanup, database): db = client collection_ref = db.collection("users") @@ -2527,6 +2599,7 @@ def on_snapshot(docs, changes, read_time): ), "expect the sort order to match, born" on_snapshot.called_count += 1 on_snapshot.last_doc_count = len(docs) + verify_pipeline(query_ref) except Exception as e: on_snapshot.failed = e @@ -2566,7 +2639,7 @@ def on_snapshot(docs, changes, read_time): ) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_repro_429(client, cleanup, database): # See: https://github.com/googleapis/python-firestore/issues/429 now = datetime.datetime.now(tz=datetime.timezone.utc) @@ -2592,6 +2665,7 @@ def test_repro_429(client, cleanup, database): for snapshot in query2.stream(): print(f"id: {snapshot.id}") + verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -3160,7 +3234,7 @@ def test_aggregation_query_stream_or_get_w_explain_options_analyze_false( explain_metrics.execution_stats -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_with_and_composite_filter(collection, database): and_filter = And( filters=[ @@ -3173,9 +3247,10 @@ def test_query_with_and_composite_filter(collection, database): for result in query.stream(): assert result.get("stats.product") > 5 assert result.get("stats.product") < 10 + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_with_or_composite_filter(collection, database): or_filter = Or( filters=[ @@ -3196,9 +3271,10 @@ def test_query_with_or_composite_filter(collection, database): assert gt_5 > 0 assert lt_10 > 0 + verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) @pytest.mark.parametrize( "aggregation_type,expected_value", [("count", 5), ("sum", 100), ("avg", 4.0)] ) @@ -3243,9 +3319,10 @@ def test_aggregation_queries_with_read_time( assert len(old_result) == 1 for r in old_result[0]: assert r.value == expected_value + verify_pipeline(aggregation_query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_query_with_complex_composite_filter(collection, database): field_filter = FieldFilter("b", "==", 0) or_filter = Or( @@ -3266,6 +3343,7 @@ def test_query_with_complex_composite_filter(collection, database): assert sum_0 > 0 assert sum_4 > 0 + verify_pipeline(query) # b == 3 || (stats.sum == 4 && a == 4) comp_filter = Or( @@ -3288,13 +3366,14 @@ def test_query_with_complex_composite_filter(collection, database): assert b_3 is True assert b_not_3 is True + verify_pipeline(query) @pytest.mark.parametrize( "aggregation_type,aggregation_args,expected", [("count", (), 3), ("sum", ("b"), 12), ("avg", ("b"), 4)], ) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_aggregation_query_in_transaction( client, cleanup, @@ -3335,13 +3414,14 @@ def in_transaction(transaction): assert len(result[0]) == 1 assert result[0][0].value == expected inner_fn_ran = True + verify_pipeline(aggregation_query) in_transaction(transaction) # make sure we didn't skip assertions in inner function assert inner_fn_ran is True -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) def test_or_query_in_transaction(client, cleanup, database): """ Test running or query inside a transaction. Should pass transaction id along with request @@ -3380,6 +3460,7 @@ def in_transaction(transaction): result[0].get("b") == 2 and result[1].get("b") == 1 ) inner_fn_ran = True + verify_pipeline(query) in_transaction(transaction) # make sure we didn't skip assertions in inner function diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index bc79ee2df..b78a77786 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -53,7 +53,9 @@ MISSING_DOCUMENT, RANDOM_ID_REGEX, UNIQUE_RESOURCE_ID, + ENTERPRISE_MODE_ERROR, TEST_DATABASES, + TEST_DATABASES_W_ENTERPRISE, ) RETRIES = retries.AsyncRetry( @@ -160,6 +162,61 @@ async def cleanup(): await operation() +async def verify_pipeline(query): + """ + This function ensures a pipeline produces the same + results as the query it is derived from + + It can be attached to existing query tests to check both + modalities at the same time + """ + from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + + def _clean_results(results): + if isinstance(results, dict): + return {k: _clean_results(v) for k, v in results.items()} + elif isinstance(results, list): + return [_clean_results(r) for r in results] + elif isinstance(results, float) and math.isnan(results): + return "__NAN_VALUE__" + else: + return results + + query_exception = None + query_results = None + try: + try: + if isinstance(query, BaseAggregationQuery): + # aggregation queries return a list of lists of aggregation results + query_results = _clean_results( + list( + itertools.chain.from_iterable( + [[a._to_dict() for a in s] for s in await query.get()] + ) + ) + ) + else: + # other qureies return a simple list of results + query_results = _clean_results([s.to_dict() for s in await query.get()]) + except Exception as e: + # if we expect the query to fail, capture the exception + query_exception = e + pipeline = query.pipeline() + if query_exception: + # ensure that the pipeline uses same error as query + with pytest.raises(query_exception.__class__): + await pipeline.execute() + else: + # ensure results match query + pipeline_results = _clean_results([s.data() async for s in pipeline.stream()]) + assert query_results == pipeline_results + except FailedPrecondition as e: + # if testing against a non-enterprise db, skip this check + if ENTERPRISE_MODE_ERROR not in e.message: + raise e + + + @pytest.fixture(scope="module") def event_loop(): """Change event_loop fixture to module level.""" @@ -1203,7 +1260,7 @@ async def async_query(collection): return collection.where(filter=FieldFilter("a", "==", 1)) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_legacy_where(query_docs, database): """Assert the legacy code still works and returns value, and shows UserWarning""" collection, stored, allowed_vals = query_docs @@ -1217,9 +1274,10 @@ async def test_query_stream_legacy_where(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + await verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_simple_field_eq_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("a", "==", 1)) @@ -1228,9 +1286,10 @@ async def test_query_stream_w_simple_field_eq_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + await verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_simple_field_array_contains_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("c", "array_contains", 1)) @@ -1239,9 +1298,10 @@ async def test_query_stream_w_simple_field_array_contains_op(query_docs, databas for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + await verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_simple_field_in_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1251,9 +1311,10 @@ async def test_query_stream_w_simple_field_in_op(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + await verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1265,9 +1326,10 @@ async def test_query_stream_w_simple_field_array_contains_any_op(query_docs, dat for key, value in values.items(): assert stored[key] == value assert value["a"] == 1 + await verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_order_by(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.order_by("b", direction=firestore.Query.DESCENDING) @@ -1279,9 +1341,10 @@ async def test_query_stream_w_order_by(query_docs, database): b_vals.append(value["b"]) # Make sure the ``b``-values are in DESCENDING order. assert sorted(b_vals, reverse=True) == b_vals + await verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_field_path(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", ">", 4)) @@ -1301,6 +1364,7 @@ async def test_query_stream_w_field_path(query_docs, database): ] ) assert expected_ab_pairs == ab_pairs2 + await verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1319,13 +1383,14 @@ async def test_query_stream_w_start_end_cursor(query_docs, database): assert value["a"] == num_vals - 2 -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_wo_results(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) query = collection.where(filter=FieldFilter("b", "==", num_vals + 100)) values = [i async for i in query.stream()] assert len(values) == 0 + await verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1345,7 +1410,7 @@ async def test_query_stream_w_projection(query_docs, database): assert expected == value -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_multiple_filters(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.product", ">", 5)).where( @@ -1363,9 +1428,10 @@ async def test_query_stream_w_multiple_filters(query_docs, database): assert stored[key] == value pair = (value["a"], value["b"]) assert pair in matching_pairs + await verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_w_offset(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1379,13 +1445,14 @@ async def test_query_stream_w_offset(query_docs, database): for key, value in values.items(): assert stored[key] == value assert value["b"] == 2 + await verify_pipeline(query) @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -1404,6 +1471,7 @@ async def test_query_stream_or_get_w_no_explain_options(query_docs, database, me # is called with pytest.raises(QueryExplainError, match="explain_options not set on query"): await results.get_explain_metrics() + await verify_pipeline(query) @pytest.mark.skipif( @@ -1570,7 +1638,7 @@ async def test_query_with_order_dot_key(client, cleanup, database): assert found_data == [snap.to_dict() for snap in cursor_with_key_data] -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_query_unary(client, cleanup, database): collection_name = "unary" + UNIQUE_RESOURCE_ID collection = client.collection(collection_name) @@ -1596,6 +1664,7 @@ async def test_query_unary(client, cleanup, database): snapshot0 = values0[0] assert snapshot0.reference._path == document0._path assert snapshot0.to_dict() == {field_name: None} + await verify_pipeline(query0) # 1. Query for a NAN. query1 = collection.where(filter=FieldFilter(field_name, "==", nan_val)) @@ -1606,6 +1675,7 @@ async def test_query_unary(client, cleanup, database): data1 = snapshot1.to_dict() assert len(data1) == 1 assert math.isnan(data1[field_name]) + await verify_pipeline(query1) # 2. Query for not null query2 = collection.where(filter=FieldFilter(field_name, "!=", None)) @@ -1625,7 +1695,7 @@ async def test_query_unary(client, cleanup, database): assert snapshot3.to_dict() == {field_name: 123} -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_collection_group_queries(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1655,7 +1725,8 @@ async def test_collection_group_queries(client, cleanup, database): snapshots = [i async for i in query.stream()] found = [snapshot.id for snapshot in snapshots] expected = ["cg-doc1", "cg-doc2", "cg-doc3", "cg-doc4", "cg-doc5"] - assert found == expected + assert set(found) == set(expected) + await verify_pipeline(query) @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @@ -1701,7 +1772,7 @@ async def test_collection_group_queries_startat_endat(client, cleanup, database) assert found == set(["cg-doc2"]) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) async def test_collection_group_queries_filters(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1743,6 +1814,7 @@ async def test_collection_group_queries_filters(client, cleanup, database): snapshots = [i async for i in query.stream()] found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2", "cg-doc3", "cg-doc4"]) + await verify_pipeline(query) query = ( client.collection_group(collection_group) @@ -1764,6 +1836,7 @@ async def test_collection_group_queries_filters(client, cleanup, database): snapshots = [i async for i in query.stream()] found = set(snapshot.id for snapshot in snapshots) assert found == set(["cg-doc2"]) + await verify_pipeline(query) @pytest.mark.skipif( diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 69ca69ec7..5064e87ae 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -20,6 +20,7 @@ from google.cloud.firestore_v1.base_aggregation import ( AggregationResult, AvgAggregation, + BaseAggregation, CountAggregation, SumAggregation, ) @@ -27,6 +28,7 @@ from google.cloud.firestore_v1.query_results import QueryResultsList from google.cloud.firestore_v1.stream_generator import StreamGenerator from google.cloud.firestore_v1.types import RunAggregationQueryResponse +from google.cloud.firestore_v1.field_path import FieldPath from google.protobuf.timestamp_pb2 import Timestamp from tests.unit.v1._test_helpers import ( make_aggregation_query, @@ -121,6 +123,65 @@ def test_avg_aggregation_no_alias_to_pb(): assert got_pb.alias == "" +@pytest.mark.parametrize( + "in_alias,expected_alias", [("total", "total"), (None, "field_1")] +) +def test_count_aggregation_to_pipeline_expr(in_alias, expected_alias): + from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias + from google.cloud.firestore_v1.pipeline_expressions import Count + + count_aggregation = CountAggregation(alias=in_alias) + got = count_aggregation._to_pipeline_expr(iter([1])) + assert isinstance(got, ExprWithAlias) + assert got.alias == expected_alias + assert isinstance(got.expr, Count) + assert len(got.expr.params) == 0 + + +@pytest.mark.parametrize( + "in_alias,expected_path,expected_alias", + [("total", "path", "total"), (None, "some_ref", "field_1")], +) +def test_sum_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alias): + from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias + from google.cloud.firestore_v1.pipeline_expressions import Sum + + count_aggregation = SumAggregation(expected_path, alias=in_alias) + got = count_aggregation._to_pipeline_expr(iter([1])) + assert isinstance(got, ExprWithAlias) + assert got.alias == expected_alias + assert isinstance(got.expr, Sum) + assert got.expr.params[0].path == expected_path + + +@pytest.mark.parametrize( + "in_alias,expected_path,expected_alias", + [("total", "path", "total"), (None, "some_ref", "field_1")], +) +def test_avg_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alias): + from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias + from google.cloud.firestore_v1.pipeline_expressions import Avg + + count_aggregation = AvgAggregation(expected_path, alias=in_alias) + got = count_aggregation._to_pipeline_expr(iter([1])) + assert isinstance(got, ExprWithAlias) + assert got.alias == expected_alias + assert isinstance(got.expr, Avg) + assert got.expr.params[0].path == expected_path + + +def test_aggregation__pipeline_alias_increment(): + """ + BaseAggregation.__pipeline_alias should pull from an autoindexer to populate field numbers + """ + autoindex = iter(range(10)) + mock_instance = mock.Mock() + mock_instance.alias = None + for i in range(10): + got_name = BaseAggregation._pipeline_alias(mock_instance, autoindex) + assert got_name == f"field_{i}" + + def test_aggregation_query_constructor(): client = make_client() parent = client.collection("dee") @@ -894,6 +955,16 @@ def test_aggregation_query_stream_w_explain_options_analyze_false(): _aggregation_query_stream_helper(explain_options=ExplainOptions(analyze=False)) +def test_aggretgation__to_dict(): + expected_alias = "alias" + expected_value = "value" + instance = AggregationResult(alias=expected_alias, value=expected_value) + dict_result = instance._to_dict() + assert len(dict_result) == 1 + assert next(iter(dict_result)) == expected_alias + assert dict_result[expected_alias] == expected_value + + def test_aggregation_from_query(): from google.cloud.firestore_v1 import _helpers @@ -952,3 +1023,147 @@ def test_aggregation_from_query(): metadata=client._rpc_metadata, **kwargs, ) + + +@pytest.mark.parametrize( + "field,in_alias,out_alias", + [ + ("field", None, "field_1"), + (FieldPath("test"), None, "field_1"), + ("field", "overwrite", "overwrite"), + ], +) +def test_aggreation_to_pipeline_sum(field, in_alias, out_alias): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Sum + + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + aggregation_query.sum(field, alias=in_alias) + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, Pipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Sum) + expected_field = field if isinstance(field, str) else field.to_api_repr() + assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field + assert aggregate_stage.accumulators[0].alias == out_alias + + +@pytest.mark.parametrize( + "field,in_alias,out_alias", + [ + ("field", None, "field_1"), + (FieldPath("test"), None, "field_1"), + ("field", "overwrite", "overwrite"), + ], +) +def test_aggreation_to_pipeline_avg(field, in_alias, out_alias): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Avg + + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + aggregation_query.avg(field, alias=in_alias) + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, Pipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Avg) + expected_field = field if isinstance(field, str) else field.to_api_repr() + assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field + assert aggregate_stage.accumulators[0].alias == out_alias + + +@pytest.mark.parametrize( + "in_alias,out_alias", + [ + (None, "field_1"), + ("overwrite", "overwrite"), + ], +) +def test_aggreation_to_pipeline_count(in_alias, out_alias): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Count + + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + aggregation_query.count(alias=in_alias) + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, Pipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Count) + assert aggregate_stage.accumulators[0].alias == out_alias + + +def test_aggreation_to_pipeline_count_increment(): + """ + When aliases aren't given, should assign incrementing field_n values + """ + from google.cloud.firestore_v1.pipeline_expressions import Count + + n = 100 + client = make_client() + parent = client.collection("dee") + query = make_query(parent) + aggregation_query = make_aggregation_query(query) + for _ in range(n): + aggregation_query.count() + pipeline = aggregation_query.pipeline() + aggregate_stage = pipeline.stages[1] + assert len(aggregate_stage.accumulators) == n + for i in range(n): + assert isinstance(aggregate_stage.accumulators[i].expr, Count) + assert aggregate_stage.accumulators[i].alias == f"field_{i+1}" + + +def test_aggreation_to_pipeline_complex(): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate, Select + from google.cloud.firestore_v1.pipeline_expressions import Sum, Avg, Count + + client = make_client() + query = client.collection("my_col").select(["field_a", "field_b.c"]) + aggregation_query = make_aggregation_query(query) + aggregation_query.sum("field", alias="alias") + aggregation_query.count() + aggregation_query.avg("other") + aggregation_query.sum("final") + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, Pipeline) + assert len(pipeline.stages) == 3 + assert isinstance(pipeline.stages[0], Collection) + assert isinstance(pipeline.stages[1], Select) + assert isinstance(pipeline.stages[2], Aggregate) + aggregate_stage = pipeline.stages[2] + assert len(aggregate_stage.accumulators) == 4 + assert isinstance(aggregate_stage.accumulators[0].expr, Sum) + assert aggregate_stage.accumulators[0].alias == "alias" + assert isinstance(aggregate_stage.accumulators[1].expr, Count) + assert aggregate_stage.accumulators[1].alias == "field_1" + assert isinstance(aggregate_stage.accumulators[2].expr, Avg) + assert aggregate_stage.accumulators[2].alias == "field_2" + assert isinstance(aggregate_stage.accumulators[3].expr, Sum) + assert aggregate_stage.accumulators[3].alias == "field_3" diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index 9140f53e8..fdd4a1450 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -31,6 +31,7 @@ from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator from google.cloud.firestore_v1.query_profile import ExplainMetrics, QueryExplainError from google.cloud.firestore_v1.query_results import QueryResultsList +from google.cloud.firestore_v1.field_path import FieldPath _PROJECT = "PROJECT" @@ -696,3 +697,147 @@ async def test_aggregation_query_stream_w_explain_options_analyze_false(): explain_options = ExplainOptions(analyze=False) await _async_aggregation_query_stream_helper(explain_options=explain_options) + + +@pytest.mark.parametrize( + "field,in_alias,out_alias", + [ + ("field", None, "field_1"), + (FieldPath("test"), None, "field_1"), + ("field", "overwrite", "overwrite"), + ], +) +def test_async_aggreation_to_pipeline_sum(field, in_alias, out_alias): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Sum + + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + aggregation_query.sum(field, alias=in_alias) + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, AsyncPipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Sum) + expected_field = field if isinstance(field, str) else field.to_api_repr() + assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field + assert aggregate_stage.accumulators[0].alias == out_alias + + +@pytest.mark.parametrize( + "field,in_alias,out_alias", + [ + ("field", None, "field_1"), + (FieldPath("test"), None, "field_1"), + ("field", "overwrite", "overwrite"), + ], +) +def test_async_aggreation_to_pipeline_avg(field, in_alias, out_alias): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Avg + + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + aggregation_query.avg(field, alias=in_alias) + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, AsyncPipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Avg) + expected_field = field if isinstance(field, str) else field.to_api_repr() + assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field + assert aggregate_stage.accumulators[0].alias == out_alias + + +@pytest.mark.parametrize( + "in_alias,out_alias", + [ + (None, "field_1"), + ("overwrite", "overwrite"), + ], +) +def test_async_aggreation_to_pipeline_count(in_alias, out_alias): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_expressions import Count + + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + aggregation_query.count(alias=in_alias) + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, AsyncPipeline) + assert len(pipeline.stages) == 2 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/dee" + aggregate_stage = pipeline.stages[1] + assert isinstance(aggregate_stage, Aggregate) + assert len(aggregate_stage.accumulators) == 1 + assert isinstance(aggregate_stage.accumulators[0].expr, Count) + assert aggregate_stage.accumulators[0].alias == out_alias + + +def test_aggreation_to_pipeline_count_increment(): + """ + When aliases aren't given, should assign incrementing field_n values + """ + from google.cloud.firestore_v1.pipeline_expressions import Count + + n = 100 + client = make_async_client() + parent = client.collection("dee") + query = make_async_query(parent) + aggregation_query = make_async_aggregation_query(query) + for _ in range(n): + aggregation_query.count() + pipeline = aggregation_query.pipeline() + aggregate_stage = pipeline.stages[1] + assert len(aggregate_stage.accumulators) == n + for i in range(n): + assert isinstance(aggregate_stage.accumulators[i].expr, Count) + assert aggregate_stage.accumulators[i].alias == f"field_{i+1}" + + +def test_async_aggreation_to_pipeline_complex(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate, Select + from google.cloud.firestore_v1.pipeline_expressions import Sum, Avg, Count + + client = make_async_client() + query = client.collection("my_col").select(["field_a", "field_b.c"]) + aggregation_query = make_async_aggregation_query(query) + aggregation_query.sum("field", alias="alias") + aggregation_query.count() + aggregation_query.avg("other") + aggregation_query.sum("final") + pipeline = aggregation_query.pipeline() + assert isinstance(pipeline, AsyncPipeline) + assert len(pipeline.stages) == 3 + assert isinstance(pipeline.stages[0], Collection) + assert isinstance(pipeline.stages[1], Select) + assert isinstance(pipeline.stages[2], Aggregate) + aggregate_stage = pipeline.stages[2] + assert len(aggregate_stage.accumulators) == 4 + assert isinstance(aggregate_stage.accumulators[0].expr, Sum) + assert aggregate_stage.accumulators[0].alias == "alias" + assert isinstance(aggregate_stage.accumulators[1].expr, Count) + assert aggregate_stage.accumulators[1].alias == "field_1" + assert isinstance(aggregate_stage.accumulators[2].expr, Avg) + assert aggregate_stage.accumulators[2].alias == "field_2" + assert isinstance(aggregate_stage.accumulators[3].expr, Sum) + assert aggregate_stage.accumulators[3].alias == "field_3" diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index a0194ace5..353997b8e 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -601,3 +601,23 @@ def test_asynccollectionreference_recursive(): col = _make_async_collection_reference("collection") assert isinstance(col.recursive(), AsyncQuery) + + +def test_asynccollectionreference_pipeline(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1._pipeline_stages import Collection + + client = make_async_client() + collection = _make_async_collection_reference("collection", client=client) + pipeline = collection.pipeline() + assert isinstance(pipeline, AsyncPipeline) + # should have single "Collection" stage + assert len(pipeline.stages) == 1 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/collection" + + +def test_asynccollectionreference_pipeline_no_client(): + collection = _make_async_collection_reference("collection") + with pytest.raises(ValueError, match="client"): + collection.pipeline() diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index 54c80e5ad..dc5eb9e8a 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -909,3 +909,22 @@ async def test_asynccollectiongroup_get_partitions_w_offset(): query = _make_async_collection_group(parent).offset(10) with pytest.raises(ValueError): [i async for i in query.get_partitions(2)] + + +def test_asyncquery_collection_pipeline_type(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + client = make_async_client() + parent = client.collection("test") + query = parent._query() + ppl = query.pipeline() + assert isinstance(ppl, AsyncPipeline) + + +def test_asyncquery_collectiongroup_pipeline_type(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + client = make_async_client() + query = client.collection_group("test") + ppl = query.pipeline() + assert isinstance(ppl, AsyncPipeline) diff --git a/tests/unit/v1/test_base_collection.py b/tests/unit/v1/test_base_collection.py index 22baa0c5f..7f7be9c07 100644 --- a/tests/unit/v1/test_base_collection.py +++ b/tests/unit/v1/test_base_collection.py @@ -422,6 +422,20 @@ def test_basecollectionreference_end_at(mock_query): assert query == mock_query.end_at.return_value +@mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) +def test_basecollectionreference_pipeline(mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference + + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query + + collection = _make_base_collection_reference("collection") + pipeline = collection.pipeline() + + mock_query.pipeline.assert_called_once_with() + assert pipeline == mock_query.pipeline.return_value + + @mock.patch("random.choice") def test__auto_id(mock_rand_choice): from google.cloud.firestore_v1.base_collection import _AUTO_ID_CHARS, _auto_id diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 7804b0430..9bb3e61f8 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -18,6 +18,7 @@ import pytest from tests.unit.v1._test_helpers import make_client +from google.cloud.firestore_v1 import _pipeline_stages as stages def _make_base_query(*args, **kwargs): @@ -1993,6 +1994,175 @@ def test__collection_group_query_response_to_snapshot_response(): assert snapshot.update_time == response_pb._pb.document.update_time +def test__query_pipeline_no_client(): + mock_parent = mock.Mock() + mock_parent._client = None + query = _make_base_query(mock_parent) + with pytest.raises(ValueError, match="client"): + query.pipeline() + + +def test__query_pipeline_decendants(): + client = make_client() + query = client.collection_group("my_col") + pipeline = query.pipeline() + + assert len(pipeline.stages) == 1 + stage = pipeline.stages[0] + assert isinstance(stage, stages.CollectionGroup) + assert stage.collection_id == "my_col" + + +@pytest.mark.parametrize( + "in_path,out_path", + [ + ("my_col/doc/", "/my_col/doc/"), + ("/my_col/doc", "/my_col/doc"), + ("my_col/doc/sub_col", "/my_col/doc/sub_col"), + ], +) +def test__query_pipeline_no_decendants(in_path, out_path): + client = make_client() + collection = client.collection(in_path) + query = collection._query() + pipeline = query.pipeline() + + assert len(pipeline.stages) == 1 + stage = pipeline.stages[0] + assert isinstance(stage, stages.Collection) + assert stage.path == out_path + + +def test__query_pipeline_composite_filter(): + from google.cloud.firestore_v1 import FieldFilter + from google.cloud.firestore_v1 import pipeline_expressions as expr + + client = make_client() + in_filter = FieldFilter("field_a", "==", "value_a") + query = client.collection("my_col").where(filter=in_filter) + with mock.patch.object( + expr.FilterCondition, "_from_query_filter_pb" + ) as convert_mock: + pipeline = query.pipeline() + convert_mock.assert_called_once_with(in_filter._to_pb(), client) + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, stages.Where) + assert stage.condition == convert_mock.return_value + + +def test__query_pipeline_projections(): + client = make_client() + query = client.collection("my_col").select(["field_a", "field_b.c"]) + pipeline = query.pipeline() + + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, stages.Select) + assert len(stage.projections) == 2 + assert stage.projections[0].path == "field_a" + assert stage.projections[1].path == "field_b.c" + + +def test__query_pipeline_order_exists_multiple(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + + client = make_client() + query = client.collection("my_col").order_by("field_a").order_by("field_b") + pipeline = query.pipeline() + + # should have collection, where, and sort + # we're interested in where + assert len(pipeline.stages) == 3 + where_stage = pipeline.stages[1] + assert isinstance(where_stage, stages.Where) + # should have and with both orderings + assert isinstance(where_stage.condition, expr.And) + assert len(where_stage.condition.params) == 2 + operands = [p for p in where_stage.condition.params] + assert isinstance(operands[0], expr.Exists) + assert operands[0].params[0].path == "field_a" + assert isinstance(operands[1], expr.Exists) + assert operands[1].params[0].path == "field_b" + + +def test__query_pipeline_order_exists_single(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + + client = make_client() + query_single = client.collection("my_col").order_by("field_c") + pipeline_single = query_single.pipeline() + + # should have collection, where, and sort + # we're interested in where + assert len(pipeline_single.stages) == 3 + where_stage_single = pipeline_single.stages[1] + assert isinstance(where_stage_single, stages.Where) + assert isinstance(where_stage_single.condition, expr.Exists) + assert where_stage_single.condition.params[0].path == "field_c" + + +def test__query_pipeline_order_sorts(): + from google.cloud.firestore_v1 import pipeline_expressions as expr + from google.cloud.firestore_v1.base_query import BaseQuery + + client = make_client() + query = ( + client.collection("my_col") + .order_by("field_a", direction=BaseQuery.ASCENDING) + .order_by("field_b", direction=BaseQuery.DESCENDING) + ) + pipeline = query.pipeline() + + assert len(pipeline.stages) == 3 + sort_stage = pipeline.stages[2] + assert isinstance(sort_stage, stages.Sort) + assert len(sort_stage.orders) == 2 + assert isinstance(sort_stage.orders[0], expr.Ordering) + assert sort_stage.orders[0].expr.path == "field_a" + assert sort_stage.orders[0].order_dir == expr.Ordering.Direction.ASCENDING + assert isinstance(sort_stage.orders[1], expr.Ordering) + assert sort_stage.orders[1].expr.path == "field_b" + assert sort_stage.orders[1].order_dir == expr.Ordering.Direction.DESCENDING + + +def test__query_pipeline_unsupported(): + client = make_client() + query_start = client.collection("my_col").start_at({"field_a": "value"}) + with pytest.raises(NotImplementedError, match="cursors"): + query_start.pipeline() + + query_end = client.collection("my_col").end_at({"field_a": "value"}) + with pytest.raises(NotImplementedError, match="cursors"): + query_end.pipeline() + + query_limit_last = client.collection("my_col").limit_to_last(10) + with pytest.raises(NotImplementedError, match="limit_to_last"): + query_limit_last.pipeline() + + +def test__query_pipeline_limit(): + client = make_client() + query = client.collection("my_col").limit(15) + pipeline = query.pipeline() + + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, stages.Limit) + assert stage.limit == 15 + + +def test__query_pipeline_offset(): + client = make_client() + query = client.collection("my_col").offset(5) + pipeline = query.pipeline() + + assert len(pipeline.stages) == 2 + stage = pipeline.stages[1] + assert isinstance(stage, stages.Offset) + assert stage.offset == 5 + + def _make_order_pb(field_path, direction): from google.cloud.firestore_v1.types import query diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index da91651b9..9e615541a 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -15,6 +15,7 @@ import types import mock +import pytest from datetime import datetime, timezone from tests.unit.v1._test_helpers import DEFAULT_TEST_PROJECT @@ -510,6 +511,27 @@ def test_stream_w_read_time(query_class): ) +def test_collectionreference_pipeline(): + from tests.unit.v1 import _test_helpers + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1._pipeline_stages import Collection + + client = _test_helpers.make_client() + collection = _make_collection_reference("collection", client=client) + pipeline = collection.pipeline() + assert isinstance(pipeline, Pipeline) + # should have single "Collection" stage + assert len(pipeline.stages) == 1 + assert isinstance(pipeline.stages[0], Collection) + assert pipeline.stages[0].path == "/collection" + + +def test_collectionreference_pipeline_no_client(): + collection = _make_collection_reference("collection") + with pytest.raises(ValueError, match="client"): + collection.pipeline() + + @mock.patch("google.cloud.firestore_v1.collection.Watch", autospec=True) def test_on_snapshot(watch): collection = _make_collection_reference("collection") diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index c5329df33..9f06c47b8 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -834,6 +834,15 @@ def test_is_nan(self): infix_instance = arg1.is_nan() assert infix_instance == instance + def test_is_null(self): + arg1 = self._make_arg("Value") + instance = Expr.is_ull(arg1) + assert instance.name == "is_null" + assert instance.params == [arg1] + assert repr(instance) == "Value.is_null()" + infix_instance = arg1.is_null() + assert infix_instance == instance + def test_not(self): arg1 = self._make_arg("Condition") instance = expr.Not(arg1) @@ -1179,6 +1188,12 @@ def test_count(self): infix_instance = arg1.count() assert infix_instance == instance + def test_base_count(self): + instance = expr.Count() + assert instance.name == "count" + assert instance.params == [] + assert repr(instance) == "Count()" + def test_minimum(self): arg1 = self._make_arg("Value") instance = Expr.minimum(arg1) diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py index d5b36e56c..fadea7e12 100644 --- a/tests/unit/v1/test_pipeline_stages.py +++ b/tests/unit/v1/test_pipeline_stages.py @@ -185,8 +185,9 @@ def test_to_pb(self): instance = self._make_one(input_arg) result = instance._to_pb() assert result.name == "collection_group" - assert len(result.args) == 1 - assert result.args[0].string_value == "test" + assert len(result.args) == 2 + assert result.args[0].reference_value == "" + assert result.args[1].string_value == "test" assert len(result.options) == 0 diff --git a/tests/unit/v1/test_query.py b/tests/unit/v1/test_query.py index b8c37cf84..8b1217370 100644 --- a/tests/unit/v1/test_query.py +++ b/tests/unit/v1/test_query.py @@ -1046,3 +1046,22 @@ def test_collection_group_get_partitions_w_offset(database): query = _make_collection_group(parent).offset(10) with pytest.raises(ValueError): list(query.get_partitions(2)) + + +def test_asyncquery_collection_pipeline_type(): + from google.cloud.firestore_v1.pipeline import Pipeline + + client = make_client() + parent = client.collection("test") + query = parent._query() + ppl = query.pipeline() + assert isinstance(ppl, Pipeline) + + +def test_asyncquery_collectiongroup_pipeline_type(): + from google.cloud.firestore_v1.pipeline import Pipeline + + client = make_client() + query = client.collection_group("test") + ppl = query.pipeline() + assert isinstance(ppl, Pipeline) From 643f01488f0b818bd14102d608237fa65ff12e62 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 29 Oct 2025 14:05:46 -0700 Subject: [PATCH 06/27] feat: Additional Pipeline Expressions (#1115) * fixed tests * added vector expressions * added new math expressions * added string manipulation expressions * added not_nan, not_null, and is_absent * added new Array type * added map and related expressions * remove dict and list from constant types * Fixed lint * added count_if and count_distinct * added misc expressions * added error functions * fixed lint * fixed typos --- google/cloud/firestore_v1/_pipeline_stages.py | 8 +- .../firestore_v1/pipeline_expressions.py | 674 +++++++- tests/system/pipeline_e2e.yaml | 1483 ++++++++++++++++- tests/system/test_pipeline_acceptance.py | 45 +- tests/system/test_system.py | 9 +- tests/system/test_system_async.py | 5 +- tests/unit/v1/test_aggregation.py | 2 +- tests/unit/v1/test_async_aggregation.py | 2 +- tests/unit/v1/test_pipeline_expressions.py | 526 +++++- 9 files changed, 2546 insertions(+), 208 deletions(-) diff --git a/google/cloud/firestore_v1/_pipeline_stages.py b/google/cloud/firestore_v1/_pipeline_stages.py index 7233a8eec..c63b748ac 100644 --- a/google/cloud/firestore_v1/_pipeline_stages.py +++ b/google/cloud/firestore_v1/_pipeline_stages.py @@ -274,13 +274,17 @@ def __init__( self, field: str | Expr, vector: Sequence[float] | Vector, - distance_measure: "DistanceMeasure", + distance_measure: "DistanceMeasure" | str, options: Optional["FindNearestOptions"] = None, ): super().__init__("find_nearest") self.field: Expr = Field(field) if isinstance(field, str) else field self.vector: Vector = vector if isinstance(vector, Vector) else Vector(vector) - self.distance_measure = distance_measure + self.distance_measure = ( + distance_measure + if isinstance(distance_measure, DistanceMeasure) + else DistanceMeasure[distance_measure.upper()] + ) self.options = options or FindNearestOptions() def _pb_args(self): diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 4639e0f7d..b113e2874 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -17,7 +17,6 @@ Any, Generic, TypeVar, - Dict, Sequence, ) from abc import ABC @@ -41,8 +40,6 @@ bytes, GeoPoint, Vector, - list, - Dict[str, Any], None, ) @@ -113,8 +110,20 @@ def _to_pb(self) -> Value: raise NotImplementedError @staticmethod - def _cast_to_expr_or_convert_to_constant(o: Any) -> "Expr": - return o if isinstance(o, Expr) else Constant(o) + def _cast_to_expr_or_convert_to_constant(o: Any, include_vector=False) -> "Expr": + """Convert arbitrary object to an Expr.""" + if isinstance(o, Constant) and isinstance(o.value, list): + o = o.value + if isinstance(o, Expr): + return o + if isinstance(o, dict): + return Map(o) + if isinstance(o, list): + if include_vector and all([isinstance(i, (float, int)) for i in o]): + return Constant(Vector(o)) + else: + return Array(o) + return Constant(o) class expose_as_static: """ @@ -132,6 +141,10 @@ def __init__(self, instance_func): self.instance_func = instance_func def static_func(self, first_arg, *other_args, **kwargs): + if not isinstance(first_arg, (Expr, str)): + raise TypeError( + f"`expressions must be called on an Expr or a string representing a field name. got {type(first_arg)}." + ) first_expr = ( Field.of(first_arg) if not isinstance(first_arg, Expr) else first_arg ) @@ -239,6 +252,147 @@ def mod(self, other: Expr | float) -> "Expr": """ return Function("mod", [self, self._cast_to_expr_or_convert_to_constant(other)]) + @expose_as_static + def abs(self) -> "Expr": + """Creates an expression that calculates the absolute value of this expression. + + Example: + >>> # Get the absolute value of the 'change' field. + >>> Field.of("change").abs() + + Returns: + A new `Expr` representing the absolute value. + """ + return Function("abs", [self]) + + @expose_as_static + def ceil(self) -> "Expr": + """Creates an expression that calculates the ceiling of this expression. + + Example: + >>> # Get the ceiling of the 'value' field. + >>> Field.of("value").ceil() + + Returns: + A new `Expr` representing the ceiling value. + """ + return Function("ceil", [self]) + + @expose_as_static + def exp(self) -> "Expr": + """Creates an expression that computes e to the power of this expression. + + Example: + >>> # Compute e to the power of the 'value' field + >>> Field.of("value").exp() + + Returns: + A new `Expr` representing the exponential value. + """ + return Function("exp", [self]) + + @expose_as_static + def floor(self) -> "Expr": + """Creates an expression that calculates the floor of this expression. + + Example: + >>> # Get the floor of the 'value' field. + >>> Field.of("value").floor() + + Returns: + A new `Expr` representing the floor value. + """ + return Function("floor", [self]) + + @expose_as_static + def ln(self) -> "Expr": + """Creates an expression that calculates the natural logarithm of this expression. + + Example: + >>> # Get the natural logarithm of the 'value' field. + >>> Field.of("value").ln() + + Returns: + A new `Expr` representing the natural logarithm. + """ + return Function("ln", [self]) + + @expose_as_static + def log(self, base: Expr | float) -> "Expr": + """Creates an expression that calculates the logarithm of this expression with a given base. + + Example: + >>> # Get the logarithm of 'value' with base 2. + >>> Field.of("value").log(2) + >>> # Get the logarithm of 'value' with base from 'base_field'. + >>> Field.of("value").log(Field.of("base_field")) + + Args: + base: The base of the logarithm. + + Returns: + A new `Expr` representing the logarithm. + """ + return Function("log", [self, self._cast_to_expr_or_convert_to_constant(base)]) + + @expose_as_static + def log10(self) -> "Expr": + """Creates an expression that calculates the base 10 logarithm of this expression. + + Example: + >>> Field.of("value").log10() + + Returns: + A new `Expr` representing the logarithm. + """ + return Function("log10", [self]) + + @expose_as_static + def pow(self, exponent: Expr | float) -> "Expr": + """Creates an expression that calculates this expression raised to the power of the exponent. + + Example: + >>> # Raise 'base_val' to the power of 2. + >>> Field.of("base_val").pow(2) + >>> # Raise 'base_val' to the power of 'exponent_val'. + >>> Field.of("base_val").pow(Field.of("exponent_val")) + + Args: + exponent: The exponent. + + Returns: + A new `Expr` representing the power operation. + """ + return Function( + "pow", [self, self._cast_to_expr_or_convert_to_constant(exponent)] + ) + + @expose_as_static + def round(self) -> "Expr": + """Creates an expression that rounds this expression to the nearest integer. + + Example: + >>> # Round the 'value' field. + >>> Field.of("value").round() + + Returns: + A new `Expr` representing the rounded value. + """ + return Function("round", [self]) + + @expose_as_static + def sqrt(self) -> "Expr": + """Creates an expression that calculates the square root of this expression. + + Example: + >>> # Get the square root of the 'area' field. + >>> Field.of("area").sqrt() + + Returns: + A new `Expr` representing the square root. + """ + return Function("sqrt", [self]) + @expose_as_static def logical_maximum(self, other: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that returns the larger value between this expression @@ -420,7 +574,9 @@ def less_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": ) @expose_as_static - def equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": + def equal_any( + self, array: Array | Sequence[Expr | CONSTANT_TYPE] | Expr + ) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to any of the provided values or expressions. @@ -438,14 +594,14 @@ def equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": "equal_any", [ self, - _ListOfExprs( - [self._cast_to_expr_or_convert_to_constant(v) for v in array] - ), + self._cast_to_expr_or_convert_to_constant(array), ], ) @expose_as_static - def not_equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": + def not_equal_any( + self, array: Array | list[Expr | CONSTANT_TYPE] | Expr + ) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to any of the provided values or expressions. @@ -463,9 +619,7 @@ def not_equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": "not_equal_any", [ self, - _ListOfExprs( - [self._cast_to_expr_or_convert_to_constant(v) for v in array] - ), + self._cast_to_expr_or_convert_to_constant(array), ], ) @@ -492,7 +646,7 @@ def array_contains(self, element: Expr | CONSTANT_TYPE) -> "BooleanExpr": @expose_as_static def array_contains_all( self, - elements: Sequence[Expr | CONSTANT_TYPE], + elements: Array | list[Expr | CONSTANT_TYPE] | Expr, ) -> "BooleanExpr": """Creates an expression that checks if an array contains all the specified elements. @@ -512,16 +666,14 @@ def array_contains_all( "array_contains_all", [ self, - _ListOfExprs( - [self._cast_to_expr_or_convert_to_constant(e) for e in elements] - ), + self._cast_to_expr_or_convert_to_constant(elements), ], ) @expose_as_static def array_contains_any( self, - elements: Sequence[Expr | CONSTANT_TYPE], + elements: Array | list[Expr | CONSTANT_TYPE] | Expr, ) -> "BooleanExpr": """Creates an expression that checks if an array contains any of the specified elements. @@ -542,9 +694,7 @@ def array_contains_any( "array_contains_any", [ self, - _ListOfExprs( - [self._cast_to_expr_or_convert_to_constant(e) for e in elements] - ), + self._cast_to_expr_or_convert_to_constant(elements), ], ) @@ -574,6 +724,90 @@ def array_reverse(self) -> "Expr": """ return Function("array_reverse", [self]) + @expose_as_static + def array_concat( + self, *other_arrays: Array | list[Expr | CONSTANT_TYPE] | Expr + ) -> "Expr": + """Creates an expression that concatenates an array expression with another array. + + Example: + >>> # Combine the 'tags' array with a new array and an array field + >>> Field.of("tags").array_concat(["newTag1", "newTag2", Field.of("otherTag")]) + + Args: + array: The list of constants or expressions to concat with. + + Returns: + A new `Expr` representing the concatenated array. + """ + return Function( + "array_concat", + [self] + + [self._cast_to_expr_or_convert_to_constant(arr) for arr in other_arrays], + ) + + @expose_as_static + def concat(self, *others: Expr | CONSTANT_TYPE) -> "Expr": + """Creates an expression that concatenates expressions together + + Args: + *others: The expressions to concatenate. + + Returns: + A new `Expr` representing the concatenated value. + """ + return Function( + "concat", + [self] + [self._cast_to_expr_or_convert_to_constant(o) for o in others], + ) + + @expose_as_static + def length(self) -> "Expr": + """ + Creates an expression that calculates the length of the expression if it is a string, array, map, or blob. + + Example: + >>> # Get the length of the 'name' field. + >>> Field.of("name").length() + + Returns: + A new `Expr` representing the length of the expression. + """ + return Function("length", [self]) + + @expose_as_static + def is_absent(self) -> "BooleanExpr": + """Creates an expression that returns true if a value is absent. Otherwise, returns false even if + the value is null. + + Example: + >>> # Check if the 'email' field is absent. + >>> Field.of("email").is_absent() + + Returns: + A new `BooleanExpression` representing the isAbsent operation. + """ + return BooleanExpr("is_absent", [self]) + + @expose_as_static + def if_absent(self, default_value: Expr | CONSTANT_TYPE) -> "Expr": + """Creates an expression that returns a default value if an expression evaluates to an absent value. + + Example: + >>> # Return the value of the 'email' field, or "N/A" if it's absent. + >>> Field.of("email").if_absent("N/A") + + Args: + default_value: The expression or constant value to return if this expression is absent. + + Returns: + A new `Expr` representing the ifAbsent operation. + """ + return Function( + "if_absent", + [self, self._cast_to_expr_or_convert_to_constant(default_value)], + ) + @expose_as_static def is_nan(self) -> "BooleanExpr": """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). @@ -587,9 +821,22 @@ def is_nan(self) -> "BooleanExpr": """ return BooleanExpr("is_nan", [self]) + @expose_as_static + def is_not_nan(self) -> "BooleanExpr": + """Creates an expression that checks if this expression evaluates to a non-'NaN' (Not a Number) value. + + Example: + >>> # Check if the result of a calculation is not NaN + >>> Field.of("value").divide(1).is_not_nan() + + Returns: + A new `Expr` representing the 'is not NaN' check. + """ + return BooleanExpr("is_not_nan", [self]) + @expose_as_static def is_null(self) -> "BooleanExpr": - """Creates an expression that checks if this expression evaluates to 'Null'. + """Creates an expression that checks if the value of a field is 'Null'. Example: >>> Field.of("value").is_null() @@ -599,6 +846,50 @@ def is_null(self) -> "BooleanExpr": """ return BooleanExpr("is_null", [self]) + @expose_as_static + def is_not_null(self) -> "BooleanExpr": + """Creates an expression that checks if the value of a field is not 'Null'. + + Example: + >>> Field.of("value").is_not_null() + + Returns: + A new `Expr` representing the 'isNotNull' check. + """ + return BooleanExpr("is_not_null", [self]) + + @expose_as_static + def is_error(self): + """Creates an expression that checks if a given expression produces an error + + Example: + >>> # Resolves to True if an expression produces an error + >>> Field.of("value").divide("string").is_error() + + Returns: + A new `Expr` representing the isError operation. + """ + return Function("is_error", [self]) + + @expose_as_static + def if_error(self, then_value: Expr | CONSTANT_TYPE) -> "Expr": + """Creates an expression that returns ``then_value`` if this expression evaluates to an error. + Otherwise, returns the value of this expression. + + Example: + >>> # Resolves to 0 if an expression produces an error + >>> Field.of("value").divide("string").if_error(0) + + Args: + then_value: The value to return if this expression evaluates to an error. + + Returns: + A new `Expr` representing the ifError operation. + """ + return Function( + "if_error", [self, self._cast_to_expr_or_convert_to_constant(then_value)] + ) + @expose_as_static def exists(self) -> "BooleanExpr": """Creates an expression that checks if a field exists in the document. @@ -653,6 +944,35 @@ def count(self) -> "Expr": """ return AggregateFunction("count", [self]) + @expose_as_static + def count_if(self) -> "Expr": + """Creates an aggregation that counts the number of values of the provided field or expression + that evaluate to True. + + Example: + >>> # Count the number of adults + >>> Field.of("age").greater_than(18).count_if().as_("totalAdults") + + + Returns: + A new `AggregateFunction` representing the 'count_if' aggregation. + """ + return AggregateFunction("count_if", [self]) + + @expose_as_static + def count_distinct(self) -> "Expr": + """Creates an aggregation that counts the number of distinct values of the + provided field or expression. + + Example: + >>> # Count the total number of countries in the data + >>> Field.of("country").count_distinct().as_("totalCountries") + + Returns: + A new `AggregateFunction` representing the 'count_distinct' aggregation. + """ + return AggregateFunction("count_distinct", [self]) + @expose_as_static def minimum(self) -> "Expr": """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. @@ -846,12 +1166,106 @@ def string_concat(self, *elements: Expr | CONSTANT_TYPE) -> "Expr": [self] + [self._cast_to_expr_or_convert_to_constant(el) for el in elements], ) + @expose_as_static + def to_lower(self) -> "Expr": + """Creates an expression that converts a string to lowercase. + + Example: + >>> # Convert the 'name' field to lowercase + >>> Field.of("name").to_lower() + + Returns: + A new `Expr` representing the lowercase string. + """ + return Function("to_lower", [self]) + + @expose_as_static + def to_upper(self) -> "Expr": + """Creates an expression that converts a string to uppercase. + + Example: + >>> # Convert the 'title' field to uppercase + >>> Field.of("title").to_upper() + + Returns: + A new `Expr` representing the uppercase string. + """ + return Function("to_upper", [self]) + + @expose_as_static + def trim(self) -> "Expr": + """Creates an expression that removes leading and trailing whitespace from a string. + + Example: + >>> # Trim whitespace from the 'userInput' field + >>> Field.of("userInput").trim() + + Returns: + A new `Expr` representing the trimmed string. + """ + return Function("trim", [self]) + + @expose_as_static + def string_reverse(self) -> "Expr": + """Creates an expression that reverses a string. + + Example: + >>> # Reverse the 'userInput' field + >>> Field.of("userInput").reverse() + + Returns: + A new `Expr` representing the reversed string. + """ + return Function("string_reverse", [self]) + + @expose_as_static + def substring( + self, position: Expr | int, length: Expr | int | None = None + ) -> "Expr": + """Creates an expression that returns a substring of the results of this expression. + + + Example: + >>> Field.of("description").substring(5, 10) + >>> Field.of("description").substring(5) + + Args: + position: the index of the first character of the substring. + length: the length of the substring. If not provided the substring + will end at the end of the input. + + Returns: + A new `Expr` representing the extracted substring. + """ + args = [self, self._cast_to_expr_or_convert_to_constant(position)] + if length is not None: + args.append(self._cast_to_expr_or_convert_to_constant(length)) + return Function("substring", args) + + @expose_as_static + def join(self, delimeter: Expr | str) -> "Expr": + """Creates an expression that joins the elements of an array into a string + + + Example: + >>> Field.of("tags").join(", ") + + Args: + delimiter: The delimiter to add between the elements of the array. + + Returns: + A new `Expr` representing the joined string. + """ + return Function( + "join", [self, self._cast_to_expr_or_convert_to_constant(delimeter)] + ) + @expose_as_static def map_get(self, key: str | Constant[str]) -> "Expr": """Accesses a value from the map produced by evaluating this expression. Example: - >>> Expr.map({"city": "London"}).map_get("city") + >>> Map({"city": "London"}).map_get("city") >>> Field.of("address").map_get("city") Args: @@ -861,7 +1275,118 @@ def map_get(self, key: str | Constant[str]) -> "Expr": A new `Expr` representing the value associated with the given key in the map. """ return Function( - "map_get", [self, Constant.of(key) if isinstance(key, str) else key] + "map_get", [self, self._cast_to_expr_or_convert_to_constant(key)] + ) + + @expose_as_static + def map_remove(self, key: str | Constant[str]) -> "Expr": + """Remove a key from the map produced by evaluating this expression. + + Example: + >>> Map({"city": "London"}).map_remove("city") + >>> Field.of("address").map_remove("city") + + Args: + key: The key to remove in the map. + + Returns: + A new `Expr` representing the map_remove operation. + """ + return Function( + "map_remove", [self, self._cast_to_expr_or_convert_to_constant(key)] + ) + + @expose_as_static + def map_merge( + self, *other_maps: Map | dict[str | Constant[str], Expr | CONSTANT_TYPE] | Expr + ) -> "Expr": + """Creates an expression that merges one or more dicts into a single map. + + Example: + >>> Map({"city": "London"}).map_merge({"country": "UK"}, {"isCapital": True}) + >>> Field.of("settings").map_merge({"enabled":True}, Function.conditional(Field.of('isAdmin'), {"admin":True}, {}}) + + Args: + *other_maps: Sequence of maps to merge into the resulting map. + + Returns: + A new `Expr` representing the value associated with the given key in the map. + """ + return Function( + "map_merge", + [self] + [self._cast_to_expr_or_convert_to_constant(m) for m in other_maps], + ) + + @expose_as_static + def cosine_distance(self, other: Expr | list[float] | Vector) -> "Expr": + """Calculates the cosine distance between two vectors. + + Example: + >>> # Calculate the cosine distance between the 'userVector' field and the 'itemVector' field + >>> Field.of("userVector").cosine_distance(Field.of("itemVector")) + >>> # Calculate the Cosine distance between the 'location' field and a target location + >>> Field.of("location").cosine_distance([37.7749, -122.4194]) + + Args: + other: The other vector (represented as an Expr, list of floats, or Vector) to compare against. + + Returns: + A new `Expr` representing the cosine distance between the two vectors. + """ + return Function( + "cosine_distance", + [ + self, + self._cast_to_expr_or_convert_to_constant(other, include_vector=True), + ], + ) + + @expose_as_static + def euclidean_distance(self, other: Expr | list[float] | Vector) -> "Expr": + """Calculates the Euclidean distance between two vectors. + + Example: + >>> # Calculate the Euclidean distance between the 'location' field and a target location + >>> Field.of("location").euclidean_distance([37.7749, -122.4194]) + >>> # Calculate the Euclidean distance between two vector fields: 'pointA' and 'pointB' + >>> Field.of("pointA").euclidean_distance(Field.of("pointB")) + + Args: + other: The other vector (represented as an Expr, list of floats, or Vector) to compare against. + + Returns: + A new `Expr` representing the Euclidean distance between the two vectors. + """ + return Function( + "euclidean_distance", + [ + self, + self._cast_to_expr_or_convert_to_constant(other, include_vector=True), + ], + ) + + @expose_as_static + def dot_product(self, other: Expr | list[float] | Vector) -> "Expr": + """Calculates the dot product between two vectors. + + Example: + >>> # Calculate the dot product between a feature vector and a target vector + >>> Field.of("features").dot_product([0.5, 0.8, 0.2]) + >>> # Calculate the dot product between two document vectors: 'docVector1' and 'docVector2' + >>> Field.of("docVector1").dot_product(Field.of("docVector2")) + + Args: + other: The other vector (represented as an Expr, list of floats, or Vector) to calculate dot product with. + + Returns: + A new `Expr` representing the dot product between the two vectors. + """ + return Function( + "dot_product", + [ + self, + self._cast_to_expr_or_convert_to_constant(other, include_vector=True), + ], ) @expose_as_static @@ -1034,6 +1559,19 @@ def collection_id(self): """ return Function("collection_id", [self]) + @expose_as_static + def document_id(self): + """Creates an expression that returns the document ID from a path. + + Example: + >>> # Get the document ID from a path. + >>> Field.of("__name__").document_id() + + Returns: + A new `Expr` representing the document ID. + """ + return Function("document_id", [self]) + def ascending(self) -> Ordering: """Creates an `Ordering` that sorts documents in ascending order based on this expression. @@ -1107,25 +1645,6 @@ def _to_pb(self) -> Value: return encode_value(self.value) -class _ListOfExprs(Expr): - """Represents a list of expressions, typically used as an argument to functions like 'in' or array functions.""" - - def __init__(self, exprs: Sequence[Expr]): - self.exprs: list[Expr] = list(exprs) - - def __eq__(self, other): - if not isinstance(other, _ListOfExprs): - return False - else: - return other.exprs == self.exprs - - def __repr__(self): - return repr(self.exprs) - - def _to_pb(self): - return Value(array_value={"values": [e._to_pb() for e in self.exprs]}) - - class Function(Expr): """A base class for expressions that represent function calls.""" @@ -1323,11 +1842,11 @@ def _from_query_filter_pb(filter_pb, client): if filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NAN: return And(field.exists(), field.is_nan()) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NAN: - return And(field.exists(), Not(field.is_nan())) + return And(field.exists(), field.is_not_nan()) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NULL: return And(field.exists(), field.is_null()) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NULL: - return And(field.exists(), Not(field.is_null())) + return And(field.exists(), field.is_not_null()) else: raise TypeError(f"Unexpected UnaryFilter operator type: {filter_pb.op}") elif isinstance(filter_pb, Query_pb.FieldFilter): @@ -1367,6 +1886,55 @@ def _from_query_filter_pb(filter_pb, client): raise TypeError(f"Unexpected filter type: {type(filter_pb)}") +class Array(Function): + """ + Creates an expression that creates a Firestore array value from an input list. + + Example: + >>> Expr.array(["bar", Field.of("baz")]) + + Args: + elements: The input list to evaluate in the expression + """ + + def __init__(self, elements: list[Expr | CONSTANT_TYPE]): + if not isinstance(elements, list): + raise TypeError("Array must be constructed with a list") + converted_elements = [ + self._cast_to_expr_or_convert_to_constant(el) for el in elements + ] + super().__init__("array", converted_elements) + + def __repr__(self): + return f"Array({self.params})" + + +class Map(Function): + """ + Creates an expression that creates a Firestore map value from an input dict. + + Example: + >>> Expr.map({"foo": "bar", "baz": Field.of("baz")}) + + Args: + elements: The input dict to evaluate in the expression + """ + + def __init__(self, elements: dict[str | Constant[str], Expr | CONSTANT_TYPE]): + element_list = [] + for k, v in elements.items(): + element_list.append(self._cast_to_expr_or_convert_to_constant(k)) + element_list.append(self._cast_to_expr_or_convert_to_constant(v)) + super().__init__("map", element_list) + + def __repr__(self): + formatted_params = [ + a.value if isinstance(a, Constant) else a for a in self.params + ] + d = {a: b for a, b in zip(formatted_params[::2], formatted_params[1::2])} + return f"Map({d})" + + class And(BooleanExpr): """ Represents an expression that performs a logical 'AND' operation on multiple filter conditions. @@ -1454,6 +2022,7 @@ def __init__(self, condition: BooleanExpr, then_expr: Expr, else_expr: Expr): "conditional", [condition, then_expr, else_expr], use_infix_repr=False ) + class Count(AggregateFunction): """ Represents an aggregation that counts the number of stage inputs with valid evaluations of the @@ -1471,6 +2040,15 @@ class Count(AggregateFunction): def __init__(self, expression: Expr | None = None): expression_list = [expression] if expression else [] - super().__init__( - "count", expression_list, use_infix_repr=bool(expression_list) - ) + super().__init__("count", expression_list, use_infix_repr=bool(expression_list)) + + +class CurrentTimestamp(Function): + """Creates an expression that returns the current timestamp + + Returns: + A new `Expr` representing the current timestamp. + """ + + def __init__(self): + super().__init__("current_timestamp", [], use_infix_repr=False) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 50cc7c29d..38595224a 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -136,6 +136,10 @@ data: embedding: [1.0, 2.0, 3.0] vec2: embedding: [4.0, 5.0, 6.0, 7.0] + vec3: + embedding: [5.0, 6.0, 7.0] + vec4: + embedding: [1.0, 2.0, 4.0] tests: - description: "testAggregates - count" pipeline: @@ -163,6 +167,64 @@ tests: - fieldReferenceValue: rating - mapValue: {} name: aggregate + - description: "testAggregates - count_if" + pipeline: + - Collection: books + - Aggregate: + - AliasedExpr: + - Expr.count_if: + - Expr.greater_than: + - Field: rating + - Constant: 4.2 + - "count_if_rating_gt_4_2" + assert_results: + - count_if_rating_gt_4_2: 5 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + count_if_rating_gt_4_2: + functionValue: + name: count_if + args: + - functionValue: + name: greater_than + args: + - fieldReferenceValue: rating + - doubleValue: 4.2 + - mapValue: {} + name: aggregate + - description: "testAggregates - count_distinct" + pipeline: + - Collection: books + - Aggregate: + - AliasedExpr: + - Expr.count_distinct: + - Field: genre + - "distinct_genres" + assert_results: + - distinct_genres: 8 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + distinct_genres: + functionValue: + name: count_distinct + args: + - fieldReferenceValue: genre + - mapValue: {} + name: aggregate - description: "testAggregates - avg, count, max" pipeline: - Collection: books @@ -697,10 +759,11 @@ tests: - functionValue: args: - fieldReferenceValue: tags - - arrayValue: - values: + - functionValue: + args: - stringValue: comedy - stringValue: classic + name: array name: array_contains_any name: where - args: @@ -739,10 +802,11 @@ tests: - functionValue: args: - fieldReferenceValue: tags - - arrayValue: - values: + - functionValue: + args: - stringValue: adventure - stringValue: magic + name: array name: array_contains_all name: where - args: @@ -929,7 +993,57 @@ tests: expression: fieldReferenceValue: title name: sort + - description: testConcat + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpr: + - Expr.concat: + - Field: author + - Constant: ": " + - Field: title + - "author_title" + - AliasedExpr: + - Expr.concat: + - Field: tags + - - Constant: "new_tag" + - "concatenatedTags" + assert_results: + - author_title: "Douglas Adams: The Hitchhiker's Guide to the Galaxy" + concatenatedTags: + - comedy + - space + - adventure + - new_tag - description: testLength + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpr: + - Expr.length: + - Field: title + - "titleLength" + - AliasedExpr: + - Expr.length: + - Field: tags + - "tagsLength" + - AliasedExpr: + - Expr.length: + - Field: awards + - "awardsLength" + assert_results: + - titleLength: 36 + tagsLength: 3 + awardsLength: 2 + - description: testCharLength pipeline: - Collection: books - Select: @@ -1414,6 +1528,177 @@ tests: - args: - integerValue: '1' name: limit + - description: testIsNotNull + pipeline: + - Collection: books + - Where: + - Expr.is_not_null: + - Field: rating + assert_count: 10 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: rating + name: is_not_null + name: where + - description: testIsNotNaN + pipeline: + - Collection: books + - Where: + - Expr.is_not_nan: + - Field: rating + assert_count: 10 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: rating + name: is_not_nan + name: where + - description: testIsAbsent + pipeline: + - Collection: books + - Where: + - Expr.is_absent: + - Field: awards.pulitzer + assert_count: 9 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: awards.pulitzer + name: is_absent + name: where + - description: testIfAbsent + pipeline: + - Collection: books + - Select: + - AliasedExpr: + - Expr.if_absent: + - Field: awards.pulitzer + - Constant: false + - "pulitzer_award" + - title + - Where: + - Expr.equal: + - Field: pulitzer_award + - Constant: true + assert_results: + - pulitzer_award: true + title: To Kill a Mockingbird + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + pulitzer_award: + functionValue: + name: if_absent + args: + - fieldReferenceValue: awards.pulitzer + - booleanValue: false + title: + fieldReferenceValue: title + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: pulitzer_award + - booleanValue: true + name: equal + name: where + - description: testIsError + pipeline: + - Collection: books + - Select: + - AliasedExpr: + - Expr.is_error: + - Expr.divide: + - Field: rating + - Constant: "string" + - "is_error_result" + - Limit: 1 + assert_results: + - is_error_result: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + is_error_result: + functionValue: + name: is_error + args: + - functionValue: + name: divide + args: + - fieldReferenceValue: rating + - stringValue: "string" + name: select + - args: + - integerValue: '1' + name: limit + - description: testIfError + pipeline: + - Collection: books + - Select: + - AliasedExpr: + - Expr.if_error: + - Expr.divide: + - Field: rating + - Field: genre + - Constant: "An error occurred" + - "if_error_result" + - Limit: 1 + assert_results: + - if_error_result: "An error occurred" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + if_error_result: + functionValue: + name: if_error + args: + - functionValue: + name: divide + args: + - fieldReferenceValue: rating + - fieldReferenceValue: genre + - stringValue: "An error occurred" + name: select + - args: + - integerValue: '1' + name: limit - description: testLogicalMinMax pipeline: - Collection: books @@ -1520,25 +1805,27 @@ tests: - booleanValue: true name: equal name: where - - description: testNestedFields + - description: testMapGetWithField pipeline: - Collection: books - Where: - Expr.equal: - - Field: awards.hugo - - Constant: true - - Sort: - - Ordering: - - Field: title - - DESCENDING + - Field: title + - Constant: "Dune" + - AddFields: + - AliasedExpr: + - Constant: "hugo" + - "award_name" - Select: - - title - - Field: awards.hugo + - AliasedExpr: + - Expr.map_get: + - Field: awards + - Field: award_name + - "hugoAward" + - Field: title assert_results: - - title: The Hitchhiker's Guide to the Galaxy - awards.hugo: true - - title: Dune - awards.hugo: true + - hugoAward: true + title: Dune assert_proto: pipeline: stages: @@ -1548,31 +1835,44 @@ tests: - args: - functionValue: args: - - fieldReferenceValue: awards.hugo - - booleanValue: true + - fieldReferenceValue: title + - stringValue: "Dune" name: equal name: where - args: - mapValue: fields: - direction: - stringValue: descending - expression: - fieldReferenceValue: title - name: sort + award_name: + stringValue: "hugo" + name: add_fields - args: - mapValue: fields: - awards.hugo: - fieldReferenceValue: awards.hugo + hugoAward: + functionValue: + name: map_get + args: + - fieldReferenceValue: awards + - fieldReferenceValue: award_name title: fieldReferenceValue: title name: select - - description: testSampleLimit + - description: testMapRemove pipeline: - Collection: books - - Sample: 3 - assert_count: 3 # Results will vary due to randomness + - Where: + - Expr.equal: + - Field: title + - Constant: "Dune" + - Select: + - AliasedExpr: + - Expr.map_remove: + - Field: awards + - "nebula" + - "awards_removed" + assert_results: + - awards_removed: + hugo: true assert_proto: pipeline: stages: @@ -1580,14 +1880,146 @@ tests: - referenceValue: /books name: collection - args: - - integerValue: '3' - - stringValue: documents - name: sample - - description: testSamplePercentage - pipeline: - - Collection: books - - Sample: - - SampleOptions: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "Dune" + name: equal + name: where + - args: + - mapValue: + fields: + awards_removed: + functionValue: + name: map_remove + args: + - fieldReferenceValue: awards + - stringValue: "nebula" + name: select + - description: testMapMerge + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "Dune" + - Select: + - AliasedExpr: + - Expr.map_merge: + - Field: awards + - Map: + elements: {"new_award": true, "hugo": false} + - Map: + elements: {"another_award": "yes"} + - "awards_merged" + assert_results: + - awards_merged: + hugo: false + nebula: true + new_award: true + another_award: "yes" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "Dune" + name: equal + name: where + - args: + - mapValue: + fields: + awards_merged: + functionValue: + name: map_merge + args: + - fieldReferenceValue: awards + - functionValue: + name: map + args: + - stringValue: "new_award" + - booleanValue: true + - stringValue: "hugo" + - booleanValue: false + - functionValue: + name: map + args: + - stringValue: "another_award" + - stringValue: "yes" + name: select + - description: testNestedFields + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: awards.hugo + - Constant: true + - Sort: + - Ordering: + - Field: title + - DESCENDING + - Select: + - title + - Field: awards.hugo + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + awards.hugo: true + - title: Dune + awards.hugo: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: awards.hugo + - booleanValue: true + name: equal + name: where + - args: + - mapValue: + fields: + direction: + stringValue: descending + expression: + fieldReferenceValue: title + name: sort + - args: + - mapValue: + fields: + awards.hugo: + fieldReferenceValue: awards.hugo + title: + fieldReferenceValue: title + name: select + - description: testSampleLimit + pipeline: + - Collection: books + - Sample: 3 + assert_count: 3 # Results will vary due to randomness + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - integerValue: '3' + - stringValue: documents + name: sample + - description: testSamplePercentage + pipeline: + - Collection: books + - Sample: + - SampleOptions: - 0.6 - percent assert_proto: @@ -1730,59 +2162,379 @@ tests: - adventure - space - comedy - - description: testExists + - description: testDocumentId pipeline: - Collection: books - Where: - - And: - - Expr.exists: - - Field: awards.pulitzer - - Expr.equal: - - Field: awards.pulitzer - - Constant: true + - Expr.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" - Select: - - title + - AliasedExpr: + - Expr.document_id: + - Field: __name__ + - "doc_id" assert_results: - - title: To Kill a Mockingbird - - description: testSum + - doc_id: "book1" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "The Hitchhiker's Guide to the Galaxy" + name: equal + name: where + - args: + - mapValue: + fields: + doc_id: + functionValue: + name: document_id + args: + - fieldReferenceValue: __name__ + name: select + - description: testCurrentTimestamp + pipeline: + - Collection: books + - Limit: 1 + - Select: + - AliasedExpr: + - And: + - Expr.greater_than_or_equal: + - CurrentTimestamp: [] + - Expr.unix_seconds_to_timestamp: + - Constant: 1735689600 # 2025-01-01 + - Expr.less_than: + - CurrentTimestamp: [] + - Expr.unix_seconds_to_timestamp: + - Constant: 4892438400 # 2125-01-01 + - "is_between_2025_and_2125" + assert_results: + - is_between_2025_and_2125: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - integerValue: '1' + name: limit + - args: + - mapValue: + fields: + is_between_2025_and_2125: + functionValue: + name: and + args: + - functionValue: + name: greater_than_or_equal + args: + - functionValue: + name: current_timestamp + - functionValue: + name: unix_seconds_to_timestamp + args: + - integerValue: '1735689600' + - functionValue: + name: less_than + args: + - functionValue: + name: current_timestamp + - functionValue: + name: unix_seconds_to_timestamp + args: + - integerValue: '4892438400' + name: select + - description: testArrayConcat pipeline: - Collection: books - Where: - Expr.equal: - - Field: genre - - Constant: Science Fiction - - Aggregate: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: - AliasedExpr: - - Expr.sum: - - Field: rating - - "total_rating" + - Expr.array_concat: + - Field: tags + - Constant: ["new_tag", "another_tag"] + - "concatenatedTags" assert_results: - - total_rating: 8.8 - - description: testStringContains + - concatenatedTags: + - comedy + - space + - adventure + - new_tag + - another_tag + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "The Hitchhiker's Guide to the Galaxy" + name: equal + name: where + - args: + - mapValue: + fields: + concatenatedTags: + functionValue: + args: + - fieldReferenceValue: tags + - functionValue: + args: + - stringValue: "new_tag" + - stringValue: "another_tag" + name: array + name: array_concat + name: select + - description: testArrayConcatMultiple pipeline: - Collection: books - Where: - - Expr.string_contains: + - Expr.equal: - Field: title - - Constant: "Hitchhiker's" + - Constant: "Dune" - Select: - - title + - AliasedExpr: + - Expr.array_concat: + - Field: tags + - Constant: ["sci-fi"] + - Constant: ["classic", "epic"] + - "concatenatedTags" assert_results: - - title: "The Hitchhiker's Guide to the Galaxy" - - description: testVectorLength + - concatenatedTags: + - politics + - desert + - ecology + - sci-fi + - classic + - epic + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "Dune" + name: equal + name: where + - args: + - mapValue: + fields: + concatenatedTags: + functionValue: + args: + - fieldReferenceValue: tags + - functionValue: + args: + - stringValue: "sci-fi" + name: array + - functionValue: + args: + - stringValue: "classic" + - stringValue: "epic" + name: array + name: array_concat + name: select + - description: testMapMergeLiterals pipeline: - - Collection: vectors + - Collection: books + - Limit: 1 - Select: - AliasedExpr: - - Expr.vector_length: - - Field: embedding - - "embedding_length" - - Sort: - - Ordering: - - Field: embedding_length - - ASCENDING + - Expr.map_merge: + - Map: + elements: {"a": "orig", "b": "orig"} + - Map: + elements: {"b": "new", "c": "new"} + - "merged" assert_results: - - embedding_length: 3 + - merged: + a: "orig" + b: "new" + c: "new" + - description: testArrayContainsAnyWithField + pipeline: + - Collection: books + - AddFields: + - AliasedExpr: + - Expr.array_concat: + - Field: tags + - Array: ["Dystopian"] + - "new_tags" + - Where: + - Expr.array_contains_any: + - Field: new_tags + - - Constant: non_existent_tag + - Field: genre + - Select: + - title + - genre + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "1984" + genre: "Dystopian" + - title: "The Handmaid's Tale" + genre: "Dystopian" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + new_tags: + functionValue: + args: + - fieldReferenceValue: tags + - functionValue: + args: + - stringValue: "Dystopian" + name: array + name: array_concat + name: add_fields + - args: + - functionValue: + args: + - fieldReferenceValue: new_tags + - functionValue: + args: + - stringValue: "non_existent_tag" + - fieldReferenceValue: genre + name: array + name: array_contains_any + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + genre: + fieldReferenceValue: genre + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testArrayConcatLiterals + pipeline: + - Collection: books + - Limit: 1 + - Select: + - AliasedExpr: + - Expr.array_concat: + - Array: [1, 2, 3] + - Array: [4, 5] + - "concatenated" + assert_results: + - concatenated: [1, 2, 3, 4, 5] + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - integerValue: '1' + name: limit + - args: + - mapValue: + fields: + concatenated: + functionValue: + args: + - functionValue: + args: + - integerValue: '1' + - integerValue: '2' + - integerValue: '3' + name: array + - functionValue: + args: + - integerValue: '4' + - integerValue: '5' + name: array + name: array_concat + name: select + - description: testExists + pipeline: + - Collection: books + - Where: + - And: + - Expr.exists: + - Field: awards.pulitzer + - Expr.equal: + - Field: awards.pulitzer + - Constant: true + - Select: + - title + assert_results: + - title: To Kill a Mockingbird + - description: testSum + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: genre + - Constant: Science Fiction + - Aggregate: + - AliasedExpr: + - Expr.sum: + - Field: rating + - "total_rating" + assert_results: + - total_rating: 8.8 + - description: testStringContains + pipeline: + - Collection: books + - Where: + - Expr.string_contains: + - Field: title + - Constant: "Hitchhiker's" + - Select: + - title + assert_results: + - title: "The Hitchhiker's Guide to the Galaxy" + - description: testVectorLength + pipeline: + - Collection: vectors + - Select: + - AliasedExpr: + - Expr.vector_length: + - Field: embedding + - "embedding_length" + - Sort: + - Ordering: + - Field: embedding_length + - ASCENDING + assert_results: + - embedding_length: 3 + - embedding_length: 3 + - embedding_length: 3 - embedding_length: 4 - description: testTimestampFunctions pipeline: @@ -1956,3 +2708,596 @@ tests: conditional_field: "Dystopian" - title: "Dune" conditional_field: "Frank Herbert" + - description: testFindNearestEuclidean + pipeline: + - Collection: vectors + - FindNearest: + field: embedding + vector: [1.0, 2.0, 3.0] + distance_measure: EUCLIDEAN + options: + FindNearestOptions: + limit: 2 + distance_field: + Field: distance + - Select: + - distance + assert_results: + - distance: 0.0 + - distance: 1.0 + assert_proto: + pipeline: + stages: + - name: collection + args: + - referenceValue: /vectors + - name: find_nearest + args: + - fieldReferenceValue: embedding + - mapValue: + fields: + __type__: + stringValue: __vector__ + value: + arrayValue: + values: + - doubleValue: 1.0 + - doubleValue: 2.0 + - doubleValue: 3.0 + - stringValue: euclidean + options: + limit: + integerValue: '2' + distance_field: + fieldReferenceValue: distance + - name: select + args: + - mapValue: + fields: + distance: + fieldReferenceValue: distance + - description: testMathExpressions + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: To Kill a Mockingbird + - Select: + - AliasedExpr: + - Expr.abs: + - Field: rating + - "abs_rating" + - AliasedExpr: + - Expr.ceil: + - Field: rating + - "ceil_rating" + - AliasedExpr: + - Expr.exp: + - Field: rating + - "exp_rating" + - AliasedExpr: + - Expr.floor: + - Field: rating + - "floor_rating" + - AliasedExpr: + - Expr.ln: + - Field: rating + - "ln_rating" + - AliasedExpr: + - Expr.log10: + - Field: rating + - "log_rating_base10" + - AliasedExpr: + - Expr.log: + - Field: rating + - Constant: 2 + - "log_rating_base2" + - AliasedExpr: + - Expr.pow: + - Field: rating + - Constant: 2 + - "pow_rating" + - AliasedExpr: + - Expr.sqrt: + - Field: rating + - "sqrt_rating" + assert_results_approximate: + - abs_rating: 4.2 + ceil_rating: 5.0 + exp_rating: 66.686331 + floor_rating: 4.0 + ln_rating: 1.4350845 + log_rating_base10: 0.623249 + log_rating_base2: 2.0704 + pow_rating: 17.64 + sqrt_rating: 2.049390 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: To Kill a Mockingbird + name: equal + name: where + - args: + - mapValue: + fields: + abs_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: abs + ceil_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: ceil + exp_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: exp + floor_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: floor + ln_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: ln + log_rating_base10: + functionValue: + args: + - fieldReferenceValue: rating + name: log10 + log_rating_base2: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: log + pow_rating: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: pow + sqrt_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: sqrt + name: select + - description: testRoundExpressions + pipeline: + - Collection: books + - Where: + - Expr.equal_any: + - Field: title + - - Constant: "To Kill a Mockingbird" # rating 4.2 + - Constant: "Pride and Prejudice" # rating 4.5 + - Constant: "The Lord of the Rings" # rating 4.7 + - Select: + - title + - AliasedExpr: + - Expr.round: + - Field: rating + - "round_rating" + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "Pride and Prejudice" + round_rating: 5.0 + - title: "The Lord of the Rings" + round_rating: 5.0 + - title: "To Kill a Mockingbird" + round_rating: 4.0 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - functionValue: + args: + - stringValue: "To Kill a Mockingbird" + - stringValue: "Pride and Prejudice" + - stringValue: "The Lord of the Rings" + name: array + name: equal_any + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + round_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: round + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testFindNearestDotProduct + pipeline: + - Collection: vectors + - FindNearest: + field: embedding + vector: [1.0, 2.0, 3.0] + distance_measure: DOT_PRODUCT + options: + FindNearestOptions: + limit: 3 + distance_field: + Field: distance + - Select: + - distance + assert_results: + - distance: 38.0 + - distance: 17.0 + - distance: 14.0 + assert_proto: + pipeline: + stages: + - name: collection + args: + - referenceValue: /vectors + - name: find_nearest + args: + - fieldReferenceValue: embedding + - mapValue: + fields: + __type__: + stringValue: __vector__ + value: + arrayValue: + values: + - doubleValue: 1.0 + - doubleValue: 2.0 + - doubleValue: 3.0 + - stringValue: dot_product + options: + limit: + integerValue: '3' + distance_field: + fieldReferenceValue: distance + - name: select + args: + - mapValue: + fields: + distance: + fieldReferenceValue: distance + - description: testDotProductWithConstant + pipeline: + - Collection: vectors + - Where: + - Expr.equal: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - Select: + - AliasedExpr: + - Expr.dot_product: + - Field: embedding + - Vector: [1.0, 1.0, 1.0] + - "dot_product_result" + assert_results: + - dot_product_result: 6.0 + - description: testEuclideanDistanceWithConstant + pipeline: + - Collection: vectors + - Where: + - Expr.equal: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - Select: + - AliasedExpr: + - Expr.euclidean_distance: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - "euclidean_distance_result" + assert_results: + - euclidean_distance_result: 0.0 + - description: testCosineDistanceWithConstant + pipeline: + - Collection: vectors + - Where: + - Expr.equal: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - Select: + - AliasedExpr: + - Expr.cosine_distance: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - "cosine_distance_result" + assert_results: + - cosine_distance_result: 0.0 + - description: testStringFunctions - ToLower + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.to_lower: + - Field: title + - "lower_title" + assert_results: + - lower_title: "the hitchhiker's guide to the galaxy" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + lower_title: + functionValue: + args: + - fieldReferenceValue: title + name: to_lower + name: select + - description: testStringFunctions - ToUpper + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.to_upper: + - Field: title + - "upper_title" + assert_results: + - upper_title: "THE HITCHHIKER'S GUIDE TO THE GALAXY" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + upper_title: + functionValue: + args: + - fieldReferenceValue: title + name: to_upper + name: select + - description: testStringFunctions - Trim + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.trim: + - Expr.string_concat: + - Constant: " " + - Field: title + - Constant: " " + - "trimmed_title" + assert_results: + - trimmed_title: "The Hitchhiker's Guide to the Galaxy" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + trimmed_title: + functionValue: + args: + - functionValue: + args: + - stringValue: " " + - fieldReferenceValue: title + - stringValue: " " + name: string_concat + name: trim + name: select + - description: testStringFunctions - StringReverse + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Jane Austen" + - Select: + - AliasedExpr: + - Expr.string_reverse: + - Field: title + - "reversed_title" + assert_results: + - reversed_title: "ecidujerP dna edirP" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Jane Austen" + name: equal + name: where + - args: + - mapValue: + fields: + reversed_title: + functionValue: + args: + - fieldReferenceValue: title + name: string_reverse + name: select + - description: testStringFunctions - Substring + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.substring: + - Field: title + - Constant: 4 + - Constant: 11 + - "substring_title" + assert_results: + - substring_title: "Hitchhiker'" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Douglas Adams" + name: equal + name: where + - args: + - mapValue: + fields: + substring_title: + functionValue: + args: + - fieldReferenceValue: title + - integerValue: '4' + - integerValue: '11' + name: substring + name: select + - description: testStringFunctions - Substring without length + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Fyodor Dostoevsky" + - Select: + - AliasedExpr: + - Expr.substring: + - Field: title + - Constant: 10 + - "substring_title" + assert_results: + - substring_title: "Punishment" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Fyodor Dostoevsky" + name: equal + name: where + - args: + - mapValue: + fields: + substring_title: + functionValue: + args: + - fieldReferenceValue: title + - integerValue: '10' + name: substring + name: select + - description: testStringFunctions - Join + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.join: + - Field: tags + - Constant: ", " + - "joined_tags" + assert_results: + - joined_tags: "comedy, space, adventure" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Douglas Adams" + name: equal + name: where + - args: + - mapValue: + fields: + joined_tags: + functionValue: + args: + - fieldReferenceValue: tags + - stringValue: ", " + name: join + name: select diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index d4c654e63..682fe5e23 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -28,6 +28,7 @@ from google.cloud.firestore_v1 import _pipeline_stages as stages from google.cloud.firestore_v1 import pipeline_expressions from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1 import pipeline_expressions as expr from google.api_core.exceptions import GoogleAPIError from google.cloud.firestore import Client, AsyncClient @@ -86,7 +87,13 @@ def test_pipeline_expected_errors(test_dict, client): @pytest.mark.parametrize( "test_dict", - [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t], + [ + t + for t in yaml_loader() + if "assert_results" in t + or "assert_count" in t + or "assert_results_approximate" in t + ], ids=lambda x: f"{x.get('description', '')}", ) def test_pipeline_results(test_dict, client): @@ -94,12 +101,23 @@ def test_pipeline_results(test_dict, client): Ensure pipeline returns expected results """ expected_results = _parse_yaml_types(test_dict.get("assert_results", None)) + expected_approximate_results = _parse_yaml_types( + test_dict.get("assert_results_approximate", None) + ) expected_count = test_dict.get("assert_count", None) pipeline = parse_pipeline(client, test_dict["pipeline"]) # check if server responds as expected got_results = [snapshot.data() for snapshot in pipeline.stream()] if expected_results: assert got_results == expected_results + if expected_approximate_results: + assert len(got_results) == len( + expected_approximate_results + ), "got unexpected result count" + for idx in range(len(got_results)): + assert got_results[idx] == pytest.approx( + expected_approximate_results[idx], abs=1e-4 + ) if expected_count is not None: assert len(got_results) == expected_count @@ -126,7 +144,13 @@ async def test_pipeline_expected_errors_async(test_dict, async_client): @pytest.mark.parametrize( "test_dict", - [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t], + [ + t + for t in yaml_loader() + if "assert_results" in t + or "assert_count" in t + or "assert_results_approximate" in t + ], ids=lambda x: f"{x.get('description', '')}", ) @pytest.mark.asyncio @@ -135,12 +159,23 @@ async def test_pipeline_results_async(test_dict, async_client): Ensure pipeline returns expected results """ expected_results = _parse_yaml_types(test_dict.get("assert_results", None)) + expected_approximate_results = _parse_yaml_types( + test_dict.get("assert_results_approximate", None) + ) expected_count = test_dict.get("assert_count", None) pipeline = parse_pipeline(async_client, test_dict["pipeline"]) # check if server responds as expected got_results = [snapshot.data() async for snapshot in pipeline.stream()] if expected_results: assert got_results == expected_results + if expected_approximate_results: + assert len(got_results) == len( + expected_approximate_results + ), "got unexpected result count" + for idx in range(len(got_results)): + assert got_results[idx] == pytest.approx( + expected_approximate_results[idx], abs=1e-4 + ) if expected_count is not None: assert len(got_results) == expected_count @@ -218,7 +253,11 @@ def _apply_yaml_args_to_callable(callable_obj, client, yaml_args): """ if isinstance(yaml_args, dict): return callable_obj(**_parse_expressions(client, yaml_args)) - elif isinstance(yaml_args, list): + elif isinstance(yaml_args, list) and not ( + callable_obj == expr.Constant + or callable_obj == Vector + or callable_obj == expr.Array + ): # yaml has an array of arguments. Treat as args return callable_obj(*_parse_expressions(client, yaml_args)) else: diff --git a/tests/system/test_system.py b/tests/system/test_system.py index a8f94e2ba..c2bd93ef8 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -109,9 +109,11 @@ def _clean_results(results): if isinstance(query, BaseAggregationQuery): # aggregation queries return a list of lists of aggregation results query_results = _clean_results( - list(itertools.chain.from_iterable( - [[a._to_dict() for a in s] for s in query.get()] - )) + list( + itertools.chain.from_iterable( + [[a._to_dict() for a in s] for s in query.get()] + ) + ) ) else: # other qureies return a simple list of results @@ -1531,6 +1533,7 @@ def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): results.get_explain_metrics() verify_pipeline(query) + @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index b78a77786..d053cbd7a 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -208,7 +208,9 @@ def _clean_results(results): await pipeline.execute() else: # ensure results match query - pipeline_results = _clean_results([s.data() async for s in pipeline.stream()]) + pipeline_results = _clean_results( + [s.data() async for s in pipeline.stream()] + ) assert query_results == pipeline_results except FailedPrecondition as e: # if testing against a non-enterprise db, skip this check @@ -216,7 +218,6 @@ def _clean_results(results): raise e - @pytest.fixture(scope="module") def event_loop(): """Change event_loop fixture to module level.""" diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 5064e87ae..9a20fd386 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -1136,7 +1136,7 @@ def test_aggreation_to_pipeline_count_increment(): assert len(aggregate_stage.accumulators) == n for i in range(n): assert isinstance(aggregate_stage.accumulators[i].expr, Count) - assert aggregate_stage.accumulators[i].alias == f"field_{i+1}" + assert aggregate_stage.accumulators[i].alias == f"field_{i + 1}" def test_aggreation_to_pipeline_complex(): diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index fdd4a1450..701feab5b 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -810,7 +810,7 @@ def test_aggreation_to_pipeline_count_increment(): assert len(aggregate_stage.accumulators) == n for i in range(n): assert isinstance(aggregate_stage.accumulators[i].expr, Count) - assert aggregate_stage.accumulators[i].alias == f"field_{i+1}" + assert aggregate_stage.accumulators[i].alias == f"field_{i + 1}" def test_async_aggreation_to_pipeline_complex(): diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 9f06c47b8..aec721e7d 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -24,7 +24,6 @@ from google.cloud.firestore_v1._helpers import GeoPoint import google.cloud.firestore_v1.pipeline_expressions as expr from google.cloud.firestore_v1.pipeline_expressions import BooleanExpr -from google.cloud.firestore_v1.pipeline_expressions import _ListOfExprs from google.cloud.firestore_v1.pipeline_expressions import Expr from google.cloud.firestore_v1.pipeline_expressions import Constant from google.cloud.firestore_v1.pipeline_expressions import Field @@ -97,13 +96,6 @@ class TestConstant: Value(timestamp_value={"seconds": 1747008000}), ), (GeoPoint(1, 2), Value(geo_point_value={"latitude": 1, "longitude": 2})), - ( - [0.0, 1.0, 2.0], - Value( - array_value={"values": [Value(double_value=i) for i in range(3)]} - ), - ), - ({"a": "b"}, Value(map_value={"fields": {"a": Value(string_value="b")}})), ( Vector([1.0, 2.0]), Value( @@ -173,57 +165,6 @@ def test_equality(self, first, second, expected): assert (first == second) is expected -class TestListOfExprs: - def test_to_pb(self): - instance = _ListOfExprs([Constant(1), Constant(2)]) - result = instance._to_pb() - assert len(result.array_value.values) == 2 - assert result.array_value.values[0].integer_value == 1 - assert result.array_value.values[1].integer_value == 2 - - def test_empty_to_pb(self): - instance = _ListOfExprs([]) - result = instance._to_pb() - assert len(result.array_value.values) == 0 - - def test_repr(self): - instance = _ListOfExprs([Constant(1), Constant(2)]) - repr_string = repr(instance) - assert repr_string == "[Constant.of(1), Constant.of(2)]" - empty_instance = _ListOfExprs([]) - empty_repr_string = repr(empty_instance) - assert empty_repr_string == "[]" - - @pytest.mark.parametrize( - "first,second,expected", - [ - (_ListOfExprs([]), _ListOfExprs([]), True), - (_ListOfExprs([]), _ListOfExprs([Constant(1)]), False), - (_ListOfExprs([Constant(1)]), _ListOfExprs([]), False), - ( - _ListOfExprs([Constant(1)]), - _ListOfExprs([Constant(1)]), - True, - ), - ( - _ListOfExprs([Constant(1)]), - _ListOfExprs([Constant(2)]), - False, - ), - ( - _ListOfExprs([Constant(1), Constant(2)]), - _ListOfExprs([Constant(1), Constant(2)]), - True, - ), - (_ListOfExprs([Constant(1)]), [Constant(1)], False), - (_ListOfExprs([Constant(1)]), [1], False), - (_ListOfExprs([Constant(1)]), object(), False), - ], - ) - def test_equality(self, first, second, expected): - assert (first == second) is expected - - class TestSelectable: """ contains tests for each Expr class that derives from Selectable @@ -370,7 +311,7 @@ def test__from_query_filter_pb_composite_filter_or(self, mock_client): field1 = Field.of("field1") field2 = Field.of("field2") expected_cond1 = expr.And(field1.exists(), field1.equal(Constant("val1"))) - expected_cond2 = expr.And(field2.exists(), field2.equal(Constant(None))) + expected_cond2 = expr.And(field2.exists(), field2.is_null()) expected = expr.Or(expected_cond1, expected_cond2) assert repr(result) == repr(expected) @@ -457,9 +398,7 @@ def test__from_query_filter_pb_composite_filter_nested(self, mock_client): field3 = Field.of("field3") expected_cond1 = expr.And(field1.exists(), field1.equal(Constant("val1"))) expected_cond2 = expr.And(field2.exists(), field2.greater_than(Constant(10))) - expected_cond3 = expr.And( - field3.exists(), expr.Not(field3.equal(Constant(None))) - ) + expected_cond3 = expr.And(field3.exists(), field3.is_not_null()) expected_inner_and = expr.And(expected_cond2, expected_cond3) expected_outer_or = expr.Or(expected_cond1, expected_inner_and) @@ -491,15 +430,15 @@ def test__from_query_filter_pb_composite_filter_unknown_op(self, mock_client): (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NAN, Expr.is_nan), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NAN, - lambda f: expr.Not(f.is_nan()), + Expr.is_not_nan, ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, - lambda f: f.equal(None), + Expr.is_null, ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, - lambda f: expr.Not(f.equal(None)), + Expr.is_not_null, ), ], ) @@ -643,6 +582,69 @@ def test__from_query_filter_pb_unknown_filter_type(self, mock_client): BooleanExpr._from_query_filter_pb(document_pb.Value(), mock_client) +class TestArray: + """Tests for the array class""" + + def test_array(self): + arg1 = Field.of("field1") + instance = expr.Array([arg1]) + assert instance.name == "array" + assert instance.params == [arg1] + assert repr(instance) == "Array([Field.of('field1')])" + + def test_empty_array(self): + instance = expr.Array([]) + assert instance.name == "array" + assert instance.params == [] + assert repr(instance) == "Array([])" + + def test_array_w_primitives(self): + a = expr.Array([1, Constant.of(2), "3"]) + assert a.name == "array" + assert a.params == [Constant.of(1), Constant.of(2), Constant.of("3")] + assert repr(a) == "Array([Constant.of(1), Constant.of(2), Constant.of('3')])" + + def test_array_w_non_list(self): + with pytest.raises(TypeError): + expr.Array(1) + + +class TestMap: + """Tests for the map class""" + + def test_map(self): + instance = expr.Map({Constant.of("a"): Constant.of("b")}) + assert instance.name == "map" + assert instance.params == [Constant.of("a"), Constant.of("b")] + assert repr(instance) == "Map({'a': 'b'})" + + def test_map_w_primitives(self): + instance = expr.Map({"a": "b", "0": 0, "bool": True}) + assert instance.params == [ + Constant.of("a"), + Constant.of("b"), + Constant.of("0"), + Constant.of(0), + Constant.of("bool"), + Constant.of(True), + ] + assert repr(instance) == "Map({'a': 'b', '0': 0, 'bool': True})" + + def test_empty_map(self): + instance = expr.Map({}) + assert instance.name == "map" + assert instance.params == [] + assert repr(instance) == "Map({})" + + def test_w_exprs(self): + instance = expr.Map({Constant.of("a"): expr.Array([1, 2, 3])}) + assert instance.params == [Constant.of("a"), expr.Array([1, 2, 3])] + assert ( + repr(instance) + == "Map({'a': Array([Constant.of(1), Constant.of(2), Constant.of(3)])})" + ) + + class TestExpressionMethods: """ contains test methods for each Expr method @@ -723,10 +725,13 @@ def test_array_contains_any(self): arg3 = self._make_arg("Element2") instance = Expr.array_contains_any(arg1, [arg2, arg3]) assert instance.name == "array_contains_any" - assert isinstance(instance.params[1], _ListOfExprs) + assert isinstance(instance.params[1], expr.Array) assert instance.params[0] == arg1 - assert instance.params[1].exprs == [arg2, arg3] - assert repr(instance) == "ArrayField.array_contains_any([Element1, Element2])" + assert instance.params[1].params == [arg2, arg3] + assert ( + repr(instance) + == "ArrayField.array_contains_any(Array([Element1, Element2]))" + ) infix_instance = arg1.array_contains_any([arg2, arg3]) assert infix_instance == instance @@ -805,10 +810,10 @@ def test_equal_any(self): arg3 = self._make_arg("Value2") instance = Expr.equal_any(arg1, [arg2, arg3]) assert instance.name == "equal_any" - assert isinstance(instance.params[1], _ListOfExprs) + assert isinstance(instance.params[1], expr.Array) assert instance.params[0] == arg1 - assert instance.params[1].exprs == [arg2, arg3] - assert repr(instance) == "Field.equal_any([Value1, Value2])" + assert instance.params[1].params == [arg2, arg3] + assert repr(instance) == "Field.equal_any(Array([Value1, Value2]))" infix_instance = arg1.equal_any([arg2, arg3]) assert infix_instance == instance @@ -818,13 +823,32 @@ def test_not_equal_any(self): arg3 = self._make_arg("Value2") instance = Expr.not_equal_any(arg1, [arg2, arg3]) assert instance.name == "not_equal_any" - assert isinstance(instance.params[1], _ListOfExprs) + assert isinstance(instance.params[1], expr.Array) assert instance.params[0] == arg1 - assert instance.params[1].exprs == [arg2, arg3] - assert repr(instance) == "Field.not_equal_any([Value1, Value2])" + assert instance.params[1].params == [arg2, arg3] + assert repr(instance) == "Field.not_equal_any(Array([Value1, Value2]))" infix_instance = arg1.not_equal_any([arg2, arg3]) assert infix_instance == instance + def test_is_absent(self): + arg1 = self._make_arg("Field") + instance = Expr.is_absent(arg1) + assert instance.name == "is_absent" + assert instance.params == [arg1] + assert repr(instance) == "Field.is_absent()" + infix_instance = arg1.is_absent() + assert infix_instance == instance + + def test_if_absent(self): + arg1 = self._make_arg("Field") + arg2 = self._make_arg("ThenExpr") + instance = Expr.if_absent(arg1, arg2) + assert instance.name == "if_absent" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Field.if_absent(ThenExpr)" + infix_instance = arg1.if_absent(arg2) + assert infix_instance == instance + def test_is_nan(self): arg1 = self._make_arg("Value") instance = Expr.is_nan(arg1) @@ -834,15 +858,52 @@ def test_is_nan(self): infix_instance = arg1.is_nan() assert infix_instance == instance + def test_is_not_nan(self): + arg1 = self._make_arg("Value") + instance = Expr.is_not_nan(arg1) + assert instance.name == "is_not_nan" + assert instance.params == [arg1] + assert repr(instance) == "Value.is_not_nan()" + infix_instance = arg1.is_not_nan() + assert infix_instance == instance + def test_is_null(self): arg1 = self._make_arg("Value") - instance = Expr.is_ull(arg1) + instance = Expr.is_null(arg1) assert instance.name == "is_null" assert instance.params == [arg1] assert repr(instance) == "Value.is_null()" infix_instance = arg1.is_null() assert infix_instance == instance + def test_is_not_null(self): + arg1 = self._make_arg("Value") + instance = Expr.is_not_null(arg1) + assert instance.name == "is_not_null" + assert instance.params == [arg1] + assert repr(instance) == "Value.is_not_null()" + infix_instance = arg1.is_not_null() + assert infix_instance == instance + + def test_is_error(self): + arg1 = self._make_arg("Value") + instance = Expr.is_error(arg1) + assert instance.name == "is_error" + assert instance.params == [arg1] + assert repr(instance) == "Value.is_error()" + infix_instance = arg1.is_error() + assert infix_instance == instance + + def test_if_error(self): + arg1 = self._make_arg("Value") + arg2 = self._make_arg("ThenExpr") + instance = Expr.if_error(arg1, arg2) + assert instance.name == "if_error" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Value.if_error(ThenExpr)" + infix_instance = arg1.if_error(arg2) + assert infix_instance == instance + def test_not(self): arg1 = self._make_arg("Condition") instance = expr.Not(arg1) @@ -856,10 +917,13 @@ def test_array_contains_all(self): arg3 = self._make_arg("Element2") instance = Expr.array_contains_all(arg1, [arg2, arg3]) assert instance.name == "array_contains_all" - assert isinstance(instance.params[1], _ListOfExprs) + assert isinstance(instance.params[1], expr.Array) assert instance.params[0] == arg1 - assert instance.params[1].exprs == [arg2, arg3] - assert repr(instance) == "ArrayField.array_contains_all([Element1, Element2])" + assert instance.params[1].params == [arg2, arg3] + assert ( + repr(instance) + == "ArrayField.array_contains_all(Array([Element1, Element2]))" + ) infix_instance = arg1.array_contains_all([arg2, arg3]) assert infix_instance == instance @@ -970,6 +1034,73 @@ def test_logical_minimum(self): infix_instance = arg1.logical_minimum(arg2) assert infix_instance == instance + def test_to_lower(self): + arg1 = self._make_arg("Input") + instance = Expr.to_lower(arg1) + assert instance.name == "to_lower" + assert instance.params == [arg1] + assert repr(instance) == "Input.to_lower()" + infix_instance = arg1.to_lower() + assert infix_instance == instance + + def test_to_upper(self): + arg1 = self._make_arg("Input") + instance = Expr.to_upper(arg1) + assert instance.name == "to_upper" + assert instance.params == [arg1] + assert repr(instance) == "Input.to_upper()" + infix_instance = arg1.to_upper() + assert infix_instance == instance + + def test_trim(self): + arg1 = self._make_arg("Input") + instance = Expr.trim(arg1) + assert instance.name == "trim" + assert instance.params == [arg1] + assert repr(instance) == "Input.trim()" + infix_instance = arg1.trim() + assert infix_instance == instance + + def test_string_reverse(self): + arg1 = self._make_arg("Input") + instance = Expr.string_reverse(arg1) + assert instance.name == "string_reverse" + assert instance.params == [arg1] + assert repr(instance) == "Input.string_reverse()" + infix_instance = arg1.string_reverse() + assert infix_instance == instance + + def test_substring(self): + arg1 = self._make_arg("Input") + arg2 = self._make_arg("Position") + instance = Expr.substring(arg1, arg2) + assert instance.name == "substring" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Input.substring(Position)" + infix_instance = arg1.substring(arg2) + assert infix_instance == instance + + def test_substring_w_length(self): + arg1 = self._make_arg("Input") + arg2 = self._make_arg("Position") + arg3 = self._make_arg("Length") + instance = Expr.substring(arg1, arg2, arg3) + assert instance.name == "substring" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "Input.substring(Position, Length)" + infix_instance = arg1.substring(arg2, arg3) + assert infix_instance == instance + + def test_join(self): + arg1 = self._make_arg("Array") + arg2 = self._make_arg("Separator") + instance = Expr.join(arg1, arg2) + assert instance.name == "join" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Array.join(Separator)" + infix_instance = arg1.join(arg2) + assert infix_instance == instance + def test_map_get(self): arg1 = self._make_arg("Map") arg2 = "key" @@ -980,6 +1111,27 @@ def test_map_get(self): infix_instance = arg1.map_get(Constant.of(arg2)) assert infix_instance == instance + def test_map_remove(self): + arg1 = self._make_arg("Map") + arg2 = "key" + instance = Expr.map_remove(arg1, arg2) + assert instance.name == "map_remove" + assert instance.params == [arg1, Constant.of(arg2)] + assert repr(instance) == "Map.map_remove(Constant.of('key'))" + infix_instance = arg1.map_remove(Constant.of(arg2)) + assert infix_instance == instance + + def test_map_merge(self): + arg1 = expr.Map({"a": 1}) + arg2 = expr.Map({"b": 2}) + arg3 = {"c": 3} + instance = Expr.map_merge(arg1, arg2, arg3) + assert instance.name == "map_merge" + assert instance.params == [arg1, arg2, expr.Map(arg3)] + assert repr(instance) == "Map({'a': 1}).map_merge(Map({'b': 2}), Map({'c': 3}))" + infix_instance = arg1.map_merge(arg2, arg3) + assert infix_instance == instance + def test_mod(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") @@ -1021,6 +1173,12 @@ def test_subtract(self): infix_instance = arg1.subtract(arg2) assert infix_instance == instance + def test_current_timestamp(self): + instance = expr.CurrentTimestamp() + assert instance.name == "current_timestamp" + assert instance.params == [] + assert repr(instance) == "CurrentTimestamp()" + def test_timestamp_add(self): arg1 = self._make_arg("Timestamp") arg2 = self._make_arg("Unit") @@ -1097,6 +1255,54 @@ def test_unix_seconds_to_timestamp(self): infix_instance = arg1.unix_seconds_to_timestamp() assert infix_instance == instance + def test_euclidean_distance(self): + arg1 = self._make_arg("Vector1") + arg2 = self._make_arg("Vector2") + instance = Expr.euclidean_distance(arg1, arg2) + assert instance.name == "euclidean_distance" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Vector1.euclidean_distance(Vector2)" + infix_instance = arg1.euclidean_distance(arg2) + assert infix_instance == instance + + def test_cosine_distance(self): + arg1 = self._make_arg("Vector1") + arg2 = self._make_arg("Vector2") + instance = Expr.cosine_distance(arg1, arg2) + assert instance.name == "cosine_distance" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Vector1.cosine_distance(Vector2)" + infix_instance = arg1.cosine_distance(arg2) + assert infix_instance == instance + + def test_dot_product(self): + arg1 = self._make_arg("Vector1") + arg2 = self._make_arg("Vector2") + instance = Expr.dot_product(arg1, arg2) + assert instance.name == "dot_product" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Vector1.dot_product(Vector2)" + infix_instance = arg1.dot_product(arg2) + assert infix_instance == instance + + @pytest.mark.parametrize( + "method", ["euclidean_distance", "cosine_distance", "dot_product"] + ) + @pytest.mark.parametrize( + "input", [Vector([1.0, 2.0]), [1, 2], Constant.of(Vector([1.0, 2.0])), []] + ) + def test_vector_ctor(self, method, input): + """ + test constructing various vector expressions with + different inputs + """ + arg1 = self._make_arg("VectorRef") + instance = getattr(arg1, method)(input) + assert instance.name == method + got_second_param = instance.params[1] + assert isinstance(got_second_param, Constant) + assert isinstance(got_second_param.value, Vector) + def test_vector_length(self): arg1 = self._make_arg("Array") instance = Expr.vector_length(arg1) @@ -1116,6 +1322,98 @@ def test_add(self): infix_instance = arg1.add(arg2) assert infix_instance == instance + def test_abs(self): + arg1 = self._make_arg("Value") + instance = Expr.abs(arg1) + assert instance.name == "abs" + assert instance.params == [arg1] + assert repr(instance) == "Value.abs()" + infix_instance = arg1.abs() + assert infix_instance == instance + + def test_ceil(self): + arg1 = self._make_arg("Value") + instance = Expr.ceil(arg1) + assert instance.name == "ceil" + assert instance.params == [arg1] + assert repr(instance) == "Value.ceil()" + infix_instance = arg1.ceil() + assert infix_instance == instance + + def test_exp(self): + arg1 = self._make_arg("Value") + instance = Expr.exp(arg1) + assert instance.name == "exp" + assert instance.params == [arg1] + assert repr(instance) == "Value.exp()" + infix_instance = arg1.exp() + assert infix_instance == instance + + def test_floor(self): + arg1 = self._make_arg("Value") + instance = Expr.floor(arg1) + assert instance.name == "floor" + assert instance.params == [arg1] + assert repr(instance) == "Value.floor()" + infix_instance = arg1.floor() + assert infix_instance == instance + + def test_ln(self): + arg1 = self._make_arg("Value") + instance = Expr.ln(arg1) + assert instance.name == "ln" + assert instance.params == [arg1] + assert repr(instance) == "Value.ln()" + infix_instance = arg1.ln() + assert infix_instance == instance + + def test_log(self): + arg1 = self._make_arg("Value") + arg2 = self._make_arg("Base") + instance = Expr.log(arg1, arg2) + assert instance.name == "log" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Value.log(Base)" + infix_instance = arg1.log(arg2) + assert infix_instance == instance + + def test_log10(self): + arg1 = self._make_arg("Value") + instance = Expr.log10(arg1) + assert instance.name == "log10" + assert instance.params == [arg1] + assert repr(instance) == "Value.log10()" + infix_instance = arg1.log10() + assert infix_instance == instance + + def test_pow(self): + arg1 = self._make_arg("Value") + arg2 = self._make_arg("Exponent") + instance = Expr.pow(arg1, arg2) + assert instance.name == "pow" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Value.pow(Exponent)" + infix_instance = arg1.pow(arg2) + assert infix_instance == instance + + def test_round(self): + arg1 = self._make_arg("Value") + instance = Expr.round(arg1) + assert instance.name == "round" + assert instance.params == [arg1] + assert repr(instance) == "Value.round()" + infix_instance = arg1.round() + assert infix_instance == instance + + def test_sqrt(self): + arg1 = self._make_arg("Value") + instance = Expr.sqrt(arg1) + assert instance.name == "sqrt" + assert instance.params == [arg1] + assert repr(instance) == "Value.sqrt()" + infix_instance = arg1.sqrt() + assert infix_instance == instance + def test_array_length(self): arg1 = self._make_arg("Array") instance = Expr.array_length(arg1) @@ -1134,6 +1432,29 @@ def test_array_reverse(self): infix_instance = arg1.array_reverse() assert infix_instance == instance + def test_array_concat(self): + arg1 = self._make_arg("ArrayRef1") + arg2 = self._make_arg("ArrayRef2") + instance = Expr.array_concat(arg1, arg2) + assert instance.name == "array_concat" + assert instance.params == [arg1, arg2] + assert repr(instance) == "ArrayRef1.array_concat(ArrayRef2)" + infix_instance = arg1.array_concat(arg2) + assert infix_instance == instance + + def test_array_concat_multiple(self): + arg1 = expr.Array([Constant.of(0)]) + arg2 = Field.of("ArrayRef2") + arg3 = Field.of("ArrayRef3") + arg4 = [self._make_arg("Constant")] + instance = arg1.array_concat(arg2, arg3, arg4) + assert instance.name == "array_concat" + assert instance.params == [arg1, arg2, arg3, expr.Array(arg4)] + assert ( + repr(instance) + == "Array([Constant.of(0)]).array_concat(Field.of('ArrayRef2'), Field.of('ArrayRef3'), Array([Constant]))" + ) + def test_byte_length(self): arg1 = self._make_arg("Expr") instance = Expr.byte_length(arg1) @@ -1152,6 +1473,26 @@ def test_char_length(self): infix_instance = arg1.char_length() assert infix_instance == instance + def test_concat(self): + arg1 = self._make_arg("First") + arg2 = self._make_arg("Second") + arg3 = "Third" + instance = Expr.concat(arg1, arg2, arg3) + assert instance.name == "concat" + assert instance.params == [arg1, arg2, Constant.of(arg3)] + assert repr(instance) == "First.concat(Second, Constant.of('Third'))" + infix_instance = arg1.concat(arg2, arg3) + assert infix_instance == instance + + def test_length(self): + arg1 = self._make_arg("Expr") + instance = Expr.length(arg1) + assert instance.name == "length" + assert instance.params == [arg1] + assert repr(instance) == "Expr.length()" + infix_instance = arg1.length() + assert infix_instance == instance + def test_collection_id(self): arg1 = self._make_arg("Value") instance = Expr.collection_id(arg1) @@ -1161,6 +1502,15 @@ def test_collection_id(self): infix_instance = arg1.collection_id() assert infix_instance == instance + def test_document_id(self): + arg1 = self._make_arg("Value") + instance = Expr.document_id(arg1) + assert instance.name == "document_id" + assert instance.params == [arg1] + assert repr(instance) == "Value.document_id()" + infix_instance = arg1.document_id() + assert infix_instance == instance + def test_sum(self): arg1 = self._make_arg("Value") instance = Expr.sum(arg1) @@ -1194,6 +1544,24 @@ def test_base_count(self): assert instance.params == [] assert repr(instance) == "Count()" + def test_count_if(self): + arg1 = self._make_arg("Value") + instance = Expr.count_if(arg1) + assert instance.name == "count_if" + assert instance.params == [arg1] + assert repr(instance) == "Value.count_if()" + infix_instance = arg1.count_if() + assert infix_instance == instance + + def test_count_distinct(self): + arg1 = self._make_arg("Value") + instance = Expr.count_distinct(arg1) + assert instance.name == "count_distinct" + assert instance.params == [arg1] + assert repr(instance) == "Value.count_distinct()" + infix_instance = arg1.count_distinct() + assert infix_instance == instance + def test_minimum(self): arg1 = self._make_arg("Value") instance = Expr.minimum(arg1) From c62f3d9d0577da84cc217bdf1359a12890975d1a Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 30 Oct 2025 12:28:03 -0700 Subject: [PATCH 07/27] chore: improve pipelines tests (#1116) --- google/cloud/firestore_v1/_pipeline_stages.py | 28 +- .../firestore_v1/pipeline_expressions.py | 8 +- pytest.ini | 3 +- tests/system/pipeline_e2e.yaml | 3303 ----------------- tests/system/pipeline_e2e/aggregates.yaml | 269 ++ tests/system/pipeline_e2e/array.yaml | 388 ++ tests/system/pipeline_e2e/data.yaml | 142 + tests/system/pipeline_e2e/date_and_time.yaml | 103 + tests/system/pipeline_e2e/general.yaml | 784 ++++ tests/system/pipeline_e2e/logical.yaml | 673 ++++ tests/system/pipeline_e2e/map.yaml | 269 ++ tests/system/pipeline_e2e/math.yaml | 309 ++ tests/system/pipeline_e2e/string.yaml | 654 ++++ tests/system/pipeline_e2e/vector.yaml | 160 + tests/system/test__helpers.py | 3 +- tests/system/test_pipeline_acceptance.py | 50 +- tests/system/test_system.py | 18 +- tests/system/test_system_async.py | 12 + tests/unit/v1/test_aggregation.py | 33 +- tests/unit/v1/test_async_aggregation.py | 15 +- tests/unit/v1/test_async_pipeline.py | 4 +- tests/unit/v1/test_base_query.py | 12 +- tests/unit/v1/test_pipeline.py | 4 +- tests/unit/v1/test_pipeline_expressions.py | 24 + tests/unit/v1/test_pipeline_stages.py | 26 +- 25 files changed, 3910 insertions(+), 3384 deletions(-) delete mode 100644 tests/system/pipeline_e2e.yaml create mode 100644 tests/system/pipeline_e2e/aggregates.yaml create mode 100644 tests/system/pipeline_e2e/array.yaml create mode 100644 tests/system/pipeline_e2e/data.yaml create mode 100644 tests/system/pipeline_e2e/date_and_time.yaml create mode 100644 tests/system/pipeline_e2e/general.yaml create mode 100644 tests/system/pipeline_e2e/logical.yaml create mode 100644 tests/system/pipeline_e2e/map.yaml create mode 100644 tests/system/pipeline_e2e/math.yaml create mode 100644 tests/system/pipeline_e2e/string.yaml create mode 100644 tests/system/pipeline_e2e/vector.yaml diff --git a/google/cloud/firestore_v1/_pipeline_stages.py b/google/cloud/firestore_v1/_pipeline_stages.py index c63b748ac..62503404e 100644 --- a/google/cloud/firestore_v1/_pipeline_stages.py +++ b/google/cloud/firestore_v1/_pipeline_stages.py @@ -112,11 +112,13 @@ class UnnestOptions: storing the original 0-based index of the element within the array. """ - def __init__(self, index_field: str): - self.index_field = index_field + def __init__(self, index_field: Field | str): + self.index_field = ( + index_field if isinstance(index_field, Field) else Field.of(index_field) + ) def __repr__(self): - return f"{self.__class__.__name__}(index_field={self.index_field!r})" + return f"{self.__class__.__name__}(index_field={self.index_field.path!r})" class Stage(ABC): @@ -258,13 +260,7 @@ def of(*documents: "BaseDocumentReference") -> "Documents": return Documents(*doc_paths) def _pb_args(self): - return [ - Value( - array_value={ - "values": [Value(string_value=path) for path in self.paths] - } - ) - ] + return [Value(reference_value=path) for path in self.paths] class FindNearest(Stage): @@ -306,15 +302,23 @@ def _pb_options(self) -> dict[str, Value]: class GenericStage(Stage): """Represents a generic, named stage with parameters.""" - def __init__(self, name: str, *params: Expr | Value): + def __init__( + self, name: str, *params: Expr | Value, options: dict[str, Expr | Value] = {} + ): super().__init__(name) self.params: list[Value] = [ p._to_pb() if isinstance(p, Expr) else p for p in params ] + self.options: dict[str, Value] = { + k: v._to_pb() if isinstance(v, Expr) else v for k, v in options.items() + } def _pb_args(self): return self.params + def _pb_options(self): + return self.options + def __repr__(self): return f"{self.__class__.__name__}(name='{self.name}')" @@ -437,7 +441,7 @@ def _pb_args(self): def _pb_options(self): options = {} if self.options is not None: - options["index_field"] = Value(string_value=self.options.index_field) + options["index_field"] = self.options.index_field._to_pb() return options diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index b113e2874..30f3de995 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -112,8 +112,6 @@ def _to_pb(self) -> Value: @staticmethod def _cast_to_expr_or_convert_to_constant(o: Any, include_vector=False) -> "Expr": """Convert arbitrary object to an Expr.""" - if isinstance(o, Constant) and isinstance(o.value, list): - o = o.value if isinstance(o, Expr): return o if isinstance(o, dict): @@ -143,7 +141,7 @@ def __init__(self, instance_func): def static_func(self, first_arg, *other_args, **kwargs): if not isinstance(first_arg, (Expr, str)): raise TypeError( - f"`expressions must be called on an Expr or a string representing a field name. got {type(first_arg)}." + f"'{self.instance_func.__name__}' must be called on an Expression or a string representing a field. got {type(first_arg)}." ) first_expr = ( Field.of(first_arg) if not isinstance(first_arg, Expr) else first_arg @@ -152,7 +150,7 @@ def static_func(self, first_arg, *other_args, **kwargs): def __get__(self, instance, owner): if instance is None: - return self.static_func.__get__(instance, owner) + return self.static_func else: return self.instance_func.__get__(instance, owner) @@ -1280,7 +1278,7 @@ def map_get(self, key: str | Constant[str]) -> "Expr": @expose_as_static def map_remove(self, key: str | Constant[str]) -> "Expr": - """Remove a key from the map produced by evaluating this expression. + """Remove a key from a the map produced by evaluating this expression. Example: >>> Map({"city": "London"}).map_remove("city") diff --git a/pytest.ini b/pytest.ini index eac8ea123..308d1b494 100644 --- a/pytest.ini +++ b/pytest.ini @@ -22,4 +22,5 @@ filterwarnings = ignore:.*The \`credentials_file\` argument is deprecated.*:DeprecationWarning # Remove after updating test dependencies that use asyncio.iscoroutinefunction ignore:.*\'asyncio.iscoroutinefunction\' is deprecated.*:DeprecationWarning - ignore:.*\'asyncio.get_event_loop_policy\' is deprecated.*:DeprecationWarning \ No newline at end of file + ignore:.*\'asyncio.get_event_loop_policy\' is deprecated.*:DeprecationWarning + ignore:.*Please upgrade to the latest Python version.*:FutureWarning diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml deleted file mode 100644 index 38595224a..000000000 --- a/tests/system/pipeline_e2e.yaml +++ /dev/null @@ -1,3303 +0,0 @@ -data: - books: - book1: - title: "The Hitchhiker's Guide to the Galaxy" - author: "Douglas Adams" - genre: "Science Fiction" - published: 1979 - rating: 4.2 - tags: - - comedy - - space - - adventure - awards: - hugo: true - nebula: false - book2: - title: "Pride and Prejudice" - author: "Jane Austen" - genre: "Romance" - published: 1813 - rating: 4.5 - tags: - - classic - - social commentary - - love - awards: - none: true - book3: - title: "One Hundred Years of Solitude" - author: "Gabriel García Márquez" - genre: "Magical Realism" - published: 1967 - rating: 4.3 - tags: - - family - - history - - fantasy - awards: - nobel: true - nebula: false - book4: - title: "The Lord of the Rings" - author: "J.R.R. Tolkien" - genre: "Fantasy" - published: 1954 - rating: 4.7 - tags: - - adventure - - magic - - epic - awards: - hugo: false - nebula: false - book5: - title: "The Handmaid's Tale" - author: "Margaret Atwood" - genre: "Dystopian" - published: 1985 - rating: 4.1 - tags: - - feminism - - totalitarianism - - resistance - awards: - arthur c. clarke: true - booker prize: false - book6: - title: "Crime and Punishment" - author: "Fyodor Dostoevsky" - genre: "Psychological Thriller" - published: 1866 - rating: 4.3 - tags: - - philosophy - - crime - - redemption - awards: - none: true - book7: - title: "To Kill a Mockingbird" - author: "Harper Lee" - genre: "Southern Gothic" - published: 1960 - rating: 4.2 - tags: - - racism - - injustice - - coming-of-age - awards: - pulitzer: true - book8: - title: "1984" - author: "George Orwell" - genre: "Dystopian" - published: 1949 - rating: 4.2 - tags: - - surveillance - - totalitarianism - - propaganda - awards: - prometheus: true - book9: - title: "The Great Gatsby" - author: "F. Scott Fitzgerald" - genre: "Modernist" - published: 1925 - rating: 4.0 - tags: - - wealth - - american dream - - love - awards: - none: true - book10: - title: "Dune" - author: "Frank Herbert" - genre: "Science Fiction" - published: 1965 - rating: 4.6 - tags: - - politics - - desert - - ecology - awards: - hugo: true - nebula: true - timestamps: - ts1: - time: "1993-04-28T12:01:00.654321+00:00" - micros: 735998460654321 - millis: 735998460654 - seconds: 735998460 - vectors: - vec1: - embedding: [1.0, 2.0, 3.0] - vec2: - embedding: [4.0, 5.0, 6.0, 7.0] - vec3: - embedding: [5.0, 6.0, 7.0] - vec4: - embedding: [1.0, 2.0, 4.0] -tests: - - description: "testAggregates - count" - pipeline: - - Collection: books - - Aggregate: - - AliasedExpr: - - Expr.count: - - Field: rating - - "count" - assert_results: - - count: 10 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - count: - functionValue: - name: count - args: - - fieldReferenceValue: rating - - mapValue: {} - name: aggregate - - description: "testAggregates - count_if" - pipeline: - - Collection: books - - Aggregate: - - AliasedExpr: - - Expr.count_if: - - Expr.greater_than: - - Field: rating - - Constant: 4.2 - - "count_if_rating_gt_4_2" - assert_results: - - count_if_rating_gt_4_2: 5 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - count_if_rating_gt_4_2: - functionValue: - name: count_if - args: - - functionValue: - name: greater_than - args: - - fieldReferenceValue: rating - - doubleValue: 4.2 - - mapValue: {} - name: aggregate - - description: "testAggregates - count_distinct" - pipeline: - - Collection: books - - Aggregate: - - AliasedExpr: - - Expr.count_distinct: - - Field: genre - - "distinct_genres" - assert_results: - - distinct_genres: 8 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - distinct_genres: - functionValue: - name: count_distinct - args: - - fieldReferenceValue: genre - - mapValue: {} - name: aggregate - - description: "testAggregates - avg, count, max" - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: genre - - Constant: Science Fiction - - Aggregate: - - AliasedExpr: - - Expr.count: - - Field: rating - - "count" - - AliasedExpr: - - Expr.average: - - Field: rating - - "avg_rating" - - AliasedExpr: - - Expr.maximum: - - Field: rating - - "max_rating" - assert_results: - - count: 2 - avg_rating: 4.4 - max_rating: 4.6 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: genre - - stringValue: Science Fiction - name: equal - name: where - - args: - - mapValue: - fields: - avg_rating: - functionValue: - args: - - fieldReferenceValue: rating - name: average - count: - functionValue: - name: count - args: - - fieldReferenceValue: rating - max_rating: - functionValue: - args: - - fieldReferenceValue: rating - name: maximum - - mapValue: {} - name: aggregate - - description: testGroupBysWithoutAccumulators - pipeline: - - Collection: books - - Where: - - Expr.less_than: - - Field: published - - Constant: 1900 - - Aggregate: - accumulators: [] - groups: [genre] - assert_error: ".* requires at least one accumulator" - - description: testGroupBysAndAggregate - pipeline: - - Collection: books - - Where: - - Expr.less_than: - - Field: published - - Constant: 1984 - - Aggregate: - accumulators: - - AliasedExpr: - - Expr.average: - - Field: rating - - "avg_rating" - groups: [genre] - - Where: - - Expr.greater_than: - - Field: avg_rating - - Constant: 4.3 - - Sort: - - Ordering: - - Field: avg_rating - - ASCENDING - assert_results: - - avg_rating: 4.4 - genre: Science Fiction - - avg_rating: 4.5 - genre: Romance - - avg_rating: 4.7 - genre: Fantasy - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: published - - integerValue: '1984' - name: less_than - name: where - - args: - - mapValue: - fields: - avg_rating: - functionValue: - args: - - fieldReferenceValue: rating - name: average - - mapValue: - fields: - genre: - fieldReferenceValue: genre - name: aggregate - - args: - - functionValue: - args: - - fieldReferenceValue: avg_rating - - doubleValue: 4.3 - name: greater_than - name: where - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: avg_rating - name: sort - - description: testMinMax - pipeline: - - Collection: books - - Aggregate: - - AliasedExpr: - - Expr.count: - - Field: rating - - "count" - - AliasedExpr: - - Expr.maximum: - - Field: rating - - "max_rating" - - AliasedExpr: - - Expr.minimum: - - Field: published - - "min_published" - assert_results: - - count: 10 - max_rating: 4.7 - min_published: 1813 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - count: - functionValue: - args: - - fieldReferenceValue: rating - name: count - max_rating: - functionValue: - args: - - fieldReferenceValue: rating - name: maximum - min_published: - functionValue: - args: - - fieldReferenceValue: published - name: minimum - - mapValue: {} - name: aggregate - - description: selectSpecificFields - pipeline: - - Collection: books - - Select: - - title - - author - - Sort: - - Ordering: - - Field: author - - ASCENDING - assert_results: - - title: "The Hitchhiker's Guide to the Galaxy" - author: "Douglas Adams" - - title: "The Great Gatsby" - author: "F. Scott Fitzgerald" - - title: "Dune" - author: "Frank Herbert" - - title: "Crime and Punishment" - author: "Fyodor Dostoevsky" - - title: "One Hundred Years of Solitude" - author: "Gabriel García Márquez" - - title: "1984" - author: "George Orwell" - - title: "To Kill a Mockingbird" - author: "Harper Lee" - - title: "The Lord of the Rings" - author: "J.R.R. Tolkien" - - title: "Pride and Prejudice" - author: "Jane Austen" - - title: "The Handmaid's Tale" - author: "Margaret Atwood" - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - author: - fieldReferenceValue: author - title: - fieldReferenceValue: title - name: select - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: author - name: sort - - description: addAndRemoveFields - pipeline: - - Collection: books - - AddFields: - - AliasedExpr: - - Expr.string_concat: - - Field: author - - Constant: _ - - Field: title - - "author_title" - - AliasedExpr: - - Expr.string_concat: - - Field: title - - Constant: _ - - Field: author - - "title_author" - - RemoveFields: - - title_author - - tags - - awards - - rating - - title - - Field: published - - Field: genre - - Field: nestedField # Field does not exist, should be ignored - - Sort: - - Ordering: - - Field: author_title - - ASCENDING - assert_results: - - author: Douglas Adams - author_title: Douglas Adams_The Hitchhiker's Guide to the Galaxy - - author: F. Scott Fitzgerald - author_title: F. Scott Fitzgerald_The Great Gatsby - - author: Frank Herbert - author_title: Frank Herbert_Dune - - author: Fyodor Dostoevsky - author_title: Fyodor Dostoevsky_Crime and Punishment - - author: Gabriel García Márquez - author_title: Gabriel García Márquez_One Hundred Years of Solitude - - author: George Orwell - author_title: George Orwell_1984 - - author: Harper Lee - author_title: Harper Lee_To Kill a Mockingbird - - author: J.R.R. Tolkien - author_title: J.R.R. Tolkien_The Lord of the Rings - - author: Jane Austen - author_title: Jane Austen_Pride and Prejudice - - author: Margaret Atwood - author_title: Margaret Atwood_The Handmaid's Tale - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - author_title: - functionValue: - args: - - fieldReferenceValue: author - - stringValue: _ - - fieldReferenceValue: title - name: string_concat - title_author: - functionValue: - args: - - fieldReferenceValue: title - - stringValue: _ - - fieldReferenceValue: author - name: string_concat - name: add_fields - - args: - - fieldReferenceValue: title_author - - fieldReferenceValue: tags - - fieldReferenceValue: awards - - fieldReferenceValue: rating - - fieldReferenceValue: title - - fieldReferenceValue: published - - fieldReferenceValue: genre - - fieldReferenceValue: nestedField - name: remove_fields - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: author_title - name: sort - - description: whereByMultipleConditions - pipeline: - - Collection: books - - Where: - - And: - - Expr.greater_than: - - Field: rating - - Constant: 4.5 - - Expr.equal: - - Field: genre - - Constant: Science Fiction - assert_results: - - title: Dune - author: Frank Herbert - genre: Science Fiction - published: 1965 - rating: 4.6 - tags: - - politics - - desert - - ecology - awards: - hugo: true - nebula: true - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - functionValue: - args: - - fieldReferenceValue: rating - - doubleValue: 4.5 - name: greater_than - - functionValue: - args: - - fieldReferenceValue: genre - - stringValue: Science Fiction - name: equal - name: and - name: where - - description: whereByOrCondition - pipeline: - - Collection: books - - Where: - - Or: - - Expr.equal: - - Field: genre - - Constant: Romance - - Expr.equal: - - Field: genre - - Constant: Dystopian - - Select: - - title - - Sort: - - Ordering: - - Field: title - - ASCENDING - assert_results: - - title: "1984" - - title: Pride and Prejudice - - title: The Handmaid's Tale - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - functionValue: - args: - - fieldReferenceValue: genre - - stringValue: Romance - name: equal - - functionValue: - args: - - fieldReferenceValue: genre - - stringValue: Dystopian - name: equal - name: or - name: where - - args: - - mapValue: - fields: - title: - fieldReferenceValue: title - name: select - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: title - name: sort - - description: testPipelineWithOffsetAndLimit - pipeline: - - Collection: books - - Sort: - - Ordering: - - Field: author - - ASCENDING - - Offset: 5 - - Limit: 3 - - Select: - - title - - author - assert_results: - - title: "1984" - author: George Orwell - - title: To Kill a Mockingbird - author: Harper Lee - - title: The Lord of the Rings - author: J.R.R. Tolkien - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: author - name: sort - - args: - - integerValue: '5' - name: offset - - args: - - integerValue: '3' - name: limit - - args: - - mapValue: - fields: - author: - fieldReferenceValue: author - title: - fieldReferenceValue: title - name: select - - description: testArrayContains - pipeline: - - Collection: books - - Where: - - Expr.array_contains: - - Field: tags - - Constant: comedy - assert_results: - - title: The Hitchhiker's Guide to the Galaxy - author: Douglas Adams - awards: - hugo: true - nebula: false - genre: Science Fiction - published: 1979 - rating: 4.2 - tags: ["comedy", "space", "adventure"] - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: tags - - stringValue: comedy - name: array_contains - name: where - - description: testArrayContainsAny - pipeline: - - Collection: books - - Where: - - Expr.array_contains_any: - - Field: tags - - - Constant: comedy - - Constant: classic - - Select: - - title - - Sort: - - Ordering: - - Field: title - - ASCENDING - assert_results: - - title: Pride and Prejudice - - title: The Hitchhiker's Guide to the Galaxy - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: tags - - functionValue: - args: - - stringValue: comedy - - stringValue: classic - name: array - name: array_contains_any - name: where - - args: - - mapValue: - fields: - title: - fieldReferenceValue: title - name: select - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: title - name: sort - - description: testArrayContainsAll - pipeline: - - Collection: books - - Where: - - Expr.array_contains_all: - - Field: tags - - - Constant: adventure - - Constant: magic - - Select: - - title - assert_results: - - title: The Lord of the Rings - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: tags - - functionValue: - args: - - stringValue: adventure - - stringValue: magic - name: array - name: array_contains_all - name: where - - args: - - mapValue: - fields: - title: - fieldReferenceValue: title - name: select - - description: testArrayLength - pipeline: - - Collection: books - - Select: - - AliasedExpr: - - Expr.array_length: - - Field: tags - - "tagsCount" - - Where: - - Expr.equal: - - Field: tagsCount - - Constant: 3 - assert_results: # All documents have 3 tags - - tagsCount: 3 - - tagsCount: 3 - - tagsCount: 3 - - tagsCount: 3 - - tagsCount: 3 - - tagsCount: 3 - - tagsCount: 3 - - tagsCount: 3 - - tagsCount: 3 - - tagsCount: 3 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - tagsCount: - functionValue: - args: - - fieldReferenceValue: tags - name: array_length - name: select - - args: - - functionValue: - args: - - fieldReferenceValue: tagsCount - - integerValue: '3' - name: equal - name: where - - description: testStringConcat - pipeline: - - Collection: books - - Sort: - - Ordering: - - Field: author - - ASCENDING - - Select: - - AliasedExpr: - - Expr.string_concat: - - Field: author - - Constant: " - " - - Field: title - - "bookInfo" - - Limit: 1 - assert_results: - - bookInfo: Douglas Adams - The Hitchhiker's Guide to the Galaxy - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: author - name: sort - - args: - - mapValue: - fields: - bookInfo: - functionValue: - args: - - fieldReferenceValue: author - - stringValue: ' - ' - - fieldReferenceValue: title - name: string_concat - name: select - - args: - - integerValue: '1' - name: limit - - description: testStartsWith - pipeline: - - Collection: books - - Where: - - Expr.starts_with: - - Field: title - - Constant: The - - Select: - - title - - Sort: - - Ordering: - - Field: title - - ASCENDING - assert_results: - - title: The Great Gatsby - - title: The Handmaid's Tale - - title: The Hitchhiker's Guide to the Galaxy - - title: The Lord of the Rings - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: The - name: starts_with - name: where - - args: - - mapValue: - fields: - title: - fieldReferenceValue: title - name: select - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: title - name: sort - - description: testEndsWith - pipeline: - - Collection: books - - Where: - - Expr.ends_with: - - Field: title - - Constant: y - - Select: - - title - - Sort: - - Ordering: - - Field: title - - DESCENDING - assert_results: - - title: The Hitchhiker's Guide to the Galaxy - - title: The Great Gatsby - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: y - name: ends_with - name: where - - args: - - mapValue: - fields: - title: - fieldReferenceValue: title - name: select - - args: - - mapValue: - fields: - direction: - stringValue: descending - expression: - fieldReferenceValue: title - name: sort - - description: testConcat - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: title - - Constant: "The Hitchhiker's Guide to the Galaxy" - - Select: - - AliasedExpr: - - Expr.concat: - - Field: author - - Constant: ": " - - Field: title - - "author_title" - - AliasedExpr: - - Expr.concat: - - Field: tags - - - Constant: "new_tag" - - "concatenatedTags" - assert_results: - - author_title: "Douglas Adams: The Hitchhiker's Guide to the Galaxy" - concatenatedTags: - - comedy - - space - - adventure - - new_tag - - description: testLength - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: title - - Constant: "The Hitchhiker's Guide to the Galaxy" - - Select: - - AliasedExpr: - - Expr.length: - - Field: title - - "titleLength" - - AliasedExpr: - - Expr.length: - - Field: tags - - "tagsLength" - - AliasedExpr: - - Expr.length: - - Field: awards - - "awardsLength" - assert_results: - - titleLength: 36 - tagsLength: 3 - awardsLength: 2 - - description: testCharLength - pipeline: - - Collection: books - - Select: - - AliasedExpr: - - Expr.char_length: - - Field: title - - "titleLength" - - title - - Where: - - Expr.greater_than: - - Field: titleLength - - Constant: 20 - - Sort: - - Ordering: - - Field: title - - ASCENDING - assert_results: - - titleLength: 29 - title: One Hundred Years of Solitude - - titleLength: 36 - title: The Hitchhiker's Guide to the Galaxy - - titleLength: 21 - title: The Lord of the Rings - - titleLength: 21 - title: To Kill a Mockingbird - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - title: - fieldReferenceValue: title - titleLength: - functionValue: - args: - - fieldReferenceValue: title - name: char_length - name: select - - args: - - functionValue: - args: - - fieldReferenceValue: titleLength - - integerValue: '20' - name: greater_than - name: where - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: title - name: sort - - description: testStringFunctions - CharLength - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: author - - Constant: "Douglas Adams" - - Select: - - AliasedExpr: - - Expr.char_length: - - Field: title - - "title_length" - assert_results: - - title_length: 36 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: author - - stringValue: Douglas Adams - name: equal - name: where - - args: - - mapValue: - fields: - title_length: - functionValue: - args: - - fieldReferenceValue: title - name: char_length - name: select - - description: testStringFunctions - ByteLength - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: author - - Constant: Douglas Adams - - Select: - - AliasedExpr: - - Expr.byte_length: - - Expr.string_concat: - - Field: title - - Constant: _银河系漫游指南 - - "title_byte_length" - assert_results: - - title_byte_length: 58 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: author - - stringValue: Douglas Adams - name: equal - name: where - - args: - - mapValue: - fields: - title_byte_length: - functionValue: - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: "_\u94F6\u6CB3\u7CFB\u6F2B\u6E38\u6307\u5357" - name: string_concat - name: byte_length - name: select - - description: testLike - pipeline: - - Collection: books - - Where: - - Expr.like: - - Field: title - - Constant: "%Guide%" - - Select: - - title - assert_results: - - title: The Hitchhiker's Guide to the Galaxy - - description: testRegexContains - # Find titles that contain either "the" or "of" (case-insensitive) - pipeline: - - Collection: books - - Where: - - Expr.regex_contains: - - Field: title - - Constant: "(?i)(the|of)" - assert_count: 5 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: "(?i)(the|of)" - name: regex_contains - name: where - - description: testRegexMatches - # Find titles that contain either "the" or "of" (case-insensitive) - pipeline: - - Collection: books - - Where: - - Expr.regex_match: - - Field: title - - Constant: ".*(?i)(the|of).*" - assert_count: 5 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: ".*(?i)(the|of).*" - name: regex_match - name: where - - description: testArithmeticOperations - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: title - - Constant: To Kill a Mockingbird - - Select: - - AliasedExpr: - - Expr.add: - - Field: rating - - Constant: 1 - - "ratingPlusOne" - - AliasedExpr: - - Expr.subtract: - - Field: published - - Constant: 1900 - - "yearsSince1900" - - AliasedExpr: - - Expr.multiply: - - Field: rating - - Constant: 10 - - "ratingTimesTen" - - AliasedExpr: - - Expr.divide: - - Field: rating - - Constant: 2 - - "ratingDividedByTwo" - - AliasedExpr: - - Expr.multiply: - - Field: rating - - Constant: 20 - - "ratingTimes20" - - AliasedExpr: - - Expr.add: - - Field: rating - - Constant: 3 - - "ratingPlus3" - - AliasedExpr: - - Expr.mod: - - Field: rating - - Constant: 2 - - "ratingMod2" - assert_results: - - ratingPlusOne: 5.2 - yearsSince1900: 60 - ratingTimesTen: 42.0 - ratingDividedByTwo: 2.1 - ratingTimes20: 84 - ratingPlus3: 7.2 - ratingMod2: 0.20000000000000018 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: To Kill a Mockingbird - name: equal - name: where - - args: - - mapValue: - fields: - ratingDividedByTwo: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '2' - name: divide - ratingPlusOne: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '1' - name: add - ratingTimesTen: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '10' - name: multiply - yearsSince1900: - functionValue: - args: - - fieldReferenceValue: published - - integerValue: '1900' - name: subtract - ratingTimes20: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '20' - name: multiply - ratingPlus3: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '3' - name: add - ratingMod2: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '2' - name: mod - name: select - - description: testComparisonOperators - pipeline: - - Collection: books - - Where: - - And: - - Expr.greater_than: - - Field: rating - - Constant: 4.2 - - Expr.less_than_or_equal: - - Field: rating - - Constant: 4.5 - - Expr.not_equal: - - Field: genre - - Constant: Science Fiction - - Select: - - rating - - title - - Sort: - - Ordering: - - title - - ASCENDING - assert_results: - - rating: 4.3 - title: Crime and Punishment - - rating: 4.3 - title: One Hundred Years of Solitude - - rating: 4.5 - title: Pride and Prejudice - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - functionValue: - args: - - fieldReferenceValue: rating - - doubleValue: 4.2 - name: greater_than - - functionValue: - args: - - fieldReferenceValue: rating - - doubleValue: 4.5 - name: less_than_or_equal - - functionValue: - args: - - fieldReferenceValue: genre - - stringValue: Science Fiction - name: not_equal - name: and - name: where - - args: - - mapValue: - fields: - rating: - fieldReferenceValue: rating - title: - fieldReferenceValue: title - name: select - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: title - name: sort - - description: testLogicalOperators - pipeline: - - Collection: books - - Where: - - Or: - - And: - - Expr.greater_than: - - Field: rating - - Constant: 4.5 - - Expr.equal: - - Field: genre - - Constant: Science Fiction - - Expr.less_than: - - Field: published - - Constant: 1900 - - Select: - - title - - Sort: - - Ordering: - - Field: title - - ASCENDING - assert_results: - - title: Crime and Punishment - - title: Dune - - title: Pride and Prejudice - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - functionValue: - args: - - functionValue: - args: - - fieldReferenceValue: rating - - doubleValue: 4.5 - name: greater_than - - functionValue: - args: - - fieldReferenceValue: genre - - stringValue: Science Fiction - name: equal - name: and - - functionValue: - args: - - fieldReferenceValue: published - - integerValue: '1900' - name: less_than - name: or - name: where - - args: - - mapValue: - fields: - title: - fieldReferenceValue: title - name: select - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: title - name: sort - - description: testChecks - pipeline: - - Collection: books - - Where: - - Not: - - Expr.is_nan: - - Field: rating - - Select: - - AliasedExpr: - - Not: - - Expr.is_nan: - - Field: rating - - "ratingIsNotNaN" - - Limit: 1 - assert_results: - - ratingIsNotNaN: true - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - functionValue: - args: - - fieldReferenceValue: rating - name: is_nan - name: not - name: where - - args: - - mapValue: - fields: - ratingIsNotNaN: - functionValue: - args: - - functionValue: - args: - - fieldReferenceValue: rating - name: is_nan - name: not - name: select - - args: - - integerValue: '1' - name: limit - - description: testIsNotNull - pipeline: - - Collection: books - - Where: - - Expr.is_not_null: - - Field: rating - assert_count: 10 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: rating - name: is_not_null - name: where - - description: testIsNotNaN - pipeline: - - Collection: books - - Where: - - Expr.is_not_nan: - - Field: rating - assert_count: 10 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: rating - name: is_not_nan - name: where - - description: testIsAbsent - pipeline: - - Collection: books - - Where: - - Expr.is_absent: - - Field: awards.pulitzer - assert_count: 9 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: awards.pulitzer - name: is_absent - name: where - - description: testIfAbsent - pipeline: - - Collection: books - - Select: - - AliasedExpr: - - Expr.if_absent: - - Field: awards.pulitzer - - Constant: false - - "pulitzer_award" - - title - - Where: - - Expr.equal: - - Field: pulitzer_award - - Constant: true - assert_results: - - pulitzer_award: true - title: To Kill a Mockingbird - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - pulitzer_award: - functionValue: - name: if_absent - args: - - fieldReferenceValue: awards.pulitzer - - booleanValue: false - title: - fieldReferenceValue: title - name: select - - args: - - functionValue: - args: - - fieldReferenceValue: pulitzer_award - - booleanValue: true - name: equal - name: where - - description: testIsError - pipeline: - - Collection: books - - Select: - - AliasedExpr: - - Expr.is_error: - - Expr.divide: - - Field: rating - - Constant: "string" - - "is_error_result" - - Limit: 1 - assert_results: - - is_error_result: true - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - is_error_result: - functionValue: - name: is_error - args: - - functionValue: - name: divide - args: - - fieldReferenceValue: rating - - stringValue: "string" - name: select - - args: - - integerValue: '1' - name: limit - - description: testIfError - pipeline: - - Collection: books - - Select: - - AliasedExpr: - - Expr.if_error: - - Expr.divide: - - Field: rating - - Field: genre - - Constant: "An error occurred" - - "if_error_result" - - Limit: 1 - assert_results: - - if_error_result: "An error occurred" - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - if_error_result: - functionValue: - name: if_error - args: - - functionValue: - name: divide - args: - - fieldReferenceValue: rating - - fieldReferenceValue: genre - - stringValue: "An error occurred" - name: select - - args: - - integerValue: '1' - name: limit - - description: testLogicalMinMax - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: author - - Constant: Douglas Adams - - Select: - - AliasedExpr: - - Expr.logical_maximum: - - Field: rating - - Constant: 4.5 - - "max_rating" - - AliasedExpr: - - Expr.logical_minimum: - - Field: published - - Constant: 1900 - - "min_published" - assert_results: - - max_rating: 4.5 - min_published: 1900 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: author - - stringValue: Douglas Adams - name: equal - name: where - - args: - - mapValue: - fields: - min_published: - functionValue: - args: - - fieldReferenceValue: published - - integerValue: '1900' - name: minimum - max_rating: - functionValue: - args: - - fieldReferenceValue: rating - - doubleValue: 4.5 - name: maximum - name: select - - description: testMapGet - pipeline: - - Collection: books - - Sort: - - Ordering: - - Field: published - - DESCENDING - - Select: - - AliasedExpr: - - Expr.map_get: - - Field: awards - - hugo - - "hugoAward" - - Field: title - - Where: - - Expr.equal: - - Field: hugoAward - - Constant: true - assert_results: - - hugoAward: true - title: The Hitchhiker's Guide to the Galaxy - - hugoAward: true - title: Dune - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - direction: - stringValue: descending - expression: - fieldReferenceValue: published - name: sort - - args: - - mapValue: - fields: - hugoAward: - functionValue: - args: - - fieldReferenceValue: awards - - stringValue: hugo - name: map_get - title: - fieldReferenceValue: title - name: select - - args: - - functionValue: - args: - - fieldReferenceValue: hugoAward - - booleanValue: true - name: equal - name: where - - description: testMapGetWithField - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: title - - Constant: "Dune" - - AddFields: - - AliasedExpr: - - Constant: "hugo" - - "award_name" - - Select: - - AliasedExpr: - - Expr.map_get: - - Field: awards - - Field: award_name - - "hugoAward" - - Field: title - assert_results: - - hugoAward: true - title: Dune - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: "Dune" - name: equal - name: where - - args: - - mapValue: - fields: - award_name: - stringValue: "hugo" - name: add_fields - - args: - - mapValue: - fields: - hugoAward: - functionValue: - name: map_get - args: - - fieldReferenceValue: awards - - fieldReferenceValue: award_name - title: - fieldReferenceValue: title - name: select - - description: testMapRemove - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: title - - Constant: "Dune" - - Select: - - AliasedExpr: - - Expr.map_remove: - - Field: awards - - "nebula" - - "awards_removed" - assert_results: - - awards_removed: - hugo: true - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: "Dune" - name: equal - name: where - - args: - - mapValue: - fields: - awards_removed: - functionValue: - name: map_remove - args: - - fieldReferenceValue: awards - - stringValue: "nebula" - name: select - - description: testMapMerge - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: title - - Constant: "Dune" - - Select: - - AliasedExpr: - - Expr.map_merge: - - Field: awards - - Map: - elements: {"new_award": true, "hugo": false} - - Map: - elements: {"another_award": "yes"} - - "awards_merged" - assert_results: - - awards_merged: - hugo: false - nebula: true - new_award: true - another_award: "yes" - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: "Dune" - name: equal - name: where - - args: - - mapValue: - fields: - awards_merged: - functionValue: - name: map_merge - args: - - fieldReferenceValue: awards - - functionValue: - name: map - args: - - stringValue: "new_award" - - booleanValue: true - - stringValue: "hugo" - - booleanValue: false - - functionValue: - name: map - args: - - stringValue: "another_award" - - stringValue: "yes" - name: select - - description: testNestedFields - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: awards.hugo - - Constant: true - - Sort: - - Ordering: - - Field: title - - DESCENDING - - Select: - - title - - Field: awards.hugo - assert_results: - - title: The Hitchhiker's Guide to the Galaxy - awards.hugo: true - - title: Dune - awards.hugo: true - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: awards.hugo - - booleanValue: true - name: equal - name: where - - args: - - mapValue: - fields: - direction: - stringValue: descending - expression: - fieldReferenceValue: title - name: sort - - args: - - mapValue: - fields: - awards.hugo: - fieldReferenceValue: awards.hugo - title: - fieldReferenceValue: title - name: select - - description: testSampleLimit - pipeline: - - Collection: books - - Sample: 3 - assert_count: 3 # Results will vary due to randomness - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - integerValue: '3' - - stringValue: documents - name: sample - - description: testSamplePercentage - pipeline: - - Collection: books - - Sample: - - SampleOptions: - - 0.6 - - percent - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - doubleValue: 0.6 - - stringValue: percent - name: sample - - description: testUnion - pipeline: - - Collection: books - - Union: - - Pipeline: - - Collection: books - assert_count: 20 # Results will be duplicated - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - pipelineValue: - stages: - - args: - - referenceValue: /books - name: collection - name: union - - description: testUnnest - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: title - - Constant: The Hitchhiker's Guide to the Galaxy - - Unnest: - - tags - - tags_alias - - Select: tags_alias - assert_results: - - tags_alias: comedy - - tags_alias: space - - tags_alias: adventure - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: The Hitchhiker's Guide to the Galaxy - name: equal - name: where - - args: - - fieldReferenceValue: tags - - fieldReferenceValue: tags_alias - name: unnest - - args: - - mapValue: - fields: - tags_alias: - fieldReferenceValue: tags_alias - name: select - - description: testGreaterThanOrEqual - pipeline: - - Collection: books - - Where: - - Expr.greater_than_or_equal: - - Field: rating - - Constant: 4.6 - - Select: - - title - - rating - - Sort: - - Ordering: - - Field: rating - - ASCENDING - assert_results: - - title: Dune - rating: 4.6 - - title: The Lord of the Rings - rating: 4.7 - - description: testInAndNotIn - pipeline: - - Collection: books - - Where: - - And: - - Expr.equal_any: - - Field: genre - - - Constant: Romance - - Constant: Dystopian - - Expr.not_equal_any: - - Field: author - - - Constant: "George Orwell" - assert_results: - - title: "Pride and Prejudice" - author: "Jane Austen" - genre: "Romance" - published: 1813 - rating: 4.5 - tags: - - classic - - social commentary - - love - awards: - none: true - - title: "The Handmaid's Tale" - author: "Margaret Atwood" - genre: "Dystopian" - published: 1985 - rating: 4.1 - tags: - - feminism - - totalitarianism - - resistance - awards: - "arthur c. clarke": true - "booker prize": false - - description: testArrayReverse - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: title - - Constant: "The Hitchhiker's Guide to the Galaxy" - - Select: - - AliasedExpr: - - Expr.array_reverse: - - Field: tags - - "reversedTags" - assert_results: - - reversedTags: - - adventure - - space - - comedy - - description: testDocumentId - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: title - - Constant: "The Hitchhiker's Guide to the Galaxy" - - Select: - - AliasedExpr: - - Expr.document_id: - - Field: __name__ - - "doc_id" - assert_results: - - doc_id: "book1" - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: "The Hitchhiker's Guide to the Galaxy" - name: equal - name: where - - args: - - mapValue: - fields: - doc_id: - functionValue: - name: document_id - args: - - fieldReferenceValue: __name__ - name: select - - description: testCurrentTimestamp - pipeline: - - Collection: books - - Limit: 1 - - Select: - - AliasedExpr: - - And: - - Expr.greater_than_or_equal: - - CurrentTimestamp: [] - - Expr.unix_seconds_to_timestamp: - - Constant: 1735689600 # 2025-01-01 - - Expr.less_than: - - CurrentTimestamp: [] - - Expr.unix_seconds_to_timestamp: - - Constant: 4892438400 # 2125-01-01 - - "is_between_2025_and_2125" - assert_results: - - is_between_2025_and_2125: true - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - integerValue: '1' - name: limit - - args: - - mapValue: - fields: - is_between_2025_and_2125: - functionValue: - name: and - args: - - functionValue: - name: greater_than_or_equal - args: - - functionValue: - name: current_timestamp - - functionValue: - name: unix_seconds_to_timestamp - args: - - integerValue: '1735689600' - - functionValue: - name: less_than - args: - - functionValue: - name: current_timestamp - - functionValue: - name: unix_seconds_to_timestamp - args: - - integerValue: '4892438400' - name: select - - description: testArrayConcat - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: title - - Constant: "The Hitchhiker's Guide to the Galaxy" - - Select: - - AliasedExpr: - - Expr.array_concat: - - Field: tags - - Constant: ["new_tag", "another_tag"] - - "concatenatedTags" - assert_results: - - concatenatedTags: - - comedy - - space - - adventure - - new_tag - - another_tag - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: "The Hitchhiker's Guide to the Galaxy" - name: equal - name: where - - args: - - mapValue: - fields: - concatenatedTags: - functionValue: - args: - - fieldReferenceValue: tags - - functionValue: - args: - - stringValue: "new_tag" - - stringValue: "another_tag" - name: array - name: array_concat - name: select - - description: testArrayConcatMultiple - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: title - - Constant: "Dune" - - Select: - - AliasedExpr: - - Expr.array_concat: - - Field: tags - - Constant: ["sci-fi"] - - Constant: ["classic", "epic"] - - "concatenatedTags" - assert_results: - - concatenatedTags: - - politics - - desert - - ecology - - sci-fi - - classic - - epic - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: "Dune" - name: equal - name: where - - args: - - mapValue: - fields: - concatenatedTags: - functionValue: - args: - - fieldReferenceValue: tags - - functionValue: - args: - - stringValue: "sci-fi" - name: array - - functionValue: - args: - - stringValue: "classic" - - stringValue: "epic" - name: array - name: array_concat - name: select - - description: testMapMergeLiterals - pipeline: - - Collection: books - - Limit: 1 - - Select: - - AliasedExpr: - - Expr.map_merge: - - Map: - elements: {"a": "orig", "b": "orig"} - - Map: - elements: {"b": "new", "c": "new"} - - "merged" - assert_results: - - merged: - a: "orig" - b: "new" - c: "new" - - description: testArrayContainsAnyWithField - pipeline: - - Collection: books - - AddFields: - - AliasedExpr: - - Expr.array_concat: - - Field: tags - - Array: ["Dystopian"] - - "new_tags" - - Where: - - Expr.array_contains_any: - - Field: new_tags - - - Constant: non_existent_tag - - Field: genre - - Select: - - title - - genre - - Sort: - - Ordering: - - Field: title - - ASCENDING - assert_results: - - title: "1984" - genre: "Dystopian" - - title: "The Handmaid's Tale" - genre: "Dystopian" - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - mapValue: - fields: - new_tags: - functionValue: - args: - - fieldReferenceValue: tags - - functionValue: - args: - - stringValue: "Dystopian" - name: array - name: array_concat - name: add_fields - - args: - - functionValue: - args: - - fieldReferenceValue: new_tags - - functionValue: - args: - - stringValue: "non_existent_tag" - - fieldReferenceValue: genre - name: array - name: array_contains_any - name: where - - args: - - mapValue: - fields: - title: - fieldReferenceValue: title - genre: - fieldReferenceValue: genre - name: select - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: title - name: sort - - description: testArrayConcatLiterals - pipeline: - - Collection: books - - Limit: 1 - - Select: - - AliasedExpr: - - Expr.array_concat: - - Array: [1, 2, 3] - - Array: [4, 5] - - "concatenated" - assert_results: - - concatenated: [1, 2, 3, 4, 5] - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - integerValue: '1' - name: limit - - args: - - mapValue: - fields: - concatenated: - functionValue: - args: - - functionValue: - args: - - integerValue: '1' - - integerValue: '2' - - integerValue: '3' - name: array - - functionValue: - args: - - integerValue: '4' - - integerValue: '5' - name: array - name: array_concat - name: select - - description: testExists - pipeline: - - Collection: books - - Where: - - And: - - Expr.exists: - - Field: awards.pulitzer - - Expr.equal: - - Field: awards.pulitzer - - Constant: true - - Select: - - title - assert_results: - - title: To Kill a Mockingbird - - description: testSum - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: genre - - Constant: Science Fiction - - Aggregate: - - AliasedExpr: - - Expr.sum: - - Field: rating - - "total_rating" - assert_results: - - total_rating: 8.8 - - description: testStringContains - pipeline: - - Collection: books - - Where: - - Expr.string_contains: - - Field: title - - Constant: "Hitchhiker's" - - Select: - - title - assert_results: - - title: "The Hitchhiker's Guide to the Galaxy" - - description: testVectorLength - pipeline: - - Collection: vectors - - Select: - - AliasedExpr: - - Expr.vector_length: - - Field: embedding - - "embedding_length" - - Sort: - - Ordering: - - Field: embedding_length - - ASCENDING - assert_results: - - embedding_length: 3 - - embedding_length: 3 - - embedding_length: 3 - - embedding_length: 4 - - description: testTimestampFunctions - pipeline: - - Collection: timestamps - - Select: - - AliasedExpr: - - Expr.timestamp_to_unix_micros: - - Field: time - - "micros" - - AliasedExpr: - - Expr.timestamp_to_unix_millis: - - Field: time - - "millis" - - AliasedExpr: - - Expr.timestamp_to_unix_seconds: - - Field: time - - "seconds" - - AliasedExpr: - - Expr.unix_micros_to_timestamp: - - Field: micros - - "from_micros" - - AliasedExpr: - - Expr.unix_millis_to_timestamp: - - Field: millis - - "from_millis" - - AliasedExpr: - - Expr.unix_seconds_to_timestamp: - - Field: seconds - - "from_seconds" - - AliasedExpr: - - Expr.timestamp_add: - - Field: time - - Constant: "day" - - Constant: 1 - - "plus_day" - - AliasedExpr: - - Expr.timestamp_subtract: - - Field: time - - Constant: "hour" - - Constant: 1 - - "minus_hour" - assert_results: - - micros: 735998460654321 - millis: 735998460654 - seconds: 735998460 - from_micros: "1993-04-28T12:01:00.654321+00:00" - from_millis: "1993-04-28T12:01:00.654000+00:00" - from_seconds: "1993-04-28T12:01:00.000000+00:00" - plus_day: "1993-04-29T12:01:00.654321+00:00" - minus_hour: "1993-04-28T11:01:00.654321+00:00" - - description: testCollectionId - pipeline: - - Collection: books - - Limit: 1 - - Select: - - AliasedExpr: - - Expr.collection_id: - - Field: __name__ - - "collectionName" - assert_results: - - collectionName: "books" - - description: testXor - pipeline: - - Collection: books - - Where: - - Xor: - - - Expr.equal: - - Field: genre - - Constant: Romance - - Expr.greater_than: - - Field: published - - Constant: 1980 - - Select: - - title - - genre - - published - - Sort: - - Ordering: - - Field: title - - ASCENDING - assert_results: - - title: "Pride and Prejudice" - genre: "Romance" - published: 1813 - - title: "The Handmaid's Tale" - genre: "Dystopian" - published: 1985 - - description: testConditional - pipeline: - - Collection: books - - Select: - - title - - AliasedExpr: - - Conditional: - - Expr.greater_than: - - Field: published - - Constant: 1950 - - Constant: "Modern" - - Constant: "Classic" - - "era" - - Sort: - - Ordering: - - Field: title - - ASCENDING - - Limit: 4 - assert_results: - - title: "1984" - era: "Classic" - - title: "Crime and Punishment" - era: "Classic" - - title: "Dune" - era: "Modern" - - title: "One Hundred Years of Solitude" - era: "Modern" - - description: testFieldToFieldArithmetic - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: title - - Constant: "Dune" - - Select: - - AliasedExpr: - - Expr.add: - - Field: published - - Field: rating - - "pub_plus_rating" - assert_results: - - pub_plus_rating: 1969.6 - - description: testFieldToFieldComparison - pipeline: - - Collection: books - - Where: - - Expr.greater_than: - - Field: published - - Field: rating - - Select: - - title - assert_count: 10 # All books were published after year 4.7 - - description: testExistsNegative - pipeline: - - Collection: books - - Where: - - Expr.exists: - - Field: non_existent_field - assert_count: 0 - - description: testConditionalWithFields - pipeline: - - Collection: books - - Where: - - Expr.equal_any: - - Field: title - - - Constant: "Dune" - - Constant: "1984" - - Select: - - title - - AliasedExpr: - - Conditional: - - Expr.greater_than: - - Field: published - - Constant: 1950 - - Field: author - - Field: genre - - "conditional_field" - - Sort: - - Ordering: - - Field: title - - ASCENDING - assert_results: - - title: "1984" - conditional_field: "Dystopian" - - title: "Dune" - conditional_field: "Frank Herbert" - - description: testFindNearestEuclidean - pipeline: - - Collection: vectors - - FindNearest: - field: embedding - vector: [1.0, 2.0, 3.0] - distance_measure: EUCLIDEAN - options: - FindNearestOptions: - limit: 2 - distance_field: - Field: distance - - Select: - - distance - assert_results: - - distance: 0.0 - - distance: 1.0 - assert_proto: - pipeline: - stages: - - name: collection - args: - - referenceValue: /vectors - - name: find_nearest - args: - - fieldReferenceValue: embedding - - mapValue: - fields: - __type__: - stringValue: __vector__ - value: - arrayValue: - values: - - doubleValue: 1.0 - - doubleValue: 2.0 - - doubleValue: 3.0 - - stringValue: euclidean - options: - limit: - integerValue: '2' - distance_field: - fieldReferenceValue: distance - - name: select - args: - - mapValue: - fields: - distance: - fieldReferenceValue: distance - - description: testMathExpressions - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: title - - Constant: To Kill a Mockingbird - - Select: - - AliasedExpr: - - Expr.abs: - - Field: rating - - "abs_rating" - - AliasedExpr: - - Expr.ceil: - - Field: rating - - "ceil_rating" - - AliasedExpr: - - Expr.exp: - - Field: rating - - "exp_rating" - - AliasedExpr: - - Expr.floor: - - Field: rating - - "floor_rating" - - AliasedExpr: - - Expr.ln: - - Field: rating - - "ln_rating" - - AliasedExpr: - - Expr.log10: - - Field: rating - - "log_rating_base10" - - AliasedExpr: - - Expr.log: - - Field: rating - - Constant: 2 - - "log_rating_base2" - - AliasedExpr: - - Expr.pow: - - Field: rating - - Constant: 2 - - "pow_rating" - - AliasedExpr: - - Expr.sqrt: - - Field: rating - - "sqrt_rating" - assert_results_approximate: - - abs_rating: 4.2 - ceil_rating: 5.0 - exp_rating: 66.686331 - floor_rating: 4.0 - ln_rating: 1.4350845 - log_rating_base10: 0.623249 - log_rating_base2: 2.0704 - pow_rating: 17.64 - sqrt_rating: 2.049390 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: To Kill a Mockingbird - name: equal - name: where - - args: - - mapValue: - fields: - abs_rating: - functionValue: - args: - - fieldReferenceValue: rating - name: abs - ceil_rating: - functionValue: - args: - - fieldReferenceValue: rating - name: ceil - exp_rating: - functionValue: - args: - - fieldReferenceValue: rating - name: exp - floor_rating: - functionValue: - args: - - fieldReferenceValue: rating - name: floor - ln_rating: - functionValue: - args: - - fieldReferenceValue: rating - name: ln - log_rating_base10: - functionValue: - args: - - fieldReferenceValue: rating - name: log10 - log_rating_base2: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '2' - name: log - pow_rating: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '2' - name: pow - sqrt_rating: - functionValue: - args: - - fieldReferenceValue: rating - name: sqrt - name: select - - description: testRoundExpressions - pipeline: - - Collection: books - - Where: - - Expr.equal_any: - - Field: title - - - Constant: "To Kill a Mockingbird" # rating 4.2 - - Constant: "Pride and Prejudice" # rating 4.5 - - Constant: "The Lord of the Rings" # rating 4.7 - - Select: - - title - - AliasedExpr: - - Expr.round: - - Field: rating - - "round_rating" - - Sort: - - Ordering: - - Field: title - - ASCENDING - assert_results: - - title: "Pride and Prejudice" - round_rating: 5.0 - - title: "The Lord of the Rings" - round_rating: 5.0 - - title: "To Kill a Mockingbird" - round_rating: 4.0 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - functionValue: - args: - - stringValue: "To Kill a Mockingbird" - - stringValue: "Pride and Prejudice" - - stringValue: "The Lord of the Rings" - name: array - name: equal_any - name: where - - args: - - mapValue: - fields: - title: - fieldReferenceValue: title - round_rating: - functionValue: - args: - - fieldReferenceValue: rating - name: round - name: select - - args: - - mapValue: - fields: - direction: - stringValue: ascending - expression: - fieldReferenceValue: title - name: sort - - description: testFindNearestDotProduct - pipeline: - - Collection: vectors - - FindNearest: - field: embedding - vector: [1.0, 2.0, 3.0] - distance_measure: DOT_PRODUCT - options: - FindNearestOptions: - limit: 3 - distance_field: - Field: distance - - Select: - - distance - assert_results: - - distance: 38.0 - - distance: 17.0 - - distance: 14.0 - assert_proto: - pipeline: - stages: - - name: collection - args: - - referenceValue: /vectors - - name: find_nearest - args: - - fieldReferenceValue: embedding - - mapValue: - fields: - __type__: - stringValue: __vector__ - value: - arrayValue: - values: - - doubleValue: 1.0 - - doubleValue: 2.0 - - doubleValue: 3.0 - - stringValue: dot_product - options: - limit: - integerValue: '3' - distance_field: - fieldReferenceValue: distance - - name: select - args: - - mapValue: - fields: - distance: - fieldReferenceValue: distance - - description: testDotProductWithConstant - pipeline: - - Collection: vectors - - Where: - - Expr.equal: - - Field: embedding - - Vector: [1.0, 2.0, 3.0] - - Select: - - AliasedExpr: - - Expr.dot_product: - - Field: embedding - - Vector: [1.0, 1.0, 1.0] - - "dot_product_result" - assert_results: - - dot_product_result: 6.0 - - description: testEuclideanDistanceWithConstant - pipeline: - - Collection: vectors - - Where: - - Expr.equal: - - Field: embedding - - Vector: [1.0, 2.0, 3.0] - - Select: - - AliasedExpr: - - Expr.euclidean_distance: - - Field: embedding - - Vector: [1.0, 2.0, 3.0] - - "euclidean_distance_result" - assert_results: - - euclidean_distance_result: 0.0 - - description: testCosineDistanceWithConstant - pipeline: - - Collection: vectors - - Where: - - Expr.equal: - - Field: embedding - - Vector: [1.0, 2.0, 3.0] - - Select: - - AliasedExpr: - - Expr.cosine_distance: - - Field: embedding - - Vector: [1.0, 2.0, 3.0] - - "cosine_distance_result" - assert_results: - - cosine_distance_result: 0.0 - - description: testStringFunctions - ToLower - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: author - - Constant: "Douglas Adams" - - Select: - - AliasedExpr: - - Expr.to_lower: - - Field: title - - "lower_title" - assert_results: - - lower_title: "the hitchhiker's guide to the galaxy" - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: author - - stringValue: Douglas Adams - name: equal - name: where - - args: - - mapValue: - fields: - lower_title: - functionValue: - args: - - fieldReferenceValue: title - name: to_lower - name: select - - description: testStringFunctions - ToUpper - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: author - - Constant: "Douglas Adams" - - Select: - - AliasedExpr: - - Expr.to_upper: - - Field: title - - "upper_title" - assert_results: - - upper_title: "THE HITCHHIKER'S GUIDE TO THE GALAXY" - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: author - - stringValue: Douglas Adams - name: equal - name: where - - args: - - mapValue: - fields: - upper_title: - functionValue: - args: - - fieldReferenceValue: title - name: to_upper - name: select - - description: testStringFunctions - Trim - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: author - - Constant: "Douglas Adams" - - Select: - - AliasedExpr: - - Expr.trim: - - Expr.string_concat: - - Constant: " " - - Field: title - - Constant: " " - - "trimmed_title" - assert_results: - - trimmed_title: "The Hitchhiker's Guide to the Galaxy" - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: author - - stringValue: Douglas Adams - name: equal - name: where - - args: - - mapValue: - fields: - trimmed_title: - functionValue: - args: - - functionValue: - args: - - stringValue: " " - - fieldReferenceValue: title - - stringValue: " " - name: string_concat - name: trim - name: select - - description: testStringFunctions - StringReverse - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: author - - Constant: "Jane Austen" - - Select: - - AliasedExpr: - - Expr.string_reverse: - - Field: title - - "reversed_title" - assert_results: - - reversed_title: "ecidujerP dna edirP" - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: author - - stringValue: "Jane Austen" - name: equal - name: where - - args: - - mapValue: - fields: - reversed_title: - functionValue: - args: - - fieldReferenceValue: title - name: string_reverse - name: select - - description: testStringFunctions - Substring - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: author - - Constant: "Douglas Adams" - - Select: - - AliasedExpr: - - Expr.substring: - - Field: title - - Constant: 4 - - Constant: 11 - - "substring_title" - assert_results: - - substring_title: "Hitchhiker'" - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: author - - stringValue: "Douglas Adams" - name: equal - name: where - - args: - - mapValue: - fields: - substring_title: - functionValue: - args: - - fieldReferenceValue: title - - integerValue: '4' - - integerValue: '11' - name: substring - name: select - - description: testStringFunctions - Substring without length - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: author - - Constant: "Fyodor Dostoevsky" - - Select: - - AliasedExpr: - - Expr.substring: - - Field: title - - Constant: 10 - - "substring_title" - assert_results: - - substring_title: "Punishment" - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: author - - stringValue: "Fyodor Dostoevsky" - name: equal - name: where - - args: - - mapValue: - fields: - substring_title: - functionValue: - args: - - fieldReferenceValue: title - - integerValue: '10' - name: substring - name: select - - description: testStringFunctions - Join - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: author - - Constant: "Douglas Adams" - - Select: - - AliasedExpr: - - Expr.join: - - Field: tags - - Constant: ", " - - "joined_tags" - assert_results: - - joined_tags: "comedy, space, adventure" - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: author - - stringValue: "Douglas Adams" - name: equal - name: where - - args: - - mapValue: - fields: - joined_tags: - functionValue: - args: - - fieldReferenceValue: tags - - stringValue: ", " - name: join - name: select diff --git a/tests/system/pipeline_e2e/aggregates.yaml b/tests/system/pipeline_e2e/aggregates.yaml new file mode 100644 index 000000000..18902aff4 --- /dev/null +++ b/tests/system/pipeline_e2e/aggregates.yaml @@ -0,0 +1,269 @@ +tests: + - description: "testAggregates - count" + pipeline: + - Collection: books + - Aggregate: + - AliasedExpr: + - Expr.count: + - Field: rating + - "count" + assert_results: + - count: 10 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + count: + functionValue: + name: count + args: + - fieldReferenceValue: rating + - mapValue: {} + name: aggregate + - description: "testAggregates - count_if" + pipeline: + - Collection: books + - Aggregate: + - AliasedExpr: + - Expr.count_if: + - Expr.greater_than: + - Field: rating + - Constant: 4.2 + - "count_if_rating_gt_4_2" + assert_results: + - count_if_rating_gt_4_2: 5 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + count_if_rating_gt_4_2: + functionValue: + name: count_if + args: + - functionValue: + name: greater_than + args: + - fieldReferenceValue: rating + - doubleValue: 4.2 + - mapValue: {} + name: aggregate + - description: "testAggregates - count_distinct" + pipeline: + - Collection: books + - Aggregate: + - AliasedExpr: + - Expr.count_distinct: + - Field: genre + - "distinct_genres" + assert_results: + - distinct_genres: 8 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + distinct_genres: + functionValue: + name: count_distinct + args: + - fieldReferenceValue: genre + - mapValue: {} + name: aggregate + - description: "testAggregates - avg, count, max" + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: genre + - Constant: Science Fiction + - Aggregate: + - AliasedExpr: + - Expr.count: + - Field: rating + - "count" + - AliasedExpr: + - Expr.average: + - Field: rating + - "avg_rating" + - AliasedExpr: + - Expr.maximum: + - Field: rating + - "max_rating" + assert_results: + - count: 2 + avg_rating: 4.4 + max_rating: 4.6 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: equal + name: where + - args: + - mapValue: + fields: + avg_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: average + count: + functionValue: + name: count + args: + - fieldReferenceValue: rating + max_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: maximum + - mapValue: {} + name: aggregate + - description: testGroupBysWithoutAccumulators + pipeline: + - Collection: books + - Where: + - Expr.less_than: + - Field: published + - Constant: 1900 + - Aggregate: + accumulators: [] + groups: [genre] + assert_error: ".* requires at least one accumulator" + - description: testGroupBysAndAggregate + pipeline: + - Collection: books + - Where: + - Expr.less_than: + - Field: published + - Constant: 1984 + - Aggregate: + accumulators: + - AliasedExpr: + - Expr.average: + - Field: rating + - "avg_rating" + groups: [genre] + - Where: + - Expr.greater_than: + - Field: avg_rating + - Constant: 4.3 + - Sort: + - Ordering: + - Field: avg_rating + - ASCENDING + assert_results: + - avg_rating: 4.4 + genre: Science Fiction + - avg_rating: 4.5 + genre: Romance + - avg_rating: 4.7 + genre: Fantasy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1984' + name: less_than + name: where + - args: + - mapValue: + fields: + avg_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: average + - mapValue: + fields: + genre: + fieldReferenceValue: genre + name: aggregate + - args: + - functionValue: + args: + - fieldReferenceValue: avg_rating + - doubleValue: 4.3 + name: greater_than + name: where + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: avg_rating + name: sort + - description: testMinMax + pipeline: + - Collection: books + - Aggregate: + - AliasedExpr: + - Expr.count: + - Field: rating + - "count" + - AliasedExpr: + - Expr.maximum: + - Field: rating + - "max_rating" + - AliasedExpr: + - Expr.minimum: + - Field: published + - "min_published" + assert_results: + - count: 10 + max_rating: 4.7 + min_published: 1813 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + count: + functionValue: + args: + - fieldReferenceValue: rating + name: count + max_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: maximum + min_published: + functionValue: + args: + - fieldReferenceValue: published + name: minimum + - mapValue: {} + name: aggregate \ No newline at end of file diff --git a/tests/system/pipeline_e2e/array.yaml b/tests/system/pipeline_e2e/array.yaml new file mode 100644 index 000000000..d63a63402 --- /dev/null +++ b/tests/system/pipeline_e2e/array.yaml @@ -0,0 +1,388 @@ +tests: + - description: testArrayContains + pipeline: + - Collection: books + - Where: + - Expr.array_contains: + - Field: tags + - Constant: comedy + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + author: Douglas Adams + awards: + hugo: true + nebula: false + genre: Science Fiction + published: 1979 + rating: 4.2 + tags: ["comedy", "space", "adventure"] + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: tags + - stringValue: comedy + name: array_contains + name: where + - description: testArrayContainsAny + pipeline: + - Collection: books + - Where: + - Expr.array_contains_any: + - Field: tags + - - Constant: comedy + - Constant: classic + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: Pride and Prejudice + - title: The Hitchhiker's Guide to the Galaxy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: tags + - functionValue: + args: + - stringValue: comedy + - stringValue: classic + name: array + name: array_contains_any + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testArrayContainsAll + pipeline: + - Collection: books + - Where: + - Expr.array_contains_all: + - Field: tags + - - Constant: adventure + - Constant: magic + - Select: + - title + assert_results: + - title: The Lord of the Rings + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: tags + - functionValue: + args: + - stringValue: adventure + - stringValue: magic + name: array + name: array_contains_all + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - description: testArrayLength + pipeline: + - Collection: books + - Select: + - AliasedExpr: + - Expr.array_length: + - Field: tags + - "tagsCount" + - Where: + - Expr.equal: + - Field: tagsCount + - Constant: 3 + assert_results: # All documents have 3 tags + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + tagsCount: + functionValue: + args: + - fieldReferenceValue: tags + name: array_length + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: tagsCount + - integerValue: '3' + name: equal + name: where + - description: testArrayReverse + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpr: + - Expr.array_reverse: + - Field: tags + - "reversedTags" + assert_results: + - reversedTags: + - adventure + - space + - comedy + - description: testArrayConcat + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpr: + - Expr.array_concat: + - Field: tags + - Array: ["new_tag", "another_tag"] + - "concatenatedTags" + assert_results: + - concatenatedTags: + - comedy + - space + - adventure + - new_tag + - another_tag + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "The Hitchhiker's Guide to the Galaxy" + name: equal + name: where + - args: + - mapValue: + fields: + concatenatedTags: + functionValue: + args: + - fieldReferenceValue: tags + - functionValue: + args: + - stringValue: "new_tag" + - stringValue: "another_tag" + name: array + name: array_concat + name: select + - description: testArrayConcatMultiple + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "Dune" + - Select: + - AliasedExpr: + - Expr.array_concat: + - Field: tags + - ["sci-fi"] + - ["classic", "epic"] + - "concatenatedTags" + assert_results: + - concatenatedTags: + - politics + - desert + - ecology + - sci-fi + - classic + - epic + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "Dune" + name: equal + name: where + - args: + - mapValue: + fields: + concatenatedTags: + functionValue: + args: + - fieldReferenceValue: tags + - functionValue: + args: + - stringValue: "sci-fi" + name: array + - functionValue: + args: + - stringValue: "classic" + - stringValue: "epic" + name: array + name: array_concat + name: select + - description: testArrayContainsAnyWithField + pipeline: + - Collection: books + - AddFields: + - AliasedExpr: + - Expr.array_concat: + - Field: tags + - Array: ["Dystopian"] + - "new_tags" + - Where: + - Expr.array_contains_any: + - Field: new_tags + - - Constant: non_existent_tag + - Field: genre + - Select: + - title + - genre + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "1984" + genre: "Dystopian" + - title: "The Handmaid's Tale" + genre: "Dystopian" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + new_tags: + functionValue: + args: + - fieldReferenceValue: tags + - functionValue: + args: + - stringValue: "Dystopian" + name: array + name: array_concat + name: add_fields + - args: + - functionValue: + args: + - fieldReferenceValue: new_tags + - functionValue: + args: + - stringValue: "non_existent_tag" + - fieldReferenceValue: genre + name: array + name: array_contains_any + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + genre: + fieldReferenceValue: genre + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testArrayConcatLiterals + pipeline: + - Collection: books + - Limit: 1 + - Select: + - AliasedExpr: + - Expr.array_concat: + - Array: [1, 2, 3] + - Array: [4, 5] + - "concatenated" + assert_results: + - concatenated: [1, 2, 3, 4, 5] + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - integerValue: '1' + name: limit + - args: + - mapValue: + fields: + concatenated: + functionValue: + args: + - functionValue: + args: + - integerValue: '1' + - integerValue: '2' + - integerValue: '3' + name: array + - functionValue: + args: + - integerValue: '4' + - integerValue: '5' + name: array + name: array_concat + name: select diff --git a/tests/system/pipeline_e2e/data.yaml b/tests/system/pipeline_e2e/data.yaml new file mode 100644 index 000000000..902f7782d --- /dev/null +++ b/tests/system/pipeline_e2e/data.yaml @@ -0,0 +1,142 @@ +data: + books: + book1: + title: "The Hitchhiker's Guide to the Galaxy" + author: "Douglas Adams" + genre: "Science Fiction" + published: 1979 + rating: 4.2 + tags: + - comedy + - space + - adventure + awards: + hugo: true + nebula: false + book2: + title: "Pride and Prejudice" + author: "Jane Austen" + genre: "Romance" + published: 1813 + rating: 4.5 + tags: + - classic + - social commentary + - love + awards: + none: true + book3: + title: "One Hundred Years of Solitude" + author: "Gabriel García Márquez" + genre: "Magical Realism" + published: 1967 + rating: 4.3 + tags: + - family + - history + - fantasy + awards: + nobel: true + nebula: false + book4: + title: "The Lord of the Rings" + author: "J.R.R. Tolkien" + genre: "Fantasy" + published: 1954 + rating: 4.7 + tags: + - adventure + - magic + - epic + awards: + hugo: false + nebula: false + book5: + title: "The Handmaid's Tale" + author: "Margaret Atwood" + genre: "Dystopian" + published: 1985 + rating: 4.1 + tags: + - feminism + - totalitarianism + - resistance + awards: + arthur c. clarke: true + booker prize: false + book6: + title: "Crime and Punishment" + author: "Fyodor Dostoevsky" + genre: "Psychological Thriller" + published: 1866 + rating: 4.3 + tags: + - philosophy + - crime + - redemption + awards: + none: true + book7: + title: "To Kill a Mockingbird" + author: "Harper Lee" + genre: "Southern Gothic" + published: 1960 + rating: 4.2 + tags: + - racism + - injustice + - coming-of-age + awards: + pulitzer: true + book8: + title: "1984" + author: "George Orwell" + genre: "Dystopian" + published: 1949 + rating: 4.2 + tags: + - surveillance + - totalitarianism + - propaganda + awards: + prometheus: true + book9: + title: "The Great Gatsby" + author: "F. Scott Fitzgerald" + genre: "Modernist" + published: 1925 + rating: 4.0 + tags: + - wealth + - american dream + - love + awards: + none: true + book10: + title: "Dune" + author: "Frank Herbert" + genre: "Science Fiction" + published: 1965 + rating: 4.6 + tags: + - politics + - desert + - ecology + awards: + hugo: true + nebula: true + timestamps: + ts1: + time: "1993-04-28T12:01:00.654321+00:00" + micros: 735998460654321 + millis: 735998460654 + seconds: 735998460 + vectors: + vec1: + embedding: [1.0, 2.0, 3.0] + vec2: + embedding: [4.0, 5.0, 6.0, 7.0] + vec3: + embedding: [5.0, 6.0, 7.0] + vec4: + embedding: [1.0, 2.0, 4.0] \ No newline at end of file diff --git a/tests/system/pipeline_e2e/date_and_time.yaml b/tests/system/pipeline_e2e/date_and_time.yaml new file mode 100644 index 000000000..bbb5f34fe --- /dev/null +++ b/tests/system/pipeline_e2e/date_and_time.yaml @@ -0,0 +1,103 @@ +tests: + - description: testCurrentTimestamp + pipeline: + - Collection: books + - Limit: 1 + - Select: + - AliasedExpr: + - And: + - Expr.greater_than_or_equal: + - CurrentTimestamp: [] + - Expr.unix_seconds_to_timestamp: + - Constant: 1735689600 # 2025-01-01 + - Expr.less_than: + - CurrentTimestamp: [] + - Expr.unix_seconds_to_timestamp: + - Constant: 4892438400 # 2125-01-01 + - "is_between_2025_and_2125" + assert_results: + - is_between_2025_and_2125: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - integerValue: '1' + name: limit + - args: + - mapValue: + fields: + is_between_2025_and_2125: + functionValue: + name: and + args: + - functionValue: + name: greater_than_or_equal + args: + - functionValue: + name: current_timestamp + - functionValue: + name: unix_seconds_to_timestamp + args: + - integerValue: '1735689600' + - functionValue: + name: less_than + args: + - functionValue: + name: current_timestamp + - functionValue: + name: unix_seconds_to_timestamp + args: + - integerValue: '4892438400' + name: select + - description: testTimestampFunctions + pipeline: + - Collection: timestamps + - Select: + - AliasedExpr: + - Expr.timestamp_to_unix_micros: + - Field: time + - "micros" + - AliasedExpr: + - Expr.timestamp_to_unix_millis: + - Field: time + - "millis" + - AliasedExpr: + - Expr.timestamp_to_unix_seconds: + - Field: time + - "seconds" + - AliasedExpr: + - Expr.unix_micros_to_timestamp: + - Field: micros + - "from_micros" + - AliasedExpr: + - Expr.unix_millis_to_timestamp: + - Field: millis + - "from_millis" + - AliasedExpr: + - Expr.unix_seconds_to_timestamp: + - Field: seconds + - "from_seconds" + - AliasedExpr: + - Expr.timestamp_add: + - Field: time + - Constant: "day" + - Constant: 1 + - "plus_day" + - AliasedExpr: + - Expr.timestamp_subtract: + - Field: time + - Constant: "hour" + - Constant: 1 + - "minus_hour" + assert_results: + - micros: 735998460654321 + millis: 735998460654 + seconds: 735998460 + from_micros: "1993-04-28T12:01:00.654321+00:00" + from_millis: "1993-04-28T12:01:00.654000+00:00" + from_seconds: "1993-04-28T12:01:00.000000+00:00" + plus_day: "1993-04-29T12:01:00.654321+00:00" + minus_hour: "1993-04-28T11:01:00.654321+00:00" diff --git a/tests/system/pipeline_e2e/general.yaml b/tests/system/pipeline_e2e/general.yaml new file mode 100644 index 000000000..c135853d1 --- /dev/null +++ b/tests/system/pipeline_e2e/general.yaml @@ -0,0 +1,784 @@ +tests: + - description: selectSpecificFields + pipeline: + - Collection: books + - Select: + - title + - author + - Sort: + - Ordering: + - Field: author + - ASCENDING + assert_results: + - title: "The Hitchhiker's Guide to the Galaxy" + author: "Douglas Adams" + - title: "The Great Gatsby" + author: "F. Scott Fitzgerald" + - title: "Dune" + author: "Frank Herbert" + - title: "Crime and Punishment" + author: "Fyodor Dostoevsky" + - title: "One Hundred Years of Solitude" + author: "Gabriel García Márquez" + - title: "1984" + author: "George Orwell" + - title: "To Kill a Mockingbird" + author: "Harper Lee" + - title: "The Lord of the Rings" + author: "J.R.R. Tolkien" + - title: "Pride and Prejudice" + author: "Jane Austen" + - title: "The Handmaid's Tale" + author: "Margaret Atwood" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + author: + fieldReferenceValue: author + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author + name: sort + - description: addAndRemoveFields + pipeline: + - Collection: books + - AddFields: + - AliasedExpr: + - Expr.string_concat: + - Field: author + - Constant: _ + - Field: title + - "author_title" + - AliasedExpr: + - Expr.string_concat: + - Field: title + - Constant: _ + - Field: author + - "title_author" + - RemoveFields: + - title_author + - tags + - awards + - rating + - title + - Field: published + - Field: genre + - Field: nestedField # Field does not exist, should be ignored + - Sort: + - Ordering: + - Field: author_title + - ASCENDING + assert_results: + - author: Douglas Adams + author_title: Douglas Adams_The Hitchhiker's Guide to the Galaxy + - author: F. Scott Fitzgerald + author_title: F. Scott Fitzgerald_The Great Gatsby + - author: Frank Herbert + author_title: Frank Herbert_Dune + - author: Fyodor Dostoevsky + author_title: Fyodor Dostoevsky_Crime and Punishment + - author: Gabriel García Márquez + author_title: Gabriel García Márquez_One Hundred Years of Solitude + - author: George Orwell + author_title: George Orwell_1984 + - author: Harper Lee + author_title: Harper Lee_To Kill a Mockingbird + - author: J.R.R. Tolkien + author_title: J.R.R. Tolkien_The Lord of the Rings + - author: Jane Austen + author_title: Jane Austen_Pride and Prejudice + - author: Margaret Atwood + author_title: Margaret Atwood_The Handmaid's Tale + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + author_title: + functionValue: + args: + - fieldReferenceValue: author + - stringValue: _ + - fieldReferenceValue: title + name: string_concat + title_author: + functionValue: + args: + - fieldReferenceValue: title + - stringValue: _ + - fieldReferenceValue: author + name: string_concat + name: add_fields + - args: + - fieldReferenceValue: title_author + - fieldReferenceValue: tags + - fieldReferenceValue: awards + - fieldReferenceValue: rating + - fieldReferenceValue: title + - fieldReferenceValue: published + - fieldReferenceValue: genre + - fieldReferenceValue: nestedField + name: remove_fields + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author_title + name: sort + + - description: testPipelineWithOffsetAndLimit + pipeline: + - Collection: books + - Sort: + - Ordering: + - Field: author + - ASCENDING + - Offset: 5 + - Limit: 3 + - Select: + - title + - author + assert_results: + - title: "1984" + author: George Orwell + - title: To Kill a Mockingbird + author: Harper Lee + - title: The Lord of the Rings + author: J.R.R. Tolkien + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author + name: sort + - args: + - integerValue: '5' + name: offset + - args: + - integerValue: '3' + name: limit + - args: + - mapValue: + fields: + author: + fieldReferenceValue: author + title: + fieldReferenceValue: title + name: select + - description: testArithmeticOperations + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: To Kill a Mockingbird + - Select: + - AliasedExpr: + - Expr.add: + - Field: rating + - Constant: 1 + - "ratingPlusOne" + - AliasedExpr: + - Expr.subtract: + - Field: published + - Constant: 1900 + - "yearsSince1900" + - AliasedExpr: + - Expr.multiply: + - Field: rating + - Constant: 10 + - "ratingTimesTen" + - AliasedExpr: + - Expr.divide: + - Field: rating + - Constant: 2 + - "ratingDividedByTwo" + - AliasedExpr: + - Expr.multiply: + - Field: rating + - Constant: 20 + - "ratingTimes20" + - AliasedExpr: + - Expr.add: + - Field: rating + - Constant: 3 + - "ratingPlus3" + - AliasedExpr: + - Expr.mod: + - Field: rating + - Constant: 2 + - "ratingMod2" + assert_results: + - ratingPlusOne: 5.2 + yearsSince1900: 60 + ratingTimesTen: 42.0 + ratingDividedByTwo: 2.1 + ratingTimes20: 84 + ratingPlus3: 7.2 + ratingMod2: 0.20000000000000018 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: To Kill a Mockingbird + name: equal + name: where + - args: + - mapValue: + fields: + ratingDividedByTwo: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: divide + ratingPlusOne: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '1' + name: add + ratingTimesTen: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '10' + name: multiply + yearsSince1900: + functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: subtract + ratingTimes20: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '20' + name: multiply + ratingPlus3: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '3' + name: add + ratingMod2: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: mod + name: select + - description: testSampleLimit + pipeline: + - Collection: books + - Sample: 3 + assert_count: 3 # Results will vary due to randomness + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - integerValue: '3' + - stringValue: documents + name: sample + - description: testSamplePercentage + pipeline: + - Collection: books + - Sample: + - SampleOptions: + - 0.6 + - percent + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - doubleValue: 0.6 + - stringValue: percent + name: sample + - description: testUnion + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: genre + - Constant: Romance + - Union: + - Pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: genre + - Constant: Dystopian + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "1984" + - title: Pride and Prejudice + - title: "The Handmaid's Tale" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Romance + name: equal + name: where + - args: + - pipelineValue: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Dystopian + name: equal + name: where + name: union + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testUnionFullCollection + pipeline: + - Collection: books + - Union: + - Pipeline: + - Collection: books + assert_count: 20 # Results will be duplicated + - description: testDocumentId + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpr: + - Expr.document_id: + - Field: __name__ + - "doc_id" + assert_results: + - doc_id: "book1" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "The Hitchhiker's Guide to the Galaxy" + name: equal + name: where + - args: + - mapValue: + fields: + doc_id: + functionValue: + name: document_id + args: + - fieldReferenceValue: __name__ + name: select + - description: testSum + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: genre + - Constant: Science Fiction + - Aggregate: + - AliasedExpr: + - Expr.sum: + - Field: rating + - "total_rating" + assert_results: + - total_rating: 8.8 + + - description: testCollectionId + pipeline: + - Collection: books + - Limit: 1 + - Select: + - AliasedExpr: + - Expr.collection_id: + - Field: __name__ + - "collectionName" + assert_results: + - collectionName: "books" + - description: testCollectionGroup + pipeline: + - CollectionGroup: books + - Select: + - title + - Distinct: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "1984" + - title: "Crime and Punishment" + - title: "Dune" + - title: "One Hundred Years of Solitude" + - title: "Pride and Prejudice" + - title: "The Great Gatsby" + - title: "The Handmaid's Tale" + - title: "The Hitchhiker's Guide to the Galaxy" + - title: "The Lord of the Rings" + - title: "To Kill a Mockingbird" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: '' + - stringValue: books + name: collection_group + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: distinct + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + + - description: testDistinct + pipeline: + - Collection: books + - Distinct: + - genre + - Sort: + - Ordering: + - Field: genre + - ASCENDING + assert_results: + - genre: Dystopian + - genre: Fantasy + - genre: Magical Realism + - genre: Modernist + - genre: Psychological Thriller + - genre: Romance + - genre: Science Fiction + - genre: Southern Gothic + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + genre: + fieldReferenceValue: genre + name: distinct + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: genre + name: sort + + - description: testDocuments + pipeline: + - Documents: + - /books/book1 + - /books/book10 + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "Dune" + - title: "The Hitchhiker's Guide to the Galaxy" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books/book1 + - referenceValue: /books/book10 + name: documents + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + + - description: testDatabase + pipeline: + - Database + - Select: + - title + - Distinct: + - title + - Aggregate: + - AliasedExpr: + - Count: [] + - count + - Select: + - AliasedExpr: + - Conditional: + - Expr.greater_than_or_equal: + - Field: count + - Constant: 10 + - Constant: True + - Constant: False + - result + assert_results: + - result: True + assert_proto: + pipeline: + stages: + - name: database + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: distinct + - args: + - mapValue: + fields: + count: + functionValue: + name: count + - mapValue: {} + name: aggregate + - name: select + args: + - mapValue: + fields: + result: + functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: count + - integerValue: '10' + name: greater_than_or_equal + - booleanValue: true + - booleanValue: false + name: conditional + - description: testGenericStage + pipeline: + - GenericStage: + - "collection" + - Value: + reference_value: "/books" + - GenericStage: + - "where" + - Expr.equal: + - Field: title + - Constant: The Hitchhiker's Guide to the Galaxy + - GenericStage: + - "select" + - Value: + map_value: + fields: + author: + field_reference_value: author + assert_results: + - author: Douglas Adams + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: The Hitchhiker's Guide to the Galaxy + name: equal + name: where + - args: + - mapValue: + fields: + author: + fieldReferenceValue: author + name: select + - description: testUnnest + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: The Hitchhiker's Guide to the Galaxy + - Unnest: + - tags + - tags_alias + - Select: tags_alias + assert_results: + - tags_alias: comedy + - tags_alias: space + - tags_alias: adventure + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: The Hitchhiker's Guide to the Galaxy + name: equal + name: where + - args: + - fieldReferenceValue: tags + - fieldReferenceValue: tags_alias + name: unnest + - args: + - mapValue: + fields: + tags_alias: + fieldReferenceValue: tags_alias + name: select + - description: testUnnestWithOptions + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: The Hitchhiker's Guide to the Galaxy + - Unnest: + field: tags + alias: tags_alias + options: + UnnestOptions: + - index + - Select: + - tags_alias + - index + assert_results: + - tags_alias: comedy + index: 0 + - tags_alias: space + index: 1 + - tags_alias: adventure + index: 2 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: The Hitchhiker's Guide to the Galaxy + name: equal + name: where + - args: + - fieldReferenceValue: tags + - fieldReferenceValue: tags_alias + name: unnest + options: + index_field: + fieldReferenceValue: index + - args: + - mapValue: + fields: + tags_alias: + fieldReferenceValue: tags_alias + index: + fieldReferenceValue: index + name: select \ No newline at end of file diff --git a/tests/system/pipeline_e2e/logical.yaml b/tests/system/pipeline_e2e/logical.yaml new file mode 100644 index 000000000..203be290d --- /dev/null +++ b/tests/system/pipeline_e2e/logical.yaml @@ -0,0 +1,673 @@ +tests: + - description: whereByMultipleConditions + pipeline: + - Collection: books + - Where: + - And: + - Expr.greater_than: + - Field: rating + - Constant: 4.5 + - Expr.equal: + - Field: genre + - Constant: Science Fiction + assert_results: + - title: Dune + author: Frank Herbert + genre: Science Fiction + published: 1965 + rating: 4.6 + tags: + - politics + - desert + - ecology + awards: + hugo: true + nebula: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: greater_than + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: equal + name: and + name: where + - description: whereByOrCondition + pipeline: + - Collection: books + - Where: + - Or: + - Expr.equal: + - Field: genre + - Constant: Romance + - Expr.equal: + - Field: genre + - Constant: Dystopian + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "1984" + - title: Pride and Prejudice + - title: The Handmaid's Tale + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Romance + name: equal + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Dystopian + name: equal + name: or + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testComparisonOperators + pipeline: + - Collection: books + - Where: + - And: + - Expr.greater_than: + - Field: rating + - Constant: 4.2 + - Expr.less_than_or_equal: + - Field: rating + - Constant: 4.5 + - Expr.not_equal: + - Field: genre + - Constant: Science Fiction + - Select: + - rating + - title + - Sort: + - Ordering: + - title + - ASCENDING + assert_results: + - rating: 4.3 + title: Crime and Punishment + - rating: 4.3 + title: One Hundred Years of Solitude + - rating: 4.5 + title: Pride and Prejudice + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.2 + name: greater_than + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: less_than_or_equal + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: not_equal + name: and + name: where + - args: + - mapValue: + fields: + rating: + fieldReferenceValue: rating + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testLogicalOperators + pipeline: + - Collection: books + - Where: + - Or: + - And: + - Expr.greater_than: + - Field: rating + - Constant: 4.5 + - Expr.equal: + - Field: genre + - Constant: Science Fiction + - Expr.less_than: + - Field: published + - Constant: 1900 + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: Crime and Punishment + - title: Dune + - title: Pride and Prejudice + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: greater_than + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: equal + name: and + - functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: less_than + name: or + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testChecks + pipeline: + - Collection: books + - Where: + - Not: + - Expr.is_nan: + - Field: rating + - Select: + - AliasedExpr: + - Not: + - Expr.is_nan: + - Field: rating + - "ratingIsNotNaN" + - Limit: 1 + assert_results: + - ratingIsNotNaN: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + name: is_nan + name: not + name: where + - args: + - mapValue: + fields: + ratingIsNotNaN: + functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + name: is_nan + name: not + name: select + - args: + - integerValue: '1' + name: limit + - description: testIsNotNull + pipeline: + - Collection: books + - Where: + - Expr.is_not_null: + - Field: rating + assert_count: 10 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: rating + name: is_not_null + name: where + - description: testIsNotNaN + pipeline: + - Collection: books + - Where: + - Expr.is_not_nan: + - Field: rating + assert_count: 10 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: rating + name: is_not_nan + name: where + - description: testIsAbsent + pipeline: + - Collection: books + - Where: + - Expr.is_absent: + - Field: awards.pulitzer + assert_count: 9 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: awards.pulitzer + name: is_absent + name: where + - description: testIfAbsent + pipeline: + - Collection: books + - Select: + - AliasedExpr: + - Expr.if_absent: + - Field: awards.pulitzer + - Constant: false + - "pulitzer_award" + - title + - Where: + - Expr.equal: + - Field: pulitzer_award + - Constant: true + assert_results: + - pulitzer_award: true + title: To Kill a Mockingbird + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + pulitzer_award: + functionValue: + name: if_absent + args: + - fieldReferenceValue: awards.pulitzer + - booleanValue: false + title: + fieldReferenceValue: title + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: pulitzer_award + - booleanValue: true + name: equal + name: where + - description: testIsError + pipeline: + - Collection: books + - Select: + - AliasedExpr: + - Expr.is_error: + - Expr.divide: + - Field: rating + - Constant: "string" + - "is_error_result" + - Limit: 1 + assert_results: + - is_error_result: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + is_error_result: + functionValue: + name: is_error + args: + - functionValue: + name: divide + args: + - fieldReferenceValue: rating + - stringValue: "string" + name: select + - args: + - integerValue: '1' + name: limit + - description: testIfError + pipeline: + - Collection: books + - Select: + - AliasedExpr: + - Expr.if_error: + - Expr.divide: + - Field: rating + - Field: genre + - Constant: "An error occurred" + - "if_error_result" + - Limit: 1 + assert_results: + - if_error_result: "An error occurred" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + if_error_result: + functionValue: + name: if_error + args: + - functionValue: + name: divide + args: + - fieldReferenceValue: rating + - fieldReferenceValue: genre + - stringValue: "An error occurred" + name: select + - args: + - integerValue: '1' + name: limit + - description: testLogicalMinMax + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: Douglas Adams + - Select: + - AliasedExpr: + - Expr.logical_maximum: + - Field: rating + - Constant: 4.5 + - "max_rating" + - AliasedExpr: + - Expr.logical_minimum: + - Field: published + - Constant: 1900 + - "min_published" + assert_results: + - max_rating: 4.5 + min_published: 1900 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + min_published: + functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: minimum + max_rating: + functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: maximum + name: select + - description: testGreaterThanOrEqual + pipeline: + - Collection: books + - Where: + - Expr.greater_than_or_equal: + - Field: rating + - Constant: 4.6 + - Select: + - title + - rating + - Sort: + - Ordering: + - Field: rating + - ASCENDING + assert_results: + - title: Dune + rating: 4.6 + - title: The Lord of the Rings + rating: 4.7 + - description: testInAndNotIn + pipeline: + - Collection: books + - Where: + - And: + - Expr.equal_any: + - Field: genre + - - Constant: Romance + - Constant: Dystopian + - Expr.not_equal_any: + - Field: author + - - Constant: "George Orwell" + assert_results: + - title: "Pride and Prejudice" + author: "Jane Austen" + genre: "Romance" + published: 1813 + rating: 4.5 + tags: + - classic + - social commentary + - love + awards: + none: true + - title: "The Handmaid's Tale" + author: "Margaret Atwood" + genre: "Dystopian" + published: 1985 + rating: 4.1 + tags: + - feminism + - totalitarianism + - resistance + awards: + "arthur c. clarke": true + "booker prize": false + - description: testExists + pipeline: + - Collection: books + - Where: + - And: + - Expr.exists: + - Field: awards.pulitzer + - Expr.equal: + - Field: awards.pulitzer + - Constant: true + - Select: + - title + assert_results: + - title: To Kill a Mockingbird + - description: testXor + pipeline: + - Collection: books + - Where: + - Xor: + - - Expr.equal: + - Field: genre + - Constant: Romance + - Expr.greater_than: + - Field: published + - Constant: 1980 + - Select: + - title + - genre + - published + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "Pride and Prejudice" + genre: "Romance" + published: 1813 + - title: "The Handmaid's Tale" + genre: "Dystopian" + published: 1985 + - description: testConditional + pipeline: + - Collection: books + - Select: + - title + - AliasedExpr: + - Conditional: + - Expr.greater_than: + - Field: published + - Constant: 1950 + - Constant: "Modern" + - Constant: "Classic" + - "era" + - Sort: + - Ordering: + - Field: title + - ASCENDING + - Limit: 4 + assert_results: + - title: "1984" + era: "Classic" + - title: "Crime and Punishment" + era: "Classic" + - title: "Dune" + era: "Modern" + - title: "One Hundred Years of Solitude" + era: "Modern" + - description: testFieldToFieldComparison + pipeline: + - Collection: books + - Where: + - Expr.greater_than: + - Field: published + - Field: rating + - Select: + - title + assert_count: 10 # All books were published after year 4.7 + - description: testExistsNegative + pipeline: + - Collection: books + - Where: + - Expr.exists: + - Field: non_existent_field + assert_count: 0 + - description: testConditionalWithFields + pipeline: + - Collection: books + - Where: + - Expr.equal_any: + - Field: title + - - Constant: "Dune" + - Constant: "1984" + - Select: + - title + - AliasedExpr: + - Conditional: + - Expr.greater_than: + - Field: published + - Constant: 1950 + - Field: author + - Field: genre + - "conditional_field" + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "1984" + conditional_field: "Dystopian" + - title: "Dune" + conditional_field: "Frank Herbert" diff --git a/tests/system/pipeline_e2e/map.yaml b/tests/system/pipeline_e2e/map.yaml new file mode 100644 index 000000000..638fe0798 --- /dev/null +++ b/tests/system/pipeline_e2e/map.yaml @@ -0,0 +1,269 @@ +tests: + - description: testMapGet + pipeline: + - Collection: books + - Sort: + - Ordering: + - Field: published + - DESCENDING + - Select: + - AliasedExpr: + - Expr.map_get: + - Field: awards + - hugo + - "hugoAward" + - Field: title + - Where: + - Expr.equal: + - Field: hugoAward + - Constant: true + assert_results: + - hugoAward: true + title: The Hitchhiker's Guide to the Galaxy + - hugoAward: true + title: Dune + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + direction: + stringValue: descending + expression: + fieldReferenceValue: published + name: sort + - args: + - mapValue: + fields: + hugoAward: + functionValue: + args: + - fieldReferenceValue: awards + - stringValue: hugo + name: map_get + title: + fieldReferenceValue: title + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: hugoAward + - booleanValue: true + name: equal + name: where + - description: testMapGetWithField + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "Dune" + - AddFields: + - AliasedExpr: + - Constant: "hugo" + - "award_name" + - Select: + - AliasedExpr: + - Expr.map_get: + - Field: awards + - Field: award_name + - "hugoAward" + - Field: title + assert_results: + - hugoAward: true + title: Dune + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "Dune" + name: equal + name: where + - args: + - mapValue: + fields: + award_name: + stringValue: "hugo" + name: add_fields + - args: + - mapValue: + fields: + hugoAward: + functionValue: + name: map_get + args: + - fieldReferenceValue: awards + - fieldReferenceValue: award_name + title: + fieldReferenceValue: title + name: select + - description: testMapRemove + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "Dune" + - Select: + - AliasedExpr: + - Expr.map_remove: + - Field: awards + - "nebula" + - "awards_removed" + assert_results: + - awards_removed: + hugo: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "Dune" + name: equal + name: where + - args: + - mapValue: + fields: + awards_removed: + functionValue: + name: map_remove + args: + - fieldReferenceValue: awards + - stringValue: "nebula" + name: select + - description: testMapMerge + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "Dune" + - Select: + - AliasedExpr: + - Expr.map_merge: + - Field: awards + - Map: + elements: {"new_award": true, "hugo": false} + - Map: + elements: {"another_award": "yes"} + - "awards_merged" + assert_results: + - awards_merged: + hugo: false + nebula: true + new_award: true + another_award: "yes" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "Dune" + name: equal + name: where + - args: + - mapValue: + fields: + awards_merged: + functionValue: + name: map_merge + args: + - fieldReferenceValue: awards + - functionValue: + name: map + args: + - stringValue: "new_award" + - booleanValue: true + - stringValue: "hugo" + - booleanValue: false + - functionValue: + name: map + args: + - stringValue: "another_award" + - stringValue: "yes" + name: select + - description: testNestedFields + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: awards.hugo + - Constant: true + - Sort: + - Ordering: + - Field: title + - DESCENDING + - Select: + - title + - Field: awards.hugo + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + awards.hugo: true + - title: Dune + awards.hugo: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: awards.hugo + - booleanValue: true + name: equal + name: where + - args: + - mapValue: + fields: + direction: + stringValue: descending + expression: + fieldReferenceValue: title + name: sort + - args: + - mapValue: + fields: + awards.hugo: + fieldReferenceValue: awards.hugo + title: + fieldReferenceValue: title + name: select + - description: testMapMergeLiterals + pipeline: + - Collection: books + - Limit: 1 + - Select: + - AliasedExpr: + - Expr.map_merge: + - Map: + elements: {"a": "orig", "b": "orig"} + - Map: + elements: {"b": "new", "c": "new"} + - "merged" + assert_results: + - merged: + a: "orig" + b: "new" + c: "new" diff --git a/tests/system/pipeline_e2e/math.yaml b/tests/system/pipeline_e2e/math.yaml new file mode 100644 index 000000000..a5a47d4c0 --- /dev/null +++ b/tests/system/pipeline_e2e/math.yaml @@ -0,0 +1,309 @@ +tests: + - description: testFieldToFieldArithmetic + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "Dune" + - Select: + - AliasedExpr: + - Expr.add: + - Field: published + - Field: rating + - "pub_plus_rating" + assert_results: + - pub_plus_rating: 1969.6 + - description: testMathExpressions + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: To Kill a Mockingbird + - Select: + - AliasedExpr: + - Expr.abs: + - Field: rating + - "abs_rating" + - AliasedExpr: + - Expr.ceil: + - Field: rating + - "ceil_rating" + - AliasedExpr: + - Expr.exp: + - Field: rating + - "exp_rating" + - AliasedExpr: + - Expr.floor: + - Field: rating + - "floor_rating" + - AliasedExpr: + - Expr.ln: + - Field: rating + - "ln_rating" + - AliasedExpr: + - Expr.log10: + - Field: rating + - "log_rating_base10" + - AliasedExpr: + - Expr.log: + - Field: rating + - Constant: 2 + - "log_rating_base2" + - AliasedExpr: + - Expr.pow: + - Field: rating + - Constant: 2 + - "pow_rating" + - AliasedExpr: + - Expr.sqrt: + - Field: rating + - "sqrt_rating" + assert_results_approximate: + - abs_rating: 4.2 + ceil_rating: 5.0 + exp_rating: 66.686331 + floor_rating: 4.0 + ln_rating: 1.4350845 + log_rating_base10: 0.623249 + log_rating_base2: 2.0704 + pow_rating: 17.64 + sqrt_rating: 2.049390 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: To Kill a Mockingbird + name: equal + name: where + - args: + - mapValue: + fields: + abs_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: abs + ceil_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: ceil + exp_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: exp + floor_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: floor + ln_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: ln + log_rating_base10: + functionValue: + args: + - fieldReferenceValue: rating + name: log10 + log_rating_base2: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: log + pow_rating: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: pow + sqrt_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: sqrt + name: select + - description: testRoundExpressions + pipeline: + - Collection: books + - Where: + - Expr.equal_any: + - Field: title + - - Constant: "To Kill a Mockingbird" # rating 4.2 + - Constant: "Pride and Prejudice" # rating 4.5 + - Constant: "The Lord of the Rings" # rating 4.7 + - Select: + - title + - AliasedExpr: + - Expr.round: + - Field: rating + - "round_rating" + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "Pride and Prejudice" + round_rating: 5.0 + - title: "The Lord of the Rings" + round_rating: 5.0 + - title: "To Kill a Mockingbird" + round_rating: 4.0 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - functionValue: + args: + - stringValue: "To Kill a Mockingbird" + - stringValue: "Pride and Prejudice" + - stringValue: "The Lord of the Rings" + name: array + name: equal_any + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + round_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: round + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testArithmeticOperations + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: To Kill a Mockingbird + - Select: + - AliasedExpr: + - Expr.add: + - Field: rating + - Constant: 1 + - "ratingPlusOne" + - AliasedExpr: + - Expr.subtract: + - Field: published + - Constant: 1900 + - "yearsSince1900" + - AliasedExpr: + - Expr.multiply: + - Field: rating + - Constant: 10 + - "ratingTimesTen" + - AliasedExpr: + - Expr.divide: + - Field: rating + - Constant: 2 + - "ratingDividedByTwo" + - AliasedExpr: + - Expr.multiply: + - Field: rating + - Constant: 20 + - "ratingTimes20" + - AliasedExpr: + - Expr.add: + - Field: rating + - Constant: 3 + - "ratingPlus3" + - AliasedExpr: + - Expr.mod: + - Field: rating + - Constant: 2 + - "ratingMod2" + assert_results: + - ratingPlusOne: 5.2 + yearsSince1900: 60 + ratingTimesTen: 42.0 + ratingDividedByTwo: 2.1 + ratingTimes20: 84 + ratingPlus3: 7.2 + ratingMod2: 0.20000000000000018 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: To Kill a Mockingbird + name: equal + name: where + - args: + - mapValue: + fields: + ratingDividedByTwo: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: divide + ratingPlusOne: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '1' + name: add + ratingTimesTen: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '10' + name: multiply + yearsSince1900: + functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: subtract + ratingTimes20: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '20' + name: multiply + ratingPlus3: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '3' + name: add + ratingMod2: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: mod + name: select \ No newline at end of file diff --git a/tests/system/pipeline_e2e/string.yaml b/tests/system/pipeline_e2e/string.yaml new file mode 100644 index 000000000..b1e3a0b64 --- /dev/null +++ b/tests/system/pipeline_e2e/string.yaml @@ -0,0 +1,654 @@ +tests: + - description: testStringConcat + pipeline: + - Collection: books + - Sort: + - Ordering: + - Field: author + - ASCENDING + - Select: + - AliasedExpr: + - Expr.string_concat: + - Field: author + - Constant: " - " + - Field: title + - "bookInfo" + - Limit: 1 + assert_results: + - bookInfo: Douglas Adams - The Hitchhiker's Guide to the Galaxy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author + name: sort + - args: + - mapValue: + fields: + bookInfo: + functionValue: + args: + - fieldReferenceValue: author + - stringValue: ' - ' + - fieldReferenceValue: title + name: string_concat + name: select + - args: + - integerValue: '1' + name: limit + - description: testStartsWith + pipeline: + - Collection: books + - Where: + - Expr.starts_with: + - Field: title + - Constant: The + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: The Great Gatsby + - title: The Handmaid's Tale + - title: The Hitchhiker's Guide to the Galaxy + - title: The Lord of the Rings + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: The + name: starts_with + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testEndsWith + pipeline: + - Collection: books + - Where: + - Expr.ends_with: + - Field: title + - Constant: y + - Select: + - title + - Sort: + - Ordering: + - Field: title + - DESCENDING + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + - title: The Great Gatsby + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: y + name: ends_with + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: descending + expression: + fieldReferenceValue: title + name: sort + - description: testConcat + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpr: + - Expr.concat: + - Field: author + - Constant: ": " + - Field: title + - "author_title" + - AliasedExpr: + - Expr.concat: + - Field: tags + - - Constant: "new_tag" + - "concatenatedTags" + assert_results: + - author_title: "Douglas Adams: The Hitchhiker's Guide to the Galaxy" + concatenatedTags: + - comedy + - space + - adventure + - new_tag + - description: testLength + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpr: + - Expr.length: + - Field: title + - "titleLength" + - AliasedExpr: + - Expr.length: + - Field: tags + - "tagsLength" + - AliasedExpr: + - Expr.length: + - Field: awards + - "awardsLength" + assert_results: + - titleLength: 36 + tagsLength: 3 + awardsLength: 2 + - description: testCharLength + pipeline: + - Collection: books + - Select: + - AliasedExpr: + - Expr.char_length: + - Field: title + - "titleLength" + - title + - Where: + - Expr.greater_than: + - Field: titleLength + - Constant: 20 + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - titleLength: 29 + title: One Hundred Years of Solitude + - titleLength: 36 + title: The Hitchhiker's Guide to the Galaxy + - titleLength: 21 + title: The Lord of the Rings + - titleLength: 21 + title: To Kill a Mockingbird + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + titleLength: + functionValue: + args: + - fieldReferenceValue: title + name: char_length + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: titleLength + - integerValue: '20' + name: greater_than + name: where + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: CharLength + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.char_length: + - Field: title + - "title_length" + assert_results: + - title_length: 36 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + title_length: + functionValue: + args: + - fieldReferenceValue: title + name: char_length + name: select + - description: ByteLength + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: Douglas Adams + - Select: + - AliasedExpr: + - Expr.byte_length: + - Expr.string_concat: + - Field: title + - Constant: _银河系漫游指南 + - "title_byte_length" + assert_results: + - title_byte_length: 58 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + title_byte_length: + functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "_\u94F6\u6CB3\u7CFB\u6F2B\u6E38\u6307\u5357" + name: string_concat + name: byte_length + name: select + - description: testLike + pipeline: + - Collection: books + - Where: + - Expr.like: + - Field: title + - Constant: "%Guide%" + - Select: + - title + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + - description: testRegexContains + # Find titles that contain either "the" or "of" (case-insensitive) + pipeline: + - Collection: books + - Where: + - Expr.regex_contains: + - Field: title + - Constant: "(?i)(the|of)" + assert_count: 5 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "(?i)(the|of)" + name: regex_contains + name: where + - description: testRegexMatches + # Find titles that contain either "the" or "of" (case-insensitive) + pipeline: + - Collection: books + - Where: + - Expr.regex_match: + - Field: title + - Constant: ".*(?i)(the|of).*" + assert_count: 5 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: ".*(?i)(the|of).*" + name: regex_match + name: where + - description: testStringContains + pipeline: + - Collection: books + - Where: + - Expr.string_contains: + - Field: title + - Constant: "Hitchhiker's" + - Select: + - title + assert_results: + - title: "The Hitchhiker's Guide to the Galaxy" + - description: ToLower + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.to_lower: + - Field: title + - "lower_title" + assert_results: + - lower_title: "the hitchhiker's guide to the galaxy" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + lower_title: + functionValue: + args: + - fieldReferenceValue: title + name: to_lower + name: select + - description: ToUpper + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.to_upper: + - Field: title + - "upper_title" + assert_results: + - upper_title: "THE HITCHHIKER'S GUIDE TO THE GALAXY" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + upper_title: + functionValue: + args: + - fieldReferenceValue: title + name: to_upper + name: select + - description: Trim + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.trim: + - Expr.string_concat: + - Constant: " " + - Field: title + - Constant: " " + - "trimmed_title" + assert_results: + - trimmed_title: "The Hitchhiker's Guide to the Galaxy" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + trimmed_title: + functionValue: + args: + - functionValue: + args: + - stringValue: " " + - fieldReferenceValue: title + - stringValue: " " + name: string_concat + name: trim + name: select + - description: StringReverse + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Jane Austen" + - Select: + - AliasedExpr: + - Expr.string_reverse: + - Field: title + - "reversed_title" + assert_results: + - reversed_title: "ecidujerP dna edirP" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Jane Austen" + name: equal + name: where + - args: + - mapValue: + fields: + reversed_title: + functionValue: + args: + - fieldReferenceValue: title + name: string_reverse + name: select + - description: Substring + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.substring: + - Field: title + - Constant: 4 + - Constant: 11 + - "substring_title" + assert_results: + - substring_title: "Hitchhiker'" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Douglas Adams" + name: equal + name: where + - args: + - mapValue: + fields: + substring_title: + functionValue: + args: + - fieldReferenceValue: title + - integerValue: '4' + - integerValue: '11' + name: substring + name: select + - description: Substring without length + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Fyodor Dostoevsky" + - Select: + - AliasedExpr: + - Expr.substring: + - Field: title + - Constant: 10 + - "substring_title" + assert_results: + - substring_title: "Punishment" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Fyodor Dostoevsky" + name: equal + name: where + - args: + - mapValue: + fields: + substring_title: + functionValue: + args: + - fieldReferenceValue: title + - integerValue: '10' + name: substring + name: select + - description: Join + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.join: + - Field: tags + - Constant: ", " + - "joined_tags" + assert_results: + - joined_tags: "comedy, space, adventure" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Douglas Adams" + name: equal + name: where + - args: + - mapValue: + fields: + joined_tags: + functionValue: + args: + - fieldReferenceValue: tags + - stringValue: ", " + name: join + name: select diff --git a/tests/system/pipeline_e2e/vector.yaml b/tests/system/pipeline_e2e/vector.yaml new file mode 100644 index 000000000..15fc9bcaa --- /dev/null +++ b/tests/system/pipeline_e2e/vector.yaml @@ -0,0 +1,160 @@ +tests: + - description: testVectorLength + pipeline: + - Collection: vectors + - Select: + - AliasedExpr: + - Expr.vector_length: + - Field: embedding + - "embedding_length" + - Sort: + - Ordering: + - Field: embedding_length + - ASCENDING + assert_results: + - embedding_length: 3 + - embedding_length: 3 + - embedding_length: 3 + - embedding_length: 4 + - description: testFindNearestEuclidean + pipeline: + - Collection: vectors + - FindNearest: + field: embedding + vector: [1.0, 2.0, 3.0] + distance_measure: EUCLIDEAN + options: + FindNearestOptions: + limit: 2 + distance_field: + Field: distance + - Select: + - distance + assert_results: + - distance: 0.0 + - distance: 1.0 + assert_proto: + pipeline: + stages: + - name: collection + args: + - referenceValue: /vectors + - name: find_nearest + args: + - fieldReferenceValue: embedding + - mapValue: + fields: + __type__: + stringValue: __vector__ + value: + arrayValue: + values: + - doubleValue: 1.0 + - doubleValue: 2.0 + - doubleValue: 3.0 + - stringValue: euclidean + options: + limit: + integerValue: '2' + distance_field: + fieldReferenceValue: distance + - name: select + args: + - mapValue: + fields: + distance: + fieldReferenceValue: distance + - description: testFindNearestDotProduct + pipeline: + - Collection: vectors + - FindNearest: + field: embedding + vector: [1.0, 2.0, 3.0] + distance_measure: DOT_PRODUCT + options: + FindNearestOptions: + limit: 3 + distance_field: + Field: distance + - Select: + - distance + assert_results: + - distance: 38.0 + - distance: 17.0 + - distance: 14.0 + assert_proto: + pipeline: + stages: + - name: collection + args: + - referenceValue: /vectors + - name: find_nearest + args: + - fieldReferenceValue: embedding + - mapValue: + fields: + __type__: + stringValue: __vector__ + value: + arrayValue: + values: + - doubleValue: 1.0 + - doubleValue: 2.0 + - doubleValue: 3.0 + - stringValue: dot_product + options: + limit: + integerValue: '3' + distance_field: + fieldReferenceValue: distance + - name: select + args: + - mapValue: + fields: + distance: + fieldReferenceValue: distance + - description: testDotProductWithConstant + pipeline: + - Collection: vectors + - Where: + - Expr.equal: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - Select: + - AliasedExpr: + - Expr.dot_product: + - Field: embedding + - Vector: [1.0, 1.0, 1.0] + - "dot_product_result" + assert_results: + - dot_product_result: 6.0 + - description: testEuclideanDistanceWithConstant + pipeline: + - Collection: vectors + - Where: + - Expr.equal: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - Select: + - AliasedExpr: + - Expr.euclidean_distance: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - "euclidean_distance_result" + assert_results: + - euclidean_distance_result: 0.0 + - description: testCosineDistanceWithConstant + pipeline: + - Collection: vectors + - Where: + - Expr.equal: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - Select: + - AliasedExpr: + - Expr.cosine_distance: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - "cosine_distance_result" + assert_results: + - cosine_distance_result: 0.0 diff --git a/tests/system/test__helpers.py b/tests/system/test__helpers.py index 5a93a869e..74b12b7c3 100644 --- a/tests/system/test__helpers.py +++ b/tests/system/test__helpers.py @@ -18,6 +18,7 @@ FIRESTORE_ENTERPRISE_DB = os.environ.get("ENTERPRISE_DATABASE", "enterprise-db") # run all tests against default database, and a named database -# TODO: add enterprise mode when GA (RunQuery not currently supported) TEST_DATABASES = [None, FIRESTORE_OTHER_DB] TEST_DATABASES_W_ENTERPRISE = TEST_DATABASES + [FIRESTORE_ENTERPRISE_DB] +# TODO remove when kokoro fully supports enterprise mode/pipelines +IS_KOKORO_TEST = os.getenv("KOKORO_JOB_NAME") is not None diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 682fe5e23..c7eaa6aff 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -33,26 +33,54 @@ from google.cloud.firestore import Client, AsyncClient -from test__helpers import FIRESTORE_ENTERPRISE_DB +from test__helpers import FIRESTORE_ENTERPRISE_DB, IS_KOKORO_TEST FIRESTORE_PROJECT = os.environ.get("GCLOUD_PROJECT") +# TODO: enable kokoro tests when internal test project is whitelisted +pytestmark = pytest.mark.skipif( + condition=IS_KOKORO_TEST, + reason="Pipeline tests are currently not supported by kokoro", +) + test_dir_name = os.path.dirname(__file__) +id_format = ( + lambda x: f"{x.get('file_name', '')}: {x.get('description', '')}" +) # noqa: E731 + -def yaml_loader(field="tests", file_name="pipeline_e2e.yaml"): +def yaml_loader(field="tests", dir_name="pipeline_e2e", attach_file_name=True): """ Helper to load test cases or data from yaml file """ - with open(f"{test_dir_name}/{file_name}") as f: - test_cases = yaml.safe_load(f) - return test_cases[field] + combined_yaml = None + for file_name in os.listdir(f"{test_dir_name}/{dir_name}"): + with open(f"{test_dir_name}/{dir_name}/{file_name}") as f: + new_yaml = yaml.safe_load(f) + assert new_yaml is not None, f"found empty yaml in {file_name}" + extracted = new_yaml.get(field, None) + # attach file_name field + if attach_file_name: + if isinstance(extracted, list): + for item in extracted: + item["file_name"] = file_name + elif isinstance(extracted, dict): + extracted["file_name"] = file_name + # aggregate files + if not combined_yaml: + combined_yaml = extracted + elif isinstance(combined_yaml, dict) and extracted: + combined_yaml.update(extracted) + elif isinstance(combined_yaml, list) and extracted: + combined_yaml.extend(extracted) + return combined_yaml @pytest.mark.parametrize( "test_dict", [t for t in yaml_loader() if "assert_proto" in t], - ids=lambda x: f"{x.get('description', '')}", + ids=id_format, ) def test_pipeline_parse_proto(test_dict, client): """ @@ -69,7 +97,7 @@ def test_pipeline_parse_proto(test_dict, client): @pytest.mark.parametrize( "test_dict", [t for t in yaml_loader() if "assert_error" in t], - ids=lambda x: f"{x.get('description', '')}", + ids=id_format, ) def test_pipeline_expected_errors(test_dict, client): """ @@ -94,7 +122,7 @@ def test_pipeline_expected_errors(test_dict, client): or "assert_count" in t or "assert_results_approximate" in t ], - ids=lambda x: f"{x.get('description', '')}", + ids=id_format, ) def test_pipeline_results(test_dict, client): """ @@ -125,7 +153,7 @@ def test_pipeline_results(test_dict, client): @pytest.mark.parametrize( "test_dict", [t for t in yaml_loader() if "assert_error" in t], - ids=lambda x: f"{x.get('description', '')}", + ids=id_format, ) @pytest.mark.asyncio async def test_pipeline_expected_errors_async(test_dict, async_client): @@ -151,7 +179,7 @@ async def test_pipeline_expected_errors_async(test_dict, async_client): or "assert_count" in t or "assert_results_approximate" in t ], - ids=lambda x: f"{x.get('description', '')}", + ids=id_format, ) @pytest.mark.asyncio async def test_pipeline_results_async(test_dict, async_client): @@ -332,7 +360,7 @@ def client(): Build a client to use for requests """ client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_ENTERPRISE_DB) - data = yaml_loader("data") + data = yaml_loader("data", attach_file_name=False) to_delete = [] try: # setup data diff --git a/tests/system/test_system.py b/tests/system/test_system.py index c2bd93ef8..ce4f64b32 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -45,6 +45,7 @@ ENTERPRISE_MODE_ERROR, TEST_DATABASES, TEST_DATABASES_W_ENTERPRISE, + IS_KOKORO_TEST, ) @@ -64,6 +65,12 @@ def _get_credentials_and_project(): @pytest.fixture(scope="session") def database(request): + from test__helpers import FIRESTORE_ENTERPRISE_DB + + # enterprise mode currently does not support RunQuery calls in prod on kokoro test project + # TODO: remove skip when kokoro test project supports full enterprise mode + if request.param == FIRESTORE_ENTERPRISE_DB and IS_KOKORO_TEST: + pytest.skip("enterprise mode does not support RunQuery on kokoro") return request.param @@ -92,6 +99,11 @@ def verify_pipeline(query): """ from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + # return early on kokoro. Test project doesn't currently support pipelines + # TODO: enable pipeline verification when kokoro test project is whitelisted + if IS_KOKORO_TEST: + pytest.skip("skipping pipeline verification on kokoro") + def _clean_results(results): if isinstance(results, dict): return {k: _clean_results(v) for k, v in results.items()} @@ -2203,7 +2215,7 @@ def on_snapshot(docs, changes, read_time): ) -@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_watch_query(client, cleanup, database): db = client collection_ref = db.collection("wq-users" + UNIQUE_RESOURCE_ID) @@ -2224,7 +2236,6 @@ def on_snapshot(docs, changes, read_time): query_ran_query = collection_ref.where(filter=FieldFilter("first", "==", "Ada")) query_ran = query_ran_query.stream() assert len(docs) == len([i for i in query_ran]) - verify_pipeline(query_ran_query) on_snapshot.called_count = 0 @@ -2565,7 +2576,7 @@ def test_chunked_and_recursive(client, cleanup, database): assert [doc.id for doc in next(iter)] == page_3_ids -@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_watch_query_order(client, cleanup, database): db = client collection_ref = db.collection("users") @@ -2602,7 +2613,6 @@ def on_snapshot(docs, changes, read_time): ), "expect the sort order to match, born" on_snapshot.called_count += 1 on_snapshot.last_doc_count = len(docs) - verify_pipeline(query_ref) except Exception as e: on_snapshot.failed = e diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index d053cbd7a..ed679402a 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -56,6 +56,7 @@ ENTERPRISE_MODE_ERROR, TEST_DATABASES, TEST_DATABASES_W_ENTERPRISE, + IS_KOKORO_TEST, ) RETRIES = retries.AsyncRetry( @@ -142,6 +143,12 @@ def _verify_explain_metrics_analyze_false(explain_metrics): @pytest.fixture(scope="session") def database(request): + from test__helpers import FIRESTORE_ENTERPRISE_DB + + # enterprise mode currently does not support RunQuery calls in prod on kokoro test project + # TODO: remove skip when kokoro test project supports full enterprise mode + if request.param == FIRESTORE_ENTERPRISE_DB and IS_KOKORO_TEST: + pytest.skip("enterprise mode does not support RunQuery on kokoro") return request.param @@ -172,6 +179,11 @@ async def verify_pipeline(query): """ from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + # return early on kokoro. Test project doesn't currently support pipelines + # TODO: enable pipeline verification when kokoro test project is whitelisted + if IS_KOKORO_TEST: + pytest.skip("skipping pipeline verification on kokoro") + def _clean_results(results): if isinstance(results, dict): return {k: _clean_results(v) for k, v in results.items()} diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 9a20fd386..66239f9ea 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -127,12 +127,12 @@ def test_avg_aggregation_no_alias_to_pb(): "in_alias,expected_alias", [("total", "total"), (None, "field_1")] ) def test_count_aggregation_to_pipeline_expr(in_alias, expected_alias): - from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias + from google.cloud.firestore_v1.pipeline_expressions import AliasedAggregate from google.cloud.firestore_v1.pipeline_expressions import Count count_aggregation = CountAggregation(alias=in_alias) got = count_aggregation._to_pipeline_expr(iter([1])) - assert isinstance(got, ExprWithAlias) + assert isinstance(got, AliasedAggregate) assert got.alias == expected_alias assert isinstance(got.expr, Count) assert len(got.expr.params) == 0 @@ -143,14 +143,13 @@ def test_count_aggregation_to_pipeline_expr(in_alias, expected_alias): [("total", "path", "total"), (None, "some_ref", "field_1")], ) def test_sum_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alias): - from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias - from google.cloud.firestore_v1.pipeline_expressions import Sum + from google.cloud.firestore_v1.pipeline_expressions import AliasedAggregate count_aggregation = SumAggregation(expected_path, alias=in_alias) got = count_aggregation._to_pipeline_expr(iter([1])) - assert isinstance(got, ExprWithAlias) + assert isinstance(got, AliasedAggregate) assert got.alias == expected_alias - assert isinstance(got.expr, Sum) + assert got.expr.name == "sum" assert got.expr.params[0].path == expected_path @@ -159,14 +158,13 @@ def test_sum_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alia [("total", "path", "total"), (None, "some_ref", "field_1")], ) def test_avg_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alias): - from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias - from google.cloud.firestore_v1.pipeline_expressions import Avg + from google.cloud.firestore_v1.pipeline_expressions import AliasedAggregate count_aggregation = AvgAggregation(expected_path, alias=in_alias) got = count_aggregation._to_pipeline_expr(iter([1])) - assert isinstance(got, ExprWithAlias) + assert isinstance(got, AliasedAggregate) assert got.alias == expected_alias - assert isinstance(got.expr, Avg) + assert got.expr.name == "average" assert got.expr.params[0].path == expected_path @@ -1036,7 +1034,6 @@ def test_aggregation_from_query(): def test_aggreation_to_pipeline_sum(field, in_alias, out_alias): from google.cloud.firestore_v1.pipeline import Pipeline from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate - from google.cloud.firestore_v1.pipeline_expressions import Sum client = make_client() parent = client.collection("dee") @@ -1051,7 +1048,7 @@ def test_aggreation_to_pipeline_sum(field, in_alias, out_alias): aggregate_stage = pipeline.stages[1] assert isinstance(aggregate_stage, Aggregate) assert len(aggregate_stage.accumulators) == 1 - assert isinstance(aggregate_stage.accumulators[0].expr, Sum) + assert aggregate_stage.accumulators[0].expr.name == "sum" expected_field = field if isinstance(field, str) else field.to_api_repr() assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field assert aggregate_stage.accumulators[0].alias == out_alias @@ -1068,7 +1065,6 @@ def test_aggreation_to_pipeline_sum(field, in_alias, out_alias): def test_aggreation_to_pipeline_avg(field, in_alias, out_alias): from google.cloud.firestore_v1.pipeline import Pipeline from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate - from google.cloud.firestore_v1.pipeline_expressions import Avg client = make_client() parent = client.collection("dee") @@ -1083,7 +1079,7 @@ def test_aggreation_to_pipeline_avg(field, in_alias, out_alias): aggregate_stage = pipeline.stages[1] assert isinstance(aggregate_stage, Aggregate) assert len(aggregate_stage.accumulators) == 1 - assert isinstance(aggregate_stage.accumulators[0].expr, Avg) + assert aggregate_stage.accumulators[0].expr.name == "average" expected_field = field if isinstance(field, str) else field.to_api_repr() assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field assert aggregate_stage.accumulators[0].alias == out_alias @@ -1142,7 +1138,6 @@ def test_aggreation_to_pipeline_count_increment(): def test_aggreation_to_pipeline_complex(): from google.cloud.firestore_v1.pipeline import Pipeline from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate, Select - from google.cloud.firestore_v1.pipeline_expressions import Sum, Avg, Count client = make_client() query = client.collection("my_col").select(["field_a", "field_b.c"]) @@ -1159,11 +1154,11 @@ def test_aggreation_to_pipeline_complex(): assert isinstance(pipeline.stages[2], Aggregate) aggregate_stage = pipeline.stages[2] assert len(aggregate_stage.accumulators) == 4 - assert isinstance(aggregate_stage.accumulators[0].expr, Sum) + assert aggregate_stage.accumulators[0].expr.name == "sum" assert aggregate_stage.accumulators[0].alias == "alias" - assert isinstance(aggregate_stage.accumulators[1].expr, Count) + assert aggregate_stage.accumulators[1].expr.name == "count" assert aggregate_stage.accumulators[1].alias == "field_1" - assert isinstance(aggregate_stage.accumulators[2].expr, Avg) + assert aggregate_stage.accumulators[2].expr.name == "average" assert aggregate_stage.accumulators[2].alias == "field_2" - assert isinstance(aggregate_stage.accumulators[3].expr, Sum) + assert aggregate_stage.accumulators[3].expr.name == "sum" assert aggregate_stage.accumulators[3].alias == "field_3" diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index 701feab5b..f51db482d 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -710,7 +710,6 @@ async def test_aggregation_query_stream_w_explain_options_analyze_false(): def test_async_aggreation_to_pipeline_sum(field, in_alias, out_alias): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate - from google.cloud.firestore_v1.pipeline_expressions import Sum client = make_async_client() parent = client.collection("dee") @@ -725,7 +724,7 @@ def test_async_aggreation_to_pipeline_sum(field, in_alias, out_alias): aggregate_stage = pipeline.stages[1] assert isinstance(aggregate_stage, Aggregate) assert len(aggregate_stage.accumulators) == 1 - assert isinstance(aggregate_stage.accumulators[0].expr, Sum) + assert aggregate_stage.accumulators[0].expr.name == "sum" expected_field = field if isinstance(field, str) else field.to_api_repr() assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field assert aggregate_stage.accumulators[0].alias == out_alias @@ -742,7 +741,6 @@ def test_async_aggreation_to_pipeline_sum(field, in_alias, out_alias): def test_async_aggreation_to_pipeline_avg(field, in_alias, out_alias): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate - from google.cloud.firestore_v1.pipeline_expressions import Avg client = make_async_client() parent = client.collection("dee") @@ -757,7 +755,7 @@ def test_async_aggreation_to_pipeline_avg(field, in_alias, out_alias): aggregate_stage = pipeline.stages[1] assert isinstance(aggregate_stage, Aggregate) assert len(aggregate_stage.accumulators) == 1 - assert isinstance(aggregate_stage.accumulators[0].expr, Avg) + assert aggregate_stage.accumulators[0].expr.name == "average" expected_field = field if isinstance(field, str) else field.to_api_repr() assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field assert aggregate_stage.accumulators[0].alias == out_alias @@ -816,7 +814,6 @@ def test_aggreation_to_pipeline_count_increment(): def test_async_aggreation_to_pipeline_complex(): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate, Select - from google.cloud.firestore_v1.pipeline_expressions import Sum, Avg, Count client = make_async_client() query = client.collection("my_col").select(["field_a", "field_b.c"]) @@ -833,11 +830,11 @@ def test_async_aggreation_to_pipeline_complex(): assert isinstance(pipeline.stages[2], Aggregate) aggregate_stage = pipeline.stages[2] assert len(aggregate_stage.accumulators) == 4 - assert isinstance(aggregate_stage.accumulators[0].expr, Sum) + assert aggregate_stage.accumulators[0].expr.name == "sum" assert aggregate_stage.accumulators[0].alias == "alias" - assert isinstance(aggregate_stage.accumulators[1].expr, Count) + assert aggregate_stage.accumulators[1].expr.name == "count" assert aggregate_stage.accumulators[1].alias == "field_1" - assert isinstance(aggregate_stage.accumulators[2].expr, Avg) + assert aggregate_stage.accumulators[2].expr.name == "average" assert aggregate_stage.accumulators[2].alias == "field_2" - assert isinstance(aggregate_stage.accumulators[3].expr, Sum) + assert aggregate_stage.accumulators[3].expr.name == "sum" assert aggregate_stage.accumulators[3].alias == "field_3" diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py index b3ed83337..a11a2951b 100644 --- a/tests/unit/v1/test_async_pipeline.py +++ b/tests/unit/v1/test_async_pipeline.py @@ -386,10 +386,10 @@ async def test_async_pipeline_stream_stream_equivalence_mocked(): ("select", ("name",), stages.Select), ("select", (Field.of("n"),), stages.Select), ("where", (Field.of("n").exists(),), stages.Where), - ("find_nearest", ("name", [0.1], 0), stages.FindNearest), + ("find_nearest", ("name", [0.1], "cosine"), stages.FindNearest), ( "find_nearest", - ("name", [0.1], 0, stages.FindNearestOptions(10)), + ("name", [0.1], "cosine", stages.FindNearestOptions(10)), stages.FindNearest, ), ("sort", (Field.of("n").descending(),), stages.Sort), diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 9bb3e61f8..7efa0dacf 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -2040,9 +2040,7 @@ def test__query_pipeline_composite_filter(): client = make_client() in_filter = FieldFilter("field_a", "==", "value_a") query = client.collection("my_col").where(filter=in_filter) - with mock.patch.object( - expr.FilterCondition, "_from_query_filter_pb" - ) as convert_mock: + with mock.patch.object(expr.BooleanExpr, "_from_query_filter_pb") as convert_mock: pipeline = query.pipeline() convert_mock.assert_called_once_with(in_filter._to_pb(), client) assert len(pipeline.stages) == 2 @@ -2080,15 +2078,13 @@ def test__query_pipeline_order_exists_multiple(): assert isinstance(where_stage.condition, expr.And) assert len(where_stage.condition.params) == 2 operands = [p for p in where_stage.condition.params] - assert isinstance(operands[0], expr.Exists) + assert operands[0].name == "exists" assert operands[0].params[0].path == "field_a" - assert isinstance(operands[1], expr.Exists) + assert operands[1].name == "exists" assert operands[1].params[0].path == "field_b" def test__query_pipeline_order_exists_single(): - from google.cloud.firestore_v1 import pipeline_expressions as expr - client = make_client() query_single = client.collection("my_col").order_by("field_c") pipeline_single = query_single.pipeline() @@ -2098,7 +2094,7 @@ def test__query_pipeline_order_exists_single(): assert len(pipeline_single.stages) == 3 where_stage_single = pipeline_single.stages[1] assert isinstance(where_stage_single, stages.Where) - assert isinstance(where_stage_single.condition, expr.Exists) + assert where_stage_single.condition.name == "exists" assert where_stage_single.condition.params[0].path == "field_c" diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index f90279e00..161eef1cc 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -363,10 +363,10 @@ def test_pipeline_execute_stream_equivalence_mocked(): ("select", ("name",), stages.Select), ("select", (Field.of("n"),), stages.Select), ("where", (Field.of("n").exists(),), stages.Where), - ("find_nearest", ("name", [0.1], 0), stages.FindNearest), + ("find_nearest", ("name", [0.1], "cosine"), stages.FindNearest), ( "find_nearest", - ("name", [0.1], 0, stages.FindNearestOptions(10)), + ("name", [0.1], "cosine", stages.FindNearestOptions(10)), stages.FindNearest, ), ("sort", (Field.of("n").descending(),), stages.Sort), diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index aec721e7d..522b51c84 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -645,6 +645,14 @@ def test_w_exprs(self): ) +class TestFunction: + def test_equals(self): + assert expr.Function.sqrt("1") == expr.Function.sqrt("1") + assert expr.Function.sqrt("1") != expr.Function.sqrt("2") + assert expr.Function.sqrt("1") != expr.Function.sum("1") + assert expr.Function.sqrt("1") != object() + + class TestExpressionMethods: """ contains test methods for each Expr method @@ -692,6 +700,22 @@ def __repr__(self): arg = MockExpr(name) return arg + def test_expression_wrong_first_type(self): + """The first argument should always be an expression or field name""" + expected_message = "must be called on an Expression or a string representing a field. got ." + with pytest.raises(TypeError) as e1: + Expr.logical_minimum(5, 1) + assert str(e1.value) == f"'logical_minimum' {expected_message}" + with pytest.raises(TypeError) as e2: + Expr.sqrt(9) + assert str(e2.value) == f"'sqrt' {expected_message}" + + def test_expression_w_string(self): + """should be able to use string for first argument. Should be interpreted as Field""" + instance = Expr.logical_minimum("first", "second") + assert isinstance(instance.params[0], Field) + assert instance.params[0].path == "first" + def test_and(self): arg1 = self._make_arg() arg2 = self._make_arg() diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py index fadea7e12..1d2ff8760 100644 --- a/tests/unit/v1/test_pipeline_stages.py +++ b/tests/unit/v1/test_pipeline_stages.py @@ -274,12 +274,11 @@ def test_to_pb(self): instance = self._make_one("/projects/p/databases/d/documents/c/doc1", "/c/doc2") result = instance._to_pb() assert result.name == "documents" - assert len(result.args) == 1 + assert len(result.args) == 2 assert ( - result.args[0].array_value.values[0].string_value - == "/projects/p/databases/d/documents/c/doc1" + result.args[0].reference_value == "/projects/p/databases/d/documents/c/doc1" ) - assert result.args[0].array_value.values[1].string_value == "/c/doc2" + assert result.args[1].reference_value == "/c/doc2" assert len(result.options) == 0 @@ -461,10 +460,22 @@ def _make_one(self, *args, **kwargs): ), ], ) - def test_ctor(self, input_args, expected_params): + def test_ctor_with_params(self, input_args, expected_params): instance = self._make_one(*input_args) assert instance.params == expected_params + def test_ctor_with_options(self): + options = {"index_field": Field.of("index")} + field = Field.of("field") + alias = Field.of("alias") + standard_unnest = stages.Unnest( + field, alias, options=stages.UnnestOptions(**options) + ) + generic_unnest = stages.GenericStage("unnest", field, alias, options=options) + assert standard_unnest._pb_args() == generic_unnest._pb_args() + assert standard_unnest._pb_options() == generic_unnest._pb_options() + assert standard_unnest._to_pb() == generic_unnest._to_pb() + @pytest.mark.parametrize( "input_args,expected", [ @@ -709,7 +720,8 @@ def _make_one_options(self, *args, **kwargs): def test_ctor_options(self): index_field_val = "my_index" instance = self._make_one_options(index_field=index_field_val) - assert instance.index_field == index_field_val + assert isinstance(instance.index_field, Field) + assert instance.index_field.path == index_field_val def test_repr(self): instance = self._make_one_options(index_field="my_idx") @@ -781,7 +793,7 @@ def test_to_pb_full(self): assert result.args[1].field_reference_value == alias_str assert len(result.options) == 1 - assert result.options["index_field"].string_value == "item_index" + assert result.options["index_field"].field_reference_value == "item_index" class TestWhere: From b5b0bd75cc6f46d3a9f959d73c718da7099ebbbf Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 30 Oct 2025 16:23:34 -0700 Subject: [PATCH 08/27] chore: Pipeline queries cleanup (#1118) --- google/cloud/firestore_v1/async_pipeline.py | 2 +- google/cloud/firestore_v1/base_aggregation.py | 4 +- google/cloud/firestore_v1/base_pipeline.py | 51 +- google/cloud/firestore_v1/base_query.py | 2 +- google/cloud/firestore_v1/pipeline.py | 2 +- .../firestore_v1/pipeline_expressions.py | 499 +++++++++--------- google/cloud/firestore_v1/pipeline_source.py | 2 +- ..._pipeline_stages.py => pipeline_stages.py} | 29 +- tests/system/pipeline_e2e/aggregates.yaml | 67 ++- tests/system/pipeline_e2e/array.yaml | 46 +- tests/system/pipeline_e2e/date_and_time.yaml | 42 +- tests/system/pipeline_e2e/general.yaml | 168 +----- tests/system/pipeline_e2e/logical.yaml | 88 +-- tests/system/pipeline_e2e/map.yaml | 32 +- tests/system/pipeline_e2e/math.yaml | 84 +-- tests/system/pipeline_e2e/string.yaml | 104 ++-- tests/system/pipeline_e2e/vector.yaml | 22 +- tests/system/test_pipeline_acceptance.py | 2 +- tests/unit/v1/test_aggregation.py | 20 +- tests/unit/v1/test_async_aggregation.py | 8 +- tests/unit/v1/test_async_collection.py | 2 +- tests/unit/v1/test_async_pipeline.py | 30 +- tests/unit/v1/test_base_query.py | 6 +- tests/unit/v1/test_collection.py | 2 +- tests/unit/v1/test_pipeline.py | 30 +- tests/unit/v1/test_pipeline_expressions.py | 315 +++++------ tests/unit/v1/test_pipeline_source.py | 2 +- tests/unit/v1/test_pipeline_stages.py | 12 +- 28 files changed, 767 insertions(+), 906 deletions(-) rename google/cloud/firestore_v1/{_pipeline_stages.py => pipeline_stages.py} (95%) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 471c33093..9fe0c8756 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -14,7 +14,7 @@ from __future__ import annotations from typing import AsyncIterable, TYPE_CHECKING -from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.base_pipeline import _BasePipeline if TYPE_CHECKING: # pragma: NO COVER diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index 3dd7a453e..d8d7cc6b4 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -36,7 +36,7 @@ ) from google.cloud.firestore_v1.pipeline_expressions import AggregateFunction from google.cloud.firestore_v1.pipeline_expressions import Count -from google.cloud.firestore_v1.pipeline_expressions import AliasedExpr +from google.cloud.firestore_v1.pipeline_expressions import AliasedExpression from google.cloud.firestore_v1.pipeline_expressions import Field # Types needed only for Type Hints @@ -86,7 +86,7 @@ def _to_protobuf(self): @abc.abstractmethod def _to_pipeline_expr( self, autoindexer: Iterable[int] - ) -> AliasedExpr[AggregateFunction]: + ) -> AliasedExpression[AggregateFunction]: """ Convert this instance to a pipeline expression for use with pipeline.aggregate() diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 01f48ee78..63fee19fa 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -14,7 +14,7 @@ from __future__ import annotations from typing import Iterable, Sequence, TYPE_CHECKING -from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.types.pipeline import ( StructuredPipeline as StructuredPipeline_pb, ) @@ -23,10 +23,11 @@ from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest from google.cloud.firestore_v1.pipeline_result import PipelineResult from google.cloud.firestore_v1.pipeline_expressions import ( - AliasedAggregate, - Expr, + AggregateFunction, + AliasedExpression, + Expression, Field, - BooleanExpr, + BooleanExpression, Selectable, ) from google.cloud.firestore_v1 import _helpers @@ -146,7 +147,7 @@ def add_fields(self, *fields: Selectable) -> "_BasePipeline": The added fields are defined using `Selectable` expressions, which can be: - `Field`: References an existing document field. - `Function`: Performs a calculation using functions like `add`, - `multiply` with assigned aliases using `Expr.as_()`. + `multiply` with assigned aliases using `Expression.as_()`. Example: >>> from google.cloud.firestore_v1.pipeline_expressions import Field, add @@ -193,7 +194,7 @@ def select(self, *selections: str | Selectable) -> "_BasePipeline": The selected fields are defined using `Selectable` expressions or field names: - `Field`: References an existing document field. - `Function`: Represents the result of a function with an assigned alias - name using `Expr.as_()`. + name using `Expression.as_()`. - `str`: The name of an existing field. If no selections are provided, the output of this stage is empty. Use @@ -219,14 +220,14 @@ def select(self, *selections: str | Selectable) -> "_BasePipeline": """ return self._append(stages.Select(*selections)) - def where(self, condition: BooleanExpr) -> "_BasePipeline": + def where(self, condition: BooleanExpression) -> "_BasePipeline": """ Filters the documents from previous stages to only include those matching - the specified `BooleanExpr`. + the specified `BooleanExpression`. This stage allows you to apply conditions to the data, similar to a "WHERE" clause in SQL. You can filter documents based on their field values, using - implementations of `BooleanExpr`, typically including but not limited to: + implementations of `BooleanExpression`, typically including but not limited to: - field comparators: `eq`, `lt` (less than), `gt` (greater than), etc. - logical operators: `And`, `Or`, `Not`, etc. - advanced functions: `regex_matches`, `array_contains`, etc. @@ -251,7 +252,7 @@ def where(self, condition: BooleanExpr) -> "_BasePipeline": Args: - condition: The `BooleanExpr` to apply. + condition: The `BooleanExpression` to apply. Returns: A new Pipeline object with this stage appended to the stage list @@ -260,7 +261,7 @@ def where(self, condition: BooleanExpr) -> "_BasePipeline": def find_nearest( self, - field: str | Expr, + field: str | Expression, vector: Sequence[float] | "Vector", distance_measure: "DistanceMeasure", options: stages.FindNearestOptions | None = None, @@ -297,7 +298,7 @@ def find_nearest( ... ) Args: - field: The name of the field (str) or an expression (`Expr`) that + field: The name of the field (str) or an expression (`Expression`) that evaluates to the vector data. This field should store vector values. vector: The target vector (sequence of floats or `Vector` object) to compare against. @@ -457,28 +458,29 @@ def unnest( """ return self._append(stages.Unnest(field, alias, options)) - def generic_stage(self, name: str, *params: Expr) -> "_BasePipeline": + def raw_stage(self, name: str, *params: Expression) -> "_BasePipeline": """ - Adds a generic, named stage to the pipeline with specified parameters. + Adds a stage to the pipeline by specifying the stage name as an argument. This does not offer any + type safety on the stage params and requires the caller to know the order (and optionally names) + of parameters accepted by the stage. - This method provides a flexible way to extend the pipeline's functionality - by adding custom stages. Each generic stage is defined by a unique `name` - and a set of `params` that control its behavior. + This class provides a way to call stages that are supported by the Firestore backend but that + are not implemented in the SDK version being used. Example: >>> # Assume we don't have a built-in "where" stage >>> pipeline = client.pipeline().collection("books") - >>> pipeline = pipeline.generic_stage("where", [Field.of("published").lt(900)]) + >>> pipeline = pipeline.raw_stage("where", Field.of("published").lt(900)) >>> pipeline = pipeline.select("title", "author") Args: - name: The name of the generic stage. - *params: A sequence of `Expr` objects representing the parameters for the stage. + name: The name of the stage. + *params: A sequence of `Expression` objects representing the parameters for the stage. Returns: A new Pipeline object with this stage appended to the stage list """ - return self._append(stages.GenericStage(name, *params)) + return self._append(stages.RawStage(name, *params)) def offset(self, offset: int) -> "_BasePipeline": """ @@ -530,7 +532,7 @@ def limit(self, limit: int) -> "_BasePipeline": def aggregate( self, - *accumulators: AliasedAggregate, + *accumulators: AliasedExpression[AggregateFunction], groups: Sequence[str | Selectable] = (), ) -> "_BasePipeline": """ @@ -546,7 +548,6 @@ def aggregate( - **Groups:** Optionally specify fields (by name or `Selectable`) to group the documents by. Aggregations are then performed within each distinct group. If no groups are provided, the aggregation is performed over the entire input. - Example: >>> from google.cloud.firestore_v1.pipeline_expressions import Field >>> pipeline = client.pipeline().collection("books") @@ -568,8 +569,8 @@ def aggregate( Args: - *accumulators: One or more `AliasedAggregate` expressions defining - the aggregations to perform and their output names. + *accumulators: One or more expressions defining the aggregations to perform and their + corresponding output names. groups: An optional sequence of field names (str) or `Selectable` expressions to group by before aggregating. diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 797572b1b..0f4347e5f 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -1151,7 +1151,7 @@ def pipeline(self): # Filters for filter_ in self._field_filters: ppl = ppl.where( - pipeline_expressions.BooleanExpr._from_query_filter_pb( + pipeline_expressions.BooleanExpression._from_query_filter_pb( filter_, self._client ) ) diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 9f568f925..f578e00b6 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -14,7 +14,7 @@ from __future__ import annotations from typing import Iterable, TYPE_CHECKING -from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.base_pipeline import _BasePipeline if TYPE_CHECKING: # pragma: NO COVER diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 30f3de995..780fe8e8e 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -56,12 +56,12 @@ def __init__(self, expr, order_dir: Direction | str = Direction.ASCENDING): Initializes an Ordering instance Args: - expr (Expr | str): The expression or field path string to sort by. + expr (Expression | str): The expression or field path string to sort by. If a string is provided, it's treated as a field path. order_dir (Direction | str): The direction to sort in. Defaults to ascending """ - self.expr = expr if isinstance(expr, Expr) else Field.of(expr) + self.expr = expr if isinstance(expr, Expression) else Field.of(expr) self.order_dir = ( Ordering.Direction[order_dir.upper()] if isinstance(order_dir, str) @@ -86,11 +86,11 @@ def _to_pb(self) -> Value: ) -class Expr(ABC): +class Expression(ABC): """Represents an expression that can be evaluated to a value within the execution of a pipeline. - Expressions are the building blocks for creating complex queries and + Expressionessions are the building blocks for creating complex queries and transformations in Firestore pipelines. They can represent: - **Field references:** Access values from document fields. @@ -98,7 +98,7 @@ class Expr(ABC): - **Function calls:** Apply functions to one or more expressions. - **Aggregations:** Calculate aggregate values (e.g., sum, average) over a set of documents. - The `Expr` class provides a fluent API for building expressions. You can chain + The `Expression` class provides a fluent API for building expressions. You can chain together method calls to create complex expressions. """ @@ -110,9 +110,11 @@ def _to_pb(self) -> Value: raise NotImplementedError @staticmethod - def _cast_to_expr_or_convert_to_constant(o: Any, include_vector=False) -> "Expr": - """Convert arbitrary object to an Expr.""" - if isinstance(o, Expr): + def _cast_to_expr_or_convert_to_constant( + o: Any, include_vector=False + ) -> "Expression": + """Convert arbitrary object to an Expression.""" + if isinstance(o, Expression): return o if isinstance(o, dict): return Map(o) @@ -139,12 +141,14 @@ def __init__(self, instance_func): self.instance_func = instance_func def static_func(self, first_arg, *other_args, **kwargs): - if not isinstance(first_arg, (Expr, str)): + if not isinstance(first_arg, (Expression, str)): raise TypeError( f"'{self.instance_func.__name__}' must be called on an Expression or a string representing a field. got {type(first_arg)}." ) first_expr = ( - Field.of(first_arg) if not isinstance(first_arg, Expr) else first_arg + Field.of(first_arg) + if not isinstance(first_arg, Expression) + else first_arg ) return self.instance_func(first_expr, *other_args, **kwargs) @@ -155,7 +159,7 @@ def __get__(self, instance, owner): return self.instance_func.__get__(instance, owner) @expose_as_static - def add(self, other: Expr | float) -> "Expr": + def add(self, other: Expression | float) -> "Expression": """Creates an expression that adds this expression to another expression or constant. Example: @@ -168,12 +172,12 @@ def add(self, other: Expr | float) -> "Expr": other: The expression or constant value to add to this expression. Returns: - A new `Expr` representing the addition operation. + A new `Expression` representing the addition operation. """ return Function("add", [self, self._cast_to_expr_or_convert_to_constant(other)]) @expose_as_static - def subtract(self, other: Expr | float) -> "Expr": + def subtract(self, other: Expression | float) -> "Expression": """Creates an expression that subtracts another expression or constant from this expression. Example: @@ -186,14 +190,14 @@ def subtract(self, other: Expr | float) -> "Expr": other: The expression or constant value to subtract from this expression. Returns: - A new `Expr` representing the subtraction operation. + A new `Expression` representing the subtraction operation. """ return Function( "subtract", [self, self._cast_to_expr_or_convert_to_constant(other)] ) @expose_as_static - def multiply(self, other: Expr | float) -> "Expr": + def multiply(self, other: Expression | float) -> "Expression": """Creates an expression that multiplies this expression by another expression or constant. Example: @@ -206,14 +210,14 @@ def multiply(self, other: Expr | float) -> "Expr": other: The expression or constant value to multiply by. Returns: - A new `Expr` representing the multiplication operation. + A new `Expression` representing the multiplication operation. """ return Function( "multiply", [self, self._cast_to_expr_or_convert_to_constant(other)] ) @expose_as_static - def divide(self, other: Expr | float) -> "Expr": + def divide(self, other: Expression | float) -> "Expression": """Creates an expression that divides this expression by another expression or constant. Example: @@ -226,14 +230,14 @@ def divide(self, other: Expr | float) -> "Expr": other: The expression or constant value to divide by. Returns: - A new `Expr` representing the division operation. + A new `Expression` representing the division operation. """ return Function( "divide", [self, self._cast_to_expr_or_convert_to_constant(other)] ) @expose_as_static - def mod(self, other: Expr | float) -> "Expr": + def mod(self, other: Expression | float) -> "Expression": """Creates an expression that calculates the modulo (remainder) to another expression or constant. Example: @@ -246,12 +250,12 @@ def mod(self, other: Expr | float) -> "Expr": other: The divisor expression or constant. Returns: - A new `Expr` representing the modulo operation. + A new `Expression` representing the modulo operation. """ return Function("mod", [self, self._cast_to_expr_or_convert_to_constant(other)]) @expose_as_static - def abs(self) -> "Expr": + def abs(self) -> "Expression": """Creates an expression that calculates the absolute value of this expression. Example: @@ -259,12 +263,12 @@ def abs(self) -> "Expr": >>> Field.of("change").abs() Returns: - A new `Expr` representing the absolute value. + A new `Expression` representing the absolute value. """ return Function("abs", [self]) @expose_as_static - def ceil(self) -> "Expr": + def ceil(self) -> "Expression": """Creates an expression that calculates the ceiling of this expression. Example: @@ -272,12 +276,12 @@ def ceil(self) -> "Expr": >>> Field.of("value").ceil() Returns: - A new `Expr` representing the ceiling value. + A new `Expression` representing the ceiling value. """ return Function("ceil", [self]) @expose_as_static - def exp(self) -> "Expr": + def exp(self) -> "Expression": """Creates an expression that computes e to the power of this expression. Example: @@ -285,12 +289,12 @@ def exp(self) -> "Expr": >>> Field.of("value").exp() Returns: - A new `Expr` representing the exponential value. + A new `Expression` representing the exponential value. """ return Function("exp", [self]) @expose_as_static - def floor(self) -> "Expr": + def floor(self) -> "Expression": """Creates an expression that calculates the floor of this expression. Example: @@ -298,12 +302,12 @@ def floor(self) -> "Expr": >>> Field.of("value").floor() Returns: - A new `Expr` representing the floor value. + A new `Expression` representing the floor value. """ return Function("floor", [self]) @expose_as_static - def ln(self) -> "Expr": + def ln(self) -> "Expression": """Creates an expression that calculates the natural logarithm of this expression. Example: @@ -311,12 +315,12 @@ def ln(self) -> "Expr": >>> Field.of("value").ln() Returns: - A new `Expr` representing the natural logarithm. + A new `Expression` representing the natural logarithm. """ return Function("ln", [self]) @expose_as_static - def log(self, base: Expr | float) -> "Expr": + def log(self, base: Expression | float) -> "Expression": """Creates an expression that calculates the logarithm of this expression with a given base. Example: @@ -329,24 +333,24 @@ def log(self, base: Expr | float) -> "Expr": base: The base of the logarithm. Returns: - A new `Expr` representing the logarithm. + A new `Expression` representing the logarithm. """ return Function("log", [self, self._cast_to_expr_or_convert_to_constant(base)]) @expose_as_static - def log10(self) -> "Expr": + def log10(self) -> "Expression": """Creates an expression that calculates the base 10 logarithm of this expression. Example: >>> Field.of("value").log10() Returns: - A new `Expr` representing the logarithm. + A new `Expression` representing the logarithm. """ return Function("log10", [self]) @expose_as_static - def pow(self, exponent: Expr | float) -> "Expr": + def pow(self, exponent: Expression | float) -> "Expression": """Creates an expression that calculates this expression raised to the power of the exponent. Example: @@ -359,14 +363,14 @@ def pow(self, exponent: Expr | float) -> "Expr": exponent: The exponent. Returns: - A new `Expr` representing the power operation. + A new `Expression` representing the power operation. """ return Function( "pow", [self, self._cast_to_expr_or_convert_to_constant(exponent)] ) @expose_as_static - def round(self) -> "Expr": + def round(self) -> "Expression": """Creates an expression that rounds this expression to the nearest integer. Example: @@ -374,12 +378,12 @@ def round(self) -> "Expr": >>> Field.of("value").round() Returns: - A new `Expr` representing the rounded value. + A new `Expression` representing the rounded value. """ return Function("round", [self]) @expose_as_static - def sqrt(self) -> "Expr": + def sqrt(self) -> "Expression": """Creates an expression that calculates the square root of this expression. Example: @@ -387,12 +391,12 @@ def sqrt(self) -> "Expr": >>> Field.of("area").sqrt() Returns: - A new `Expr` representing the square root. + A new `Expression` representing the square root. """ return Function("sqrt", [self]) @expose_as_static - def logical_maximum(self, other: Expr | CONSTANT_TYPE) -> "Expr": + def logical_maximum(self, other: Expression | CONSTANT_TYPE) -> "Expression": """Creates an expression that returns the larger value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -409,7 +413,7 @@ def logical_maximum(self, other: Expr | CONSTANT_TYPE) -> "Expr": other: The other expression or constant value to compare with. Returns: - A new `Expr` representing the logical maximum operation. + A new `Expression` representing the logical maximum operation. """ return Function( "maximum", @@ -418,7 +422,7 @@ def logical_maximum(self, other: Expr | CONSTANT_TYPE) -> "Expr": ) @expose_as_static - def logical_minimum(self, other: Expr | CONSTANT_TYPE) -> "Expr": + def logical_minimum(self, other: Expression | CONSTANT_TYPE) -> "Expression": """Creates an expression that returns the smaller value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -435,7 +439,7 @@ def logical_minimum(self, other: Expr | CONSTANT_TYPE) -> "Expr": other: The other expression or constant value to compare with. Returns: - A new `Expr` representing the logical minimum operation. + A new `Expression` representing the logical minimum operation. """ return Function( "minimum", @@ -444,7 +448,7 @@ def logical_minimum(self, other: Expr | CONSTANT_TYPE) -> "Expr": ) @expose_as_static - def equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def equal(self, other: Expression | CONSTANT_TYPE) -> "BooleanExpression": """Creates an expression that checks if this expression is equal to another expression or constant value. @@ -458,14 +462,14 @@ def equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": other: The expression or constant value to compare for equality. Returns: - A new `Expr` representing the equality comparison. + A new `Expression` representing the equality comparison. """ - return BooleanExpr( + return BooleanExpression( "equal", [self, self._cast_to_expr_or_convert_to_constant(other)] ) @expose_as_static - def not_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def not_equal(self, other: Expression | CONSTANT_TYPE) -> "BooleanExpression": """Creates an expression that checks if this expression is not equal to another expression or constant value. @@ -479,14 +483,14 @@ def not_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": other: The expression or constant value to compare for inequality. Returns: - A new `Expr` representing the inequality comparison. + A new `Expression` representing the inequality comparison. """ - return BooleanExpr( + return BooleanExpression( "not_equal", [self, self._cast_to_expr_or_convert_to_constant(other)] ) @expose_as_static - def greater_than(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def greater_than(self, other: Expression | CONSTANT_TYPE) -> "BooleanExpression": """Creates an expression that checks if this expression is greater than another expression or constant value. @@ -500,14 +504,16 @@ def greater_than(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": other: The expression or constant value to compare for greater than. Returns: - A new `Expr` representing the greater than comparison. + A new `Expression` representing the greater than comparison. """ - return BooleanExpr( + return BooleanExpression( "greater_than", [self, self._cast_to_expr_or_convert_to_constant(other)] ) @expose_as_static - def greater_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def greater_than_or_equal( + self, other: Expression | CONSTANT_TYPE + ) -> "BooleanExpression": """Creates an expression that checks if this expression is greater than or equal to another expression or constant value. @@ -521,15 +527,15 @@ def greater_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": other: The expression or constant value to compare for greater than or equal to. Returns: - A new `Expr` representing the greater than or equal to comparison. + A new `Expression` representing the greater than or equal to comparison. """ - return BooleanExpr( + return BooleanExpression( "greater_than_or_equal", [self, self._cast_to_expr_or_convert_to_constant(other)], ) @expose_as_static - def less_than(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def less_than(self, other: Expression | CONSTANT_TYPE) -> "BooleanExpression": """Creates an expression that checks if this expression is less than another expression or constant value. @@ -543,14 +549,16 @@ def less_than(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": other: The expression or constant value to compare for less than. Returns: - A new `Expr` representing the less than comparison. + A new `Expression` representing the less than comparison. """ - return BooleanExpr( + return BooleanExpression( "less_than", [self, self._cast_to_expr_or_convert_to_constant(other)] ) @expose_as_static - def less_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def less_than_or_equal( + self, other: Expression | CONSTANT_TYPE + ) -> "BooleanExpression": """Creates an expression that checks if this expression is less than or equal to another expression or constant value. @@ -564,17 +572,17 @@ def less_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": other: The expression or constant value to compare for less than or equal to. Returns: - A new `Expr` representing the less than or equal to comparison. + A new `Expression` representing the less than or equal to comparison. """ - return BooleanExpr( + return BooleanExpression( "less_than_or_equal", [self, self._cast_to_expr_or_convert_to_constant(other)], ) @expose_as_static def equal_any( - self, array: Array | Sequence[Expr | CONSTANT_TYPE] | Expr - ) -> "BooleanExpr": + self, array: Array | Sequence[Expression | CONSTANT_TYPE] | Expression + ) -> "BooleanExpression": """Creates an expression that checks if this expression is equal to any of the provided values or expressions. @@ -586,9 +594,9 @@ def equal_any( array: The values or expressions to check against. Returns: - A new `Expr` representing the 'IN' comparison. + A new `Expression` representing the 'IN' comparison. """ - return BooleanExpr( + return BooleanExpression( "equal_any", [ self, @@ -598,8 +606,8 @@ def equal_any( @expose_as_static def not_equal_any( - self, array: Array | list[Expr | CONSTANT_TYPE] | Expr - ) -> "BooleanExpr": + self, array: Array | list[Expression | CONSTANT_TYPE] | Expression + ) -> "BooleanExpression": """Creates an expression that checks if this expression is not equal to any of the provided values or expressions. @@ -611,9 +619,9 @@ def not_equal_any( array: The values or expressions to check against. Returns: - A new `Expr` representing the 'NOT IN' comparison. + A new `Expression` representing the 'NOT IN' comparison. """ - return BooleanExpr( + return BooleanExpression( "not_equal_any", [ self, @@ -622,7 +630,9 @@ def not_equal_any( ) @expose_as_static - def array_contains(self, element: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def array_contains( + self, element: Expression | CONSTANT_TYPE + ) -> "BooleanExpression": """Creates an expression that checks if an array contains a specific element or value. Example: @@ -635,17 +645,17 @@ def array_contains(self, element: Expr | CONSTANT_TYPE) -> "BooleanExpr": element: The element (expression or constant) to search for in the array. Returns: - A new `Expr` representing the 'array_contains' comparison. + A new `Expression` representing the 'array_contains' comparison. """ - return BooleanExpr( + return BooleanExpression( "array_contains", [self, self._cast_to_expr_or_convert_to_constant(element)] ) @expose_as_static def array_contains_all( self, - elements: Array | list[Expr | CONSTANT_TYPE] | Expr, - ) -> "BooleanExpr": + elements: Array | list[Expression | CONSTANT_TYPE] | Expression, + ) -> "BooleanExpression": """Creates an expression that checks if an array contains all the specified elements. Example: @@ -658,9 +668,9 @@ def array_contains_all( elements: The list of elements (expressions or constants) to check for in the array. Returns: - A new `Expr` representing the 'array_contains_all' comparison. + A new `Expression` representing the 'array_contains_all' comparison. """ - return BooleanExpr( + return BooleanExpression( "array_contains_all", [ self, @@ -671,8 +681,8 @@ def array_contains_all( @expose_as_static def array_contains_any( self, - elements: Array | list[Expr | CONSTANT_TYPE] | Expr, - ) -> "BooleanExpr": + elements: Array | list[Expression | CONSTANT_TYPE] | Expression, + ) -> "BooleanExpression": """Creates an expression that checks if an array contains any of the specified elements. Example: @@ -686,9 +696,9 @@ def array_contains_any( elements: The list of elements (expressions or constants) to check for in the array. Returns: - A new `Expr` representing the 'array_contains_any' comparison. + A new `Expression` representing the 'array_contains_any' comparison. """ - return BooleanExpr( + return BooleanExpression( "array_contains_any", [ self, @@ -697,7 +707,7 @@ def array_contains_any( ) @expose_as_static - def array_length(self) -> "Expr": + def array_length(self) -> "Expression": """Creates an expression that calculates the length of an array. Example: @@ -705,12 +715,12 @@ def array_length(self) -> "Expr": >>> Field.of("cart").array_length() Returns: - A new `Expr` representing the length of the array. + A new `Expression` representing the length of the array. """ return Function("array_length", [self]) @expose_as_static - def array_reverse(self) -> "Expr": + def array_reverse(self) -> "Expression": """Creates an expression that returns the reversed content of an array. Example: @@ -718,14 +728,14 @@ def array_reverse(self) -> "Expr": >>> Field.of("preferences").array_reverse() Returns: - A new `Expr` representing the reversed array. + A new `Expression` representing the reversed array. """ return Function("array_reverse", [self]) @expose_as_static def array_concat( - self, *other_arrays: Array | list[Expr | CONSTANT_TYPE] | Expr - ) -> "Expr": + self, *other_arrays: Array | list[Expression | CONSTANT_TYPE] | Expression + ) -> "Expression": """Creates an expression that concatenates an array expression with another array. Example: @@ -736,7 +746,7 @@ def array_concat( array: The list of constants or expressions to concat with. Returns: - A new `Expr` representing the concatenated array. + A new `Expression` representing the concatenated array. """ return Function( "array_concat", @@ -745,14 +755,14 @@ def array_concat( ) @expose_as_static - def concat(self, *others: Expr | CONSTANT_TYPE) -> "Expr": + def concat(self, *others: Expression | CONSTANT_TYPE) -> "Expression": """Creates an expression that concatenates expressions together Args: *others: The expressions to concatenate. Returns: - A new `Expr` representing the concatenated value. + A new `Expression` representing the concatenated value. """ return Function( "concat", @@ -760,7 +770,7 @@ def concat(self, *others: Expr | CONSTANT_TYPE) -> "Expr": ) @expose_as_static - def length(self) -> "Expr": + def length(self) -> "Expression": """ Creates an expression that calculates the length of the expression if it is a string, array, map, or blob. @@ -769,12 +779,12 @@ def length(self) -> "Expr": >>> Field.of("name").length() Returns: - A new `Expr` representing the length of the expression. + A new `Expression` representing the length of the expression. """ return Function("length", [self]) @expose_as_static - def is_absent(self) -> "BooleanExpr": + def is_absent(self) -> "BooleanExpression": """Creates an expression that returns true if a value is absent. Otherwise, returns false even if the value is null. @@ -783,12 +793,12 @@ def is_absent(self) -> "BooleanExpr": >>> Field.of("email").is_absent() Returns: - A new `BooleanExpression` representing the isAbsent operation. + A new `BooleanExpressionession` representing the isAbsent operation. """ - return BooleanExpr("is_absent", [self]) + return BooleanExpression("is_absent", [self]) @expose_as_static - def if_absent(self, default_value: Expr | CONSTANT_TYPE) -> "Expr": + def if_absent(self, default_value: Expression | CONSTANT_TYPE) -> "Expression": """Creates an expression that returns a default value if an expression evaluates to an absent value. Example: @@ -799,7 +809,7 @@ def if_absent(self, default_value: Expr | CONSTANT_TYPE) -> "Expr": default_value: The expression or constant value to return if this expression is absent. Returns: - A new `Expr` representing the ifAbsent operation. + A new `Expression` representing the ifAbsent operation. """ return Function( "if_absent", @@ -807,7 +817,7 @@ def if_absent(self, default_value: Expr | CONSTANT_TYPE) -> "Expr": ) @expose_as_static - def is_nan(self) -> "BooleanExpr": + def is_nan(self) -> "BooleanExpression": """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). Example: @@ -815,12 +825,12 @@ def is_nan(self) -> "BooleanExpr": >>> Field.of("value").divide(0).is_nan() Returns: - A new `Expr` representing the 'isNaN' check. + A new `Expression` representing the 'isNaN' check. """ - return BooleanExpr("is_nan", [self]) + return BooleanExpression("is_nan", [self]) @expose_as_static - def is_not_nan(self) -> "BooleanExpr": + def is_not_nan(self) -> "BooleanExpression": """Creates an expression that checks if this expression evaluates to a non-'NaN' (Not a Number) value. Example: @@ -828,33 +838,33 @@ def is_not_nan(self) -> "BooleanExpr": >>> Field.of("value").divide(1).is_not_nan() Returns: - A new `Expr` representing the 'is not NaN' check. + A new `Expression` representing the 'is not NaN' check. """ - return BooleanExpr("is_not_nan", [self]) + return BooleanExpression("is_not_nan", [self]) @expose_as_static - def is_null(self) -> "BooleanExpr": + def is_null(self) -> "BooleanExpression": """Creates an expression that checks if the value of a field is 'Null'. Example: >>> Field.of("value").is_null() Returns: - A new `Expr` representing the 'isNull' check. + A new `Expression` representing the 'isNull' check. """ - return BooleanExpr("is_null", [self]) + return BooleanExpression("is_null", [self]) @expose_as_static - def is_not_null(self) -> "BooleanExpr": + def is_not_null(self) -> "BooleanExpression": """Creates an expression that checks if the value of a field is not 'Null'. Example: >>> Field.of("value").is_not_null() Returns: - A new `Expr` representing the 'isNotNull' check. + A new `Expression` representing the 'isNotNull' check. """ - return BooleanExpr("is_not_null", [self]) + return BooleanExpression("is_not_null", [self]) @expose_as_static def is_error(self): @@ -865,12 +875,12 @@ def is_error(self): >>> Field.of("value").divide("string").is_error() Returns: - A new `Expr` representing the isError operation. + A new `Expression` representing the isError operation. """ return Function("is_error", [self]) @expose_as_static - def if_error(self, then_value: Expr | CONSTANT_TYPE) -> "Expr": + def if_error(self, then_value: Expression | CONSTANT_TYPE) -> "Expression": """Creates an expression that returns ``then_value`` if this expression evaluates to an error. Otherwise, returns the value of this expression. @@ -882,14 +892,14 @@ def if_error(self, then_value: Expr | CONSTANT_TYPE) -> "Expr": then_value: The value to return if this expression evaluates to an error. Returns: - A new `Expr` representing the ifError operation. + A new `Expression` representing the ifError operation. """ return Function( "if_error", [self, self._cast_to_expr_or_convert_to_constant(then_value)] ) @expose_as_static - def exists(self) -> "BooleanExpr": + def exists(self) -> "BooleanExpression": """Creates an expression that checks if a field exists in the document. Example: @@ -897,12 +907,12 @@ def exists(self) -> "BooleanExpr": >>> Field.of("phoneNumber").exists() Returns: - A new `Expr` representing the 'exists' check. + A new `Expression` representing the 'exists' check. """ - return BooleanExpr("exists", [self]) + return BooleanExpression("exists", [self]) @expose_as_static - def sum(self) -> "Expr": + def sum(self) -> "Expression": """Creates an aggregation that calculates the sum of a numeric field across multiple stage inputs. Example: @@ -915,7 +925,7 @@ def sum(self) -> "Expr": return AggregateFunction("sum", [self]) @expose_as_static - def average(self) -> "Expr": + def average(self) -> "Expression": """Creates an aggregation that calculates the average (mean) of a numeric field across multiple stage inputs. @@ -929,7 +939,7 @@ def average(self) -> "Expr": return AggregateFunction("average", [self]) @expose_as_static - def count(self) -> "Expr": + def count(self) -> "Expression": """Creates an aggregation that counts the number of stage inputs with valid evaluations of the expression or field. @@ -943,7 +953,7 @@ def count(self) -> "Expr": return AggregateFunction("count", [self]) @expose_as_static - def count_if(self) -> "Expr": + def count_if(self) -> "Expression": """Creates an aggregation that counts the number of values of the provided field or expression that evaluate to True. @@ -958,7 +968,7 @@ def count_if(self) -> "Expr": return AggregateFunction("count_if", [self]) @expose_as_static - def count_distinct(self) -> "Expr": + def count_distinct(self) -> "Expression": """Creates an aggregation that counts the number of distinct values of the provided field or expression. @@ -972,7 +982,7 @@ def count_distinct(self) -> "Expr": return AggregateFunction("count_distinct", [self]) @expose_as_static - def minimum(self) -> "Expr": + def minimum(self) -> "Expression": """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. Example: @@ -985,7 +995,7 @@ def minimum(self) -> "Expr": return AggregateFunction("minimum", [self]) @expose_as_static - def maximum(self) -> "Expr": + def maximum(self) -> "Expression": """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. Example: @@ -998,7 +1008,7 @@ def maximum(self) -> "Expr": return AggregateFunction("maximum", [self]) @expose_as_static - def char_length(self) -> "Expr": + def char_length(self) -> "Expression": """Creates an expression that calculates the character length of a string. Example: @@ -1006,12 +1016,12 @@ def char_length(self) -> "Expr": >>> Field.of("name").char_length() Returns: - A new `Expr` representing the length of the string. + A new `Expression` representing the length of the string. """ return Function("char_length", [self]) @expose_as_static - def byte_length(self) -> "Expr": + def byte_length(self) -> "Expression": """Creates an expression that calculates the byte length of a string in its UTF-8 form. Example: @@ -1019,12 +1029,12 @@ def byte_length(self) -> "Expr": >>> Field.of("name").byte_length() Returns: - A new `Expr` representing the byte length of the string. + A new `Expression` representing the byte length of the string. """ return Function("byte_length", [self]) @expose_as_static - def like(self, pattern: Expr | str) -> "BooleanExpr": + def like(self, pattern: Expression | str) -> "BooleanExpression": """Creates an expression that performs a case-sensitive string comparison. Example: @@ -1037,14 +1047,14 @@ def like(self, pattern: Expr | str) -> "BooleanExpr": pattern: The pattern (string or expression) to search for. You can use "%" as a wildcard character. Returns: - A new `Expr` representing the 'like' comparison. + A new `Expression` representing the 'like' comparison. """ - return BooleanExpr( + return BooleanExpression( "like", [self, self._cast_to_expr_or_convert_to_constant(pattern)] ) @expose_as_static - def regex_contains(self, regex: Expr | str) -> "BooleanExpr": + def regex_contains(self, regex: Expression | str) -> "BooleanExpression": """Creates an expression that checks if a string contains a specified regular expression as a substring. @@ -1058,14 +1068,14 @@ def regex_contains(self, regex: Expr | str) -> "BooleanExpr": regex: The regular expression (string or expression) to use for the search. Returns: - A new `Expr` representing the 'contains' comparison. + A new `Expression` representing the 'contains' comparison. """ - return BooleanExpr( + return BooleanExpression( "regex_contains", [self, self._cast_to_expr_or_convert_to_constant(regex)] ) @expose_as_static - def regex_match(self, regex: Expr | str) -> "BooleanExpr": + def regex_match(self, regex: Expression | str) -> "BooleanExpression": """Creates an expression that checks if a string matches a specified regular expression. Example: @@ -1078,14 +1088,14 @@ def regex_match(self, regex: Expr | str) -> "BooleanExpr": regex: The regular expression (string or expression) to use for the match. Returns: - A new `Expr` representing the regular expression match. + A new `Expression` representing the regular expression match. """ - return BooleanExpr( + return BooleanExpression( "regex_match", [self, self._cast_to_expr_or_convert_to_constant(regex)] ) @expose_as_static - def string_contains(self, substring: Expr | str) -> "BooleanExpr": + def string_contains(self, substring: Expression | str) -> "BooleanExpression": """Creates an expression that checks if this string expression contains a specified substring. Example: @@ -1098,15 +1108,15 @@ def string_contains(self, substring: Expr | str) -> "BooleanExpr": substring: The substring (string or expression) to use for the search. Returns: - A new `Expr` representing the 'contains' comparison. + A new `Expression` representing the 'contains' comparison. """ - return BooleanExpr( + return BooleanExpression( "string_contains", [self, self._cast_to_expr_or_convert_to_constant(substring)], ) @expose_as_static - def starts_with(self, prefix: Expr | str) -> "BooleanExpr": + def starts_with(self, prefix: Expression | str) -> "BooleanExpression": """Creates an expression that checks if a string starts with a given prefix. Example: @@ -1119,14 +1129,14 @@ def starts_with(self, prefix: Expr | str) -> "BooleanExpr": prefix: The prefix (string or expression) to check for. Returns: - A new `Expr` representing the 'starts with' comparison. + A new `Expression` representing the 'starts with' comparison. """ - return BooleanExpr( + return BooleanExpression( "starts_with", [self, self._cast_to_expr_or_convert_to_constant(prefix)] ) @expose_as_static - def ends_with(self, postfix: Expr | str) -> "BooleanExpr": + def ends_with(self, postfix: Expression | str) -> "BooleanExpression": """Creates an expression that checks if a string ends with a given postfix. Example: @@ -1139,14 +1149,14 @@ def ends_with(self, postfix: Expr | str) -> "BooleanExpr": postfix: The postfix (string or expression) to check for. Returns: - A new `Expr` representing the 'ends with' comparison. + A new `Expression` representing the 'ends with' comparison. """ - return BooleanExpr( + return BooleanExpression( "ends_with", [self, self._cast_to_expr_or_convert_to_constant(postfix)] ) @expose_as_static - def string_concat(self, *elements: Expr | CONSTANT_TYPE) -> "Expr": + def string_concat(self, *elements: Expression | CONSTANT_TYPE) -> "Expression": """Creates an expression that concatenates string expressions, fields or constants together. Example: @@ -1157,7 +1167,7 @@ def string_concat(self, *elements: Expr | CONSTANT_TYPE) -> "Expr": *elements: The expressions or constants (typically strings) to concatenate. Returns: - A new `Expr` representing the concatenated string. + A new `Expression` representing the concatenated string. """ return Function( "string_concat", @@ -1165,7 +1175,7 @@ def string_concat(self, *elements: Expr | CONSTANT_TYPE) -> "Expr": ) @expose_as_static - def to_lower(self) -> "Expr": + def to_lower(self) -> "Expression": """Creates an expression that converts a string to lowercase. Example: @@ -1173,12 +1183,12 @@ def to_lower(self) -> "Expr": >>> Field.of("name").to_lower() Returns: - A new `Expr` representing the lowercase string. + A new `Expression` representing the lowercase string. """ return Function("to_lower", [self]) @expose_as_static - def to_upper(self) -> "Expr": + def to_upper(self) -> "Expression": """Creates an expression that converts a string to uppercase. Example: @@ -1186,12 +1196,12 @@ def to_upper(self) -> "Expr": >>> Field.of("title").to_upper() Returns: - A new `Expr` representing the uppercase string. + A new `Expression` representing the uppercase string. """ return Function("to_upper", [self]) @expose_as_static - def trim(self) -> "Expr": + def trim(self) -> "Expression": """Creates an expression that removes leading and trailing whitespace from a string. Example: @@ -1199,12 +1209,12 @@ def trim(self) -> "Expr": >>> Field.of("userInput").trim() Returns: - A new `Expr` representing the trimmed string. + A new `Expression` representing the trimmed string. """ return Function("trim", [self]) @expose_as_static - def string_reverse(self) -> "Expr": + def string_reverse(self) -> "Expression": """Creates an expression that reverses a string. Example: @@ -1212,14 +1222,14 @@ def string_reverse(self) -> "Expr": >>> Field.of("userInput").reverse() Returns: - A new `Expr` representing the reversed string. + A new `Expression` representing the reversed string. """ return Function("string_reverse", [self]) @expose_as_static def substring( - self, position: Expr | int, length: Expr | int | None = None - ) -> "Expr": + self, position: Expression | int, length: Expression | int | None = None + ) -> "Expression": """Creates an expression that returns a substring of the results of this expression. @@ -1233,7 +1243,7 @@ def substring( will end at the end of the input. Returns: - A new `Expr` representing the extracted substring. + A new `Expression` representing the extracted substring. """ args = [self, self._cast_to_expr_or_convert_to_constant(position)] if length is not None: @@ -1241,7 +1251,7 @@ def substring( return Function("substring", args) @expose_as_static - def join(self, delimeter: Expr | str) -> "Expr": + def join(self, delimeter: Expression | str) -> "Expression": """Creates an expression that joins the elements of an array into a string @@ -1252,14 +1262,14 @@ def join(self, delimeter: Expr | str) -> "Expr": delimiter: The delimiter to add between the elements of the array. Returns: - A new `Expr` representing the joined string. + A new `Expression` representing the joined string. """ return Function( "join", [self, self._cast_to_expr_or_convert_to_constant(delimeter)] ) @expose_as_static - def map_get(self, key: str | Constant[str]) -> "Expr": + def map_get(self, key: str | Constant[str]) -> "Expression": """Accesses a value from the map produced by evaluating this expression. Example: @@ -1270,14 +1280,14 @@ def map_get(self, key: str | Constant[str]) -> "Expr": key: The key to access in the map. Returns: - A new `Expr` representing the value associated with the given key in the map. + A new `Expression` representing the value associated with the given key in the map. """ return Function( "map_get", [self, self._cast_to_expr_or_convert_to_constant(key)] ) @expose_as_static - def map_remove(self, key: str | Constant[str]) -> "Expr": + def map_remove(self, key: str | Constant[str]) -> "Expression": """Remove a key from a the map produced by evaluating this expression. Example: @@ -1288,7 +1298,7 @@ def map_remove(self, key: str | Constant[str]) -> "Expr": key: The key to remove in the map. Returns: - A new `Expr` representing the map_remove operation. + A new `Expression` representing the map_remove operation. """ return Function( "map_remove", [self, self._cast_to_expr_or_convert_to_constant(key)] @@ -1296,8 +1306,11 @@ def map_remove(self, key: str | Constant[str]) -> "Expr": @expose_as_static def map_merge( - self, *other_maps: Map | dict[str | Constant[str], Expr | CONSTANT_TYPE] | Expr - ) -> "Expr": + self, + *other_maps: Map + | dict[str | Constant[str], Expression | CONSTANT_TYPE] + | Expression, + ) -> "Expression": """Creates an expression that merges one or more dicts into a single map. Example: @@ -1308,7 +1321,7 @@ def map_merge( *other_maps: Sequence of maps to merge into the resulting map. Returns: - A new `Expr` representing the value associated with the given key in the map. + A new `Expression` representing the value associated with the given key in the map. """ return Function( "map_merge", @@ -1316,7 +1329,7 @@ def map_merge( ) @expose_as_static - def cosine_distance(self, other: Expr | list[float] | Vector) -> "Expr": + def cosine_distance(self, other: Expression | list[float] | Vector) -> "Expression": """Calculates the cosine distance between two vectors. Example: @@ -1326,10 +1339,10 @@ def cosine_distance(self, other: Expr | list[float] | Vector) -> "Expr": >>> Field.of("location").cosine_distance([37.7749, -122.4194]) Args: - other: The other vector (represented as an Expr, list of floats, or Vector) to compare against. + other: The other vector (represented as an Expression, list of floats, or Vector) to compare against. Returns: - A new `Expr` representing the cosine distance between the two vectors. + A new `Expression` representing the cosine distance between the two vectors. """ return Function( "cosine_distance", @@ -1340,7 +1353,9 @@ def cosine_distance(self, other: Expr | list[float] | Vector) -> "Expr": ) @expose_as_static - def euclidean_distance(self, other: Expr | list[float] | Vector) -> "Expr": + def euclidean_distance( + self, other: Expression | list[float] | Vector + ) -> "Expression": """Calculates the Euclidean distance between two vectors. Example: @@ -1350,10 +1365,10 @@ def euclidean_distance(self, other: Expr | list[float] | Vector) -> "Expr": >>> Field.of("pointA").euclidean_distance(Field.of("pointB")) Args: - other: The other vector (represented as an Expr, list of floats, or Vector) to compare against. + other: The other vector (represented as an Expression, list of floats, or Vector) to compare against. Returns: - A new `Expr` representing the Euclidean distance between the two vectors. + A new `Expression` representing the Euclidean distance between the two vectors. """ return Function( "euclidean_distance", @@ -1364,7 +1379,7 @@ def euclidean_distance(self, other: Expr | list[float] | Vector) -> "Expr": ) @expose_as_static - def dot_product(self, other: Expr | list[float] | Vector) -> "Expr": + def dot_product(self, other: Expression | list[float] | Vector) -> "Expression": """Calculates the dot product between two vectors. Example: @@ -1374,10 +1389,10 @@ def dot_product(self, other: Expr | list[float] | Vector) -> "Expr": >>> Field.of("docVector1").dot_product(Field.of("docVector2")) Args: - other: The other vector (represented as an Expr, list of floats, or Vector) to calculate dot product with. + other: The other vector (represented as an Expression, list of floats, or Vector) to calculate dot product with. Returns: - A new `Expr` representing the dot product between the two vectors. + A new `Expression` representing the dot product between the two vectors. """ return Function( "dot_product", @@ -1388,7 +1403,7 @@ def dot_product(self, other: Expr | list[float] | Vector) -> "Expr": ) @expose_as_static - def vector_length(self) -> "Expr": + def vector_length(self) -> "Expression": """Creates an expression that calculates the length (dimension) of a Firestore Vector. Example: @@ -1396,12 +1411,12 @@ def vector_length(self) -> "Expr": >>> Field.of("embedding").vector_length() Returns: - A new `Expr` representing the length of the vector. + A new `Expression` representing the length of the vector. """ return Function("vector_length", [self]) @expose_as_static - def timestamp_to_unix_micros(self) -> "Expr": + def timestamp_to_unix_micros(self) -> "Expression": """Creates an expression that converts a timestamp to the number of microseconds since the epoch (1970-01-01 00:00:00 UTC). @@ -1412,12 +1427,12 @@ def timestamp_to_unix_micros(self) -> "Expr": >>> Field.of("timestamp").timestamp_to_unix_micros() Returns: - A new `Expr` representing the number of microseconds since the epoch. + A new `Expression` representing the number of microseconds since the epoch. """ return Function("timestamp_to_unix_micros", [self]) @expose_as_static - def unix_micros_to_timestamp(self) -> "Expr": + def unix_micros_to_timestamp(self) -> "Expression": """Creates an expression that converts a number of microseconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -1426,12 +1441,12 @@ def unix_micros_to_timestamp(self) -> "Expr": >>> Field.of("microseconds").unix_micros_to_timestamp() Returns: - A new `Expr` representing the timestamp. + A new `Expression` representing the timestamp. """ return Function("unix_micros_to_timestamp", [self]) @expose_as_static - def timestamp_to_unix_millis(self) -> "Expr": + def timestamp_to_unix_millis(self) -> "Expression": """Creates an expression that converts a timestamp to the number of milliseconds since the epoch (1970-01-01 00:00:00 UTC). @@ -1442,12 +1457,12 @@ def timestamp_to_unix_millis(self) -> "Expr": >>> Field.of("timestamp").timestamp_to_unix_millis() Returns: - A new `Expr` representing the number of milliseconds since the epoch. + A new `Expression` representing the number of milliseconds since the epoch. """ return Function("timestamp_to_unix_millis", [self]) @expose_as_static - def unix_millis_to_timestamp(self) -> "Expr": + def unix_millis_to_timestamp(self) -> "Expression": """Creates an expression that converts a number of milliseconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -1456,12 +1471,12 @@ def unix_millis_to_timestamp(self) -> "Expr": >>> Field.of("milliseconds").unix_millis_to_timestamp() Returns: - A new `Expr` representing the timestamp. + A new `Expression` representing the timestamp. """ return Function("unix_millis_to_timestamp", [self]) @expose_as_static - def timestamp_to_unix_seconds(self) -> "Expr": + def timestamp_to_unix_seconds(self) -> "Expression": """Creates an expression that converts a timestamp to the number of seconds since the epoch (1970-01-01 00:00:00 UTC). @@ -1472,12 +1487,12 @@ def timestamp_to_unix_seconds(self) -> "Expr": >>> Field.of("timestamp").timestamp_to_unix_seconds() Returns: - A new `Expr` representing the number of seconds since the epoch. + A new `Expression` representing the number of seconds since the epoch. """ return Function("timestamp_to_unix_seconds", [self]) @expose_as_static - def unix_seconds_to_timestamp(self) -> "Expr": + def unix_seconds_to_timestamp(self) -> "Expression": """Creates an expression that converts a number of seconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -1486,12 +1501,14 @@ def unix_seconds_to_timestamp(self) -> "Expr": >>> Field.of("seconds").unix_seconds_to_timestamp() Returns: - A new `Expr` representing the timestamp. + A new `Expression` representing the timestamp. """ return Function("unix_seconds_to_timestamp", [self]) @expose_as_static - def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "Expr": + def timestamp_add( + self, unit: Expression | str, amount: Expression | float + ) -> "Expression": """Creates an expression that adds a specified amount of time to this timestamp expression. Example: @@ -1506,7 +1523,7 @@ def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "Expr": amount: The expression or float representing the amount of time to add. Returns: - A new `Expr` representing the resulting timestamp. + A new `Expression` representing the resulting timestamp. """ return Function( "timestamp_add", @@ -1518,7 +1535,9 @@ def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "Expr": ) @expose_as_static - def timestamp_subtract(self, unit: Expr | str, amount: Expr | float) -> "Expr": + def timestamp_subtract( + self, unit: Expression | str, amount: Expression | float + ) -> "Expression": """Creates an expression that subtracts a specified amount of time from this timestamp expression. Example: @@ -1533,7 +1552,7 @@ def timestamp_subtract(self, unit: Expr | str, amount: Expr | float) -> "Expr": amount: The expression or float representing the amount of time to subtract. Returns: - A new `Expr` representing the resulting timestamp. + A new `Expression` representing the resulting timestamp. """ return Function( "timestamp_subtract", @@ -1553,7 +1572,7 @@ def collection_id(self): >>> Field.of("__name__").collection_id() Returns: - A new `Expr` representing the collection ID. + A new `Expression` representing the collection ID. """ return Function("collection_id", [self]) @@ -1566,7 +1585,7 @@ def document_id(self): >>> Field.of("__name__").document_id() Returns: - A new `Expr` representing the document ID. + A new `Expression` representing the document ID. """ return Function("document_id", [self]) @@ -1594,7 +1613,7 @@ def descending(self) -> Ordering: """ return Ordering(self, Ordering.Direction.DESCENDING) - def as_(self, alias: str) -> "AliasedExpr": + def as_(self, alias: str) -> "AliasedExpression": """Assigns an alias to this expression. Aliases are useful for renaming fields in the output of a stage or for giving meaningful @@ -1610,13 +1629,13 @@ def as_(self, alias: str) -> "AliasedExpr": alias: The alias to assign to this expression. Returns: - A new `Selectable` (typically an `AliasedExpr`) that wraps this + A new `Selectable` (typically an `AliasedExpression`) that wraps this expression and associates it with the provided alias. """ - return AliasedExpr(self, alias) + return AliasedExpression(self, alias) -class Constant(Expr, Generic[CONSTANT_TYPE]): +class Constant(Expression, Generic[CONSTANT_TYPE]): """Represents a constant literal value in an expression.""" def __init__(self, value: CONSTANT_TYPE): @@ -1643,13 +1662,13 @@ def _to_pb(self) -> Value: return encode_value(self.value) -class Function(Expr): +class Function(Expression): """A base class for expressions that represent function calls.""" def __init__( self, name: str, - params: Sequence[Expr], + params: Sequence[Expression], *, use_infix_repr: bool = True, infix_name_override: str | None = None, @@ -1693,22 +1712,8 @@ def _to_pb(self): class AggregateFunction(Function): """A base class for aggregation functions that operate across multiple inputs.""" - def as_(self, alias: str) -> "AliasedAggregate": - """Assigns an alias to this expression. - - Aliases are useful for renaming fields in the output of a stage or for giving meaningful - names to calculated values. - - Args: - alias: The alias to assign to this expression. - - Returns: A new AliasedAggregate that wraps this expression and associates it with the - provided alias. - """ - return AliasedAggregate(self, alias) - -class Selectable(Expr): +class Selectable(Expression): """Base class for expressions that can be selected or aliased in projection stages.""" def __eq__(self, other): @@ -1744,10 +1749,10 @@ def _to_value(field_list: Sequence[Selectable]) -> Value: ) -T = TypeVar("T", bound=Expr) +T = TypeVar("T", bound=Expression) -class AliasedExpr(Selectable, Generic[T]): +class AliasedExpression(Selectable, Generic[T]): """Wraps an expression with an alias.""" def __init__(self, expr: T, alias: str): @@ -1764,23 +1769,6 @@ def _to_pb(self): return Value(map_value={"fields": {self.alias: self.expr._to_pb()}}) -class AliasedAggregate: - """Wraps an aggregate with an alias""" - - def __init__(self, expr: AggregateFunction, alias: str): - self.expr = expr - self.alias = alias - - def _to_map(self): - return self.alias, self.expr._to_pb() - - def __repr__(self): - return f"{self.expr}.as_('{self.alias}')" - - def _to_pb(self): - return Value(map_value={"fields": {self.alias: self.expr._to_pb()}}) - - class Field(Selectable): """Represents a reference to a field within a document.""" @@ -1818,14 +1806,15 @@ def _to_pb(self): return Value(field_reference_value=self.path) -class BooleanExpr(Function): +class BooleanExpression(Function): """Filters the given data in some way.""" @staticmethod def _from_query_filter_pb(filter_pb, client): if isinstance(filter_pb, Query_pb.CompositeFilter): sub_filters = [ - BooleanExpr._from_query_filter_pb(f, client) for f in filter_pb.filters + BooleanExpression._from_query_filter_pb(f, client) + for f in filter_pb.filters ] if filter_pb.op == Query_pb.CompositeFilter.Operator.OR: return Or(*sub_filters) @@ -1879,7 +1868,7 @@ def _from_query_filter_pb(filter_pb, client): or filter_pb.field_filter or filter_pb.unary_filter ) - return BooleanExpr._from_query_filter_pb(f, client) + return BooleanExpression._from_query_filter_pb(f, client) else: raise TypeError(f"Unexpected filter type: {type(filter_pb)}") @@ -1889,13 +1878,13 @@ class Array(Function): Creates an expression that creates a Firestore array value from an input list. Example: - >>> Expr.array(["bar", Field.of("baz")]) + >>> Expression.array(["bar", Field.of("baz")]) Args: elements: The input list to evaluate in the expression """ - def __init__(self, elements: list[Expr | CONSTANT_TYPE]): + def __init__(self, elements: list[Expression | CONSTANT_TYPE]): if not isinstance(elements, list): raise TypeError("Array must be constructed with a list") converted_elements = [ @@ -1912,13 +1901,13 @@ class Map(Function): Creates an expression that creates a Firestore map value from an input dict. Example: - >>> Expr.map({"foo": "bar", "baz": Field.of("baz")}) + >>> Expression.map({"foo": "bar", "baz": Field.of("baz")}) Args: elements: The input dict to evaluate in the expression """ - def __init__(self, elements: dict[str | Constant[str], Expr | CONSTANT_TYPE]): + def __init__(self, elements: dict[str | Constant[str], Expression | CONSTANT_TYPE]): element_list = [] for k, v in elements.items(): element_list.append(self._cast_to_expr_or_convert_to_constant(k)) @@ -1933,7 +1922,7 @@ def __repr__(self): return f"Map({d})" -class And(BooleanExpr): +class And(BooleanExpression): """ Represents an expression that performs a logical 'AND' operation on multiple filter conditions. @@ -1946,11 +1935,11 @@ class And(BooleanExpr): *conditions: The filter conditions to 'AND' together. """ - def __init__(self, *conditions: "BooleanExpr"): + def __init__(self, *conditions: "BooleanExpression"): super().__init__("and", conditions, use_infix_repr=False) -class Not(BooleanExpr): +class Not(BooleanExpression): """ Represents an expression that negates a filter condition. @@ -1962,11 +1951,11 @@ class Not(BooleanExpr): condition: The filter condition to negate. """ - def __init__(self, condition: BooleanExpr): + def __init__(self, condition: BooleanExpression): super().__init__("not", [condition], use_infix_repr=False) -class Or(BooleanExpr): +class Or(BooleanExpression): """ Represents expression that performs a logical 'OR' operation on multiple filter conditions. @@ -1979,11 +1968,11 @@ class Or(BooleanExpr): *conditions: The filter conditions to 'OR' together. """ - def __init__(self, *conditions: "BooleanExpr"): + def __init__(self, *conditions: "BooleanExpression"): super().__init__("or", conditions, use_infix_repr=False) -class Xor(BooleanExpr): +class Xor(BooleanExpression): """ Represents an expression that performs a logical 'XOR' (exclusive OR) operation on multiple filter conditions. @@ -1996,11 +1985,11 @@ class Xor(BooleanExpr): *conditions: The filter conditions to 'XOR' together. """ - def __init__(self, conditions: Sequence["BooleanExpr"]): + def __init__(self, conditions: Sequence["BooleanExpression"]): super().__init__("xor", conditions, use_infix_repr=False) -class Conditional(BooleanExpr): +class Conditional(BooleanExpression): """ Represents a conditional expression that evaluates to a 'then' expression if a condition is true and an 'else' expression if the condition is false. @@ -2015,7 +2004,9 @@ class Conditional(BooleanExpr): else_expr: The expression to return if the condition is false """ - def __init__(self, condition: BooleanExpr, then_expr: Expr, else_expr: Expr): + def __init__( + self, condition: BooleanExpression, then_expr: Expression, else_expr: Expression + ): super().__init__( "conditional", [condition, then_expr, else_expr], use_infix_repr=False ) @@ -2036,7 +2027,7 @@ class Count(AggregateFunction): expression: The expression or field to count. If None, counts all stage inputs. """ - def __init__(self, expression: Expr | None = None): + def __init__(self, expression: Expression | None = None): expression_list = [expression] if expression else [] super().__init__("count", expression_list, use_infix_repr=bool(expression_list)) @@ -2045,7 +2036,7 @@ class CurrentTimestamp(Function): """Creates an expression that returns the current timestamp Returns: - A new `Expr` representing the current timestamp. + A new `Expression` representing the current timestamp. """ def __init__(self): diff --git a/google/cloud/firestore_v1/pipeline_source.py b/google/cloud/firestore_v1/pipeline_source.py index 6d83ae533..f4328afa4 100644 --- a/google/cloud/firestore_v1/pipeline_source.py +++ b/google/cloud/firestore_v1/pipeline_source.py @@ -14,7 +14,7 @@ from __future__ import annotations from typing import Generic, TypeVar, TYPE_CHECKING -from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.base_pipeline import _BasePipeline from google.cloud.firestore_v1._helpers import DOCUMENT_PATH_DELIMITER diff --git a/google/cloud/firestore_v1/_pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py similarity index 95% rename from google/cloud/firestore_v1/_pipeline_stages.py rename to google/cloud/firestore_v1/pipeline_stages.py index 62503404e..95ce32021 100644 --- a/google/cloud/firestore_v1/_pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -24,11 +24,10 @@ from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.pipeline_expressions import ( AggregateFunction, - Expr, - AliasedAggregate, - AliasedExpr, + Expression, + AliasedExpression, Field, - BooleanExpr, + BooleanExpression, Selectable, Ordering, ) @@ -167,8 +166,8 @@ class Aggregate(Stage): def __init__( self, - *args: AliasedExpr[AggregateFunction], - accumulators: Sequence[AliasedAggregate] = (), + *args: AliasedExpression[AggregateFunction], + accumulators: Sequence[AliasedExpression[AggregateFunction]] = (), groups: Sequence[str | Selectable] = (), ): super().__init__() @@ -268,13 +267,13 @@ class FindNearest(Stage): def __init__( self, - field: str | Expr, + field: str | Expression, vector: Sequence[float] | Vector, distance_measure: "DistanceMeasure" | str, options: Optional["FindNearestOptions"] = None, ): super().__init__("find_nearest") - self.field: Expr = Field(field) if isinstance(field, str) else field + self.field: Expression = Field(field) if isinstance(field, str) else field self.vector: Vector = vector if isinstance(vector, Vector) else Vector(vector) self.distance_measure = ( distance_measure @@ -299,18 +298,22 @@ def _pb_options(self) -> dict[str, Value]: return options -class GenericStage(Stage): +class RawStage(Stage): """Represents a generic, named stage with parameters.""" def __init__( - self, name: str, *params: Expr | Value, options: dict[str, Expr | Value] = {} + self, + name: str, + *params: Expression | Value, + options: dict[str, Expression | Value] = {}, ): super().__init__(name) self.params: list[Value] = [ - p._to_pb() if isinstance(p, Expr) else p for p in params + p._to_pb() if isinstance(p, Expression) else p for p in params ] self.options: dict[str, Value] = { - k: v._to_pb() if isinstance(v, Expr) else v for k, v in options.items() + k: v._to_pb() if isinstance(v, Expression) else v + for k, v in options.items() } def _pb_args(self): @@ -448,7 +451,7 @@ def _pb_options(self): class Where(Stage): """Filters documents based on a specified condition.""" - def __init__(self, condition: BooleanExpr): + def __init__(self, condition: BooleanExpression): super().__init__() self.condition = condition diff --git a/tests/system/pipeline_e2e/aggregates.yaml b/tests/system/pipeline_e2e/aggregates.yaml index 18902aff4..9593213ed 100644 --- a/tests/system/pipeline_e2e/aggregates.yaml +++ b/tests/system/pipeline_e2e/aggregates.yaml @@ -3,8 +3,8 @@ tests: pipeline: - Collection: books - Aggregate: - - AliasedExpr: - - Expr.count: + - AliasedExpression: + - Function.count: - Field: rating - "count" assert_results: @@ -29,9 +29,9 @@ tests: pipeline: - Collection: books - Aggregate: - - AliasedExpr: - - Expr.count_if: - - Expr.greater_than: + - AliasedExpression: + - Function.count_if: + - Function.greater_than: - Field: rating - Constant: 4.2 - "count_if_rating_gt_4_2" @@ -61,8 +61,8 @@ tests: pipeline: - Collection: books - Aggregate: - - AliasedExpr: - - Expr.count_distinct: + - AliasedExpression: + - Function.count_distinct: - Field: genre - "distinct_genres" assert_results: @@ -87,20 +87,20 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: genre - Constant: Science Fiction - Aggregate: - - AliasedExpr: - - Expr.count: + - AliasedExpression: + - Function.count: - Field: rating - "count" - - AliasedExpr: - - Expr.average: + - AliasedExpression: + - Function.average: - Field: rating - "avg_rating" - - AliasedExpr: - - Expr.maximum: + - AliasedExpression: + - Function.maximum: - Field: rating - "max_rating" assert_results: @@ -144,7 +144,7 @@ tests: pipeline: - Collection: books - Where: - - Expr.less_than: + - Function.less_than: - Field: published - Constant: 1900 - Aggregate: @@ -155,18 +155,18 @@ tests: pipeline: - Collection: books - Where: - - Expr.less_than: + - Function.less_than: - Field: published - Constant: 1984 - Aggregate: accumulators: - - AliasedExpr: - - Expr.average: + - AliasedExpression: + - Function.average: - Field: rating - "avg_rating" groups: [genre] - Where: - - Expr.greater_than: + - Function.greater_than: - Field: avg_rating - Constant: 4.3 - Sort: @@ -225,16 +225,16 @@ tests: pipeline: - Collection: books - Aggregate: - - AliasedExpr: - - Expr.count: + - AliasedExpression: + - Function.count: - Field: rating - "count" - - AliasedExpr: - - Expr.maximum: + - AliasedExpression: + - Function.maximum: - Field: rating - "max_rating" - - AliasedExpr: - - Expr.minimum: + - AliasedExpression: + - Function.minimum: - Field: published - "min_published" assert_results: @@ -266,4 +266,19 @@ tests: - fieldReferenceValue: published name: minimum - mapValue: {} - name: aggregate \ No newline at end of file + name: aggregate + - description: testSum + pipeline: + - Collection: books + - Where: + - Function.equal: + - Field: genre + - Constant: Science Fiction + - Aggregate: + - AliasedExpression: + - Function.sum: + - Field: rating + - "total_rating" + assert_results: + - total_rating: 8.8 + diff --git a/tests/system/pipeline_e2e/array.yaml b/tests/system/pipeline_e2e/array.yaml index d63a63402..3da16264d 100644 --- a/tests/system/pipeline_e2e/array.yaml +++ b/tests/system/pipeline_e2e/array.yaml @@ -3,7 +3,7 @@ tests: pipeline: - Collection: books - Where: - - Expr.array_contains: + - Function.array_contains: - Field: tags - Constant: comedy assert_results: @@ -33,7 +33,7 @@ tests: pipeline: - Collection: books - Where: - - Expr.array_contains_any: + - Function.array_contains_any: - Field: tags - - Constant: comedy - Constant: classic @@ -81,7 +81,7 @@ tests: pipeline: - Collection: books - Where: - - Expr.array_contains_all: + - Function.array_contains_all: - Field: tags - - Constant: adventure - Constant: magic @@ -116,12 +116,12 @@ tests: pipeline: - Collection: books - Select: - - AliasedExpr: - - Expr.array_length: + - AliasedExpression: + - Function.array_length: - Field: tags - "tagsCount" - Where: - - Expr.equal: + - Function.equal: - Field: tagsCount - Constant: 3 assert_results: # All documents have 3 tags @@ -161,12 +161,12 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: title - Constant: "The Hitchhiker's Guide to the Galaxy" - Select: - - AliasedExpr: - - Expr.array_reverse: + - AliasedExpression: + - Function.array_reverse: - Field: tags - "reversedTags" assert_results: @@ -178,14 +178,14 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: title - Constant: "The Hitchhiker's Guide to the Galaxy" - Select: - - AliasedExpr: - - Expr.array_concat: + - AliasedExpression: + - Function.array_concat: - Field: tags - - Array: ["new_tag", "another_tag"] + - Constant: ["new_tag", "another_tag"] - "concatenatedTags" assert_results: - concatenatedTags: @@ -225,15 +225,15 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: title - Constant: "Dune" - Select: - - AliasedExpr: - - Expr.array_concat: + - AliasedExpression: + - Function.array_concat: - Field: tags - - ["sci-fi"] - - ["classic", "epic"] + - Constant: ["sci-fi"] + - Constant: ["classic", "epic"] - "concatenatedTags" assert_results: - concatenatedTags: @@ -278,13 +278,13 @@ tests: pipeline: - Collection: books - AddFields: - - AliasedExpr: - - Expr.array_concat: + - AliasedExpression: + - Function.array_concat: - Field: tags - Array: ["Dystopian"] - "new_tags" - Where: - - Expr.array_contains_any: + - Function.array_contains_any: - Field: new_tags - - Constant: non_existent_tag - Field: genre @@ -351,8 +351,8 @@ tests: - Collection: books - Limit: 1 - Select: - - AliasedExpr: - - Expr.array_concat: + - AliasedExpression: + - Function.array_concat: - Array: [1, 2, 3] - Array: [4, 5] - "concatenated" diff --git a/tests/system/pipeline_e2e/date_and_time.yaml b/tests/system/pipeline_e2e/date_and_time.yaml index bbb5f34fe..cb5323dc1 100644 --- a/tests/system/pipeline_e2e/date_and_time.yaml +++ b/tests/system/pipeline_e2e/date_and_time.yaml @@ -4,15 +4,15 @@ tests: - Collection: books - Limit: 1 - Select: - - AliasedExpr: + - AliasedExpression: - And: - - Expr.greater_than_or_equal: + - Function.greater_than_or_equal: - CurrentTimestamp: [] - - Expr.unix_seconds_to_timestamp: + - Function.unix_seconds_to_timestamp: - Constant: 1735689600 # 2025-01-01 - - Expr.less_than: + - Function.less_than: - CurrentTimestamp: [] - - Expr.unix_seconds_to_timestamp: + - Function.unix_seconds_to_timestamp: - Constant: 4892438400 # 2125-01-01 - "is_between_2025_and_2125" assert_results: @@ -56,38 +56,38 @@ tests: pipeline: - Collection: timestamps - Select: - - AliasedExpr: - - Expr.timestamp_to_unix_micros: + - AliasedExpression: + - Function.timestamp_to_unix_micros: - Field: time - "micros" - - AliasedExpr: - - Expr.timestamp_to_unix_millis: + - AliasedExpression: + - Function.timestamp_to_unix_millis: - Field: time - "millis" - - AliasedExpr: - - Expr.timestamp_to_unix_seconds: + - AliasedExpression: + - Function.timestamp_to_unix_seconds: - Field: time - "seconds" - - AliasedExpr: - - Expr.unix_micros_to_timestamp: + - AliasedExpression: + - Function.unix_micros_to_timestamp: - Field: micros - "from_micros" - - AliasedExpr: - - Expr.unix_millis_to_timestamp: + - AliasedExpression: + - Function.unix_millis_to_timestamp: - Field: millis - "from_millis" - - AliasedExpr: - - Expr.unix_seconds_to_timestamp: + - AliasedExpression: + - Function.unix_seconds_to_timestamp: - Field: seconds - "from_seconds" - - AliasedExpr: - - Expr.timestamp_add: + - AliasedExpression: + - Function.timestamp_add: - Field: time - Constant: "day" - Constant: 1 - "plus_day" - - AliasedExpr: - - Expr.timestamp_subtract: + - AliasedExpression: + - Function.timestamp_subtract: - Field: time - Constant: "hour" - Constant: 1 diff --git a/tests/system/pipeline_e2e/general.yaml b/tests/system/pipeline_e2e/general.yaml index c135853d1..23e98cf3d 100644 --- a/tests/system/pipeline_e2e/general.yaml +++ b/tests/system/pipeline_e2e/general.yaml @@ -56,14 +56,14 @@ tests: pipeline: - Collection: books - AddFields: - - AliasedExpr: - - Expr.string_concat: + - AliasedExpression: + - Function.string_concat: - Field: author - Constant: _ - Field: title - "author_title" - - AliasedExpr: - - Expr.string_concat: + - AliasedExpression: + - Function.string_concat: - Field: title - Constant: _ - Field: author @@ -144,7 +144,6 @@ tests: expression: fieldReferenceValue: author_title name: sort - - description: testPipelineWithOffsetAndLimit pipeline: - Collection: books @@ -192,116 +191,6 @@ tests: title: fieldReferenceValue: title name: select - - description: testArithmeticOperations - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: title - - Constant: To Kill a Mockingbird - - Select: - - AliasedExpr: - - Expr.add: - - Field: rating - - Constant: 1 - - "ratingPlusOne" - - AliasedExpr: - - Expr.subtract: - - Field: published - - Constant: 1900 - - "yearsSince1900" - - AliasedExpr: - - Expr.multiply: - - Field: rating - - Constant: 10 - - "ratingTimesTen" - - AliasedExpr: - - Expr.divide: - - Field: rating - - Constant: 2 - - "ratingDividedByTwo" - - AliasedExpr: - - Expr.multiply: - - Field: rating - - Constant: 20 - - "ratingTimes20" - - AliasedExpr: - - Expr.add: - - Field: rating - - Constant: 3 - - "ratingPlus3" - - AliasedExpr: - - Expr.mod: - - Field: rating - - Constant: 2 - - "ratingMod2" - assert_results: - - ratingPlusOne: 5.2 - yearsSince1900: 60 - ratingTimesTen: 42.0 - ratingDividedByTwo: 2.1 - ratingTimes20: 84 - ratingPlus3: 7.2 - ratingMod2: 0.20000000000000018 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: title - - stringValue: To Kill a Mockingbird - name: equal - name: where - - args: - - mapValue: - fields: - ratingDividedByTwo: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '2' - name: divide - ratingPlusOne: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '1' - name: add - ratingTimesTen: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '10' - name: multiply - yearsSince1900: - functionValue: - args: - - fieldReferenceValue: published - - integerValue: '1900' - name: subtract - ratingTimes20: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '20' - name: multiply - ratingPlus3: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '3' - name: add - ratingMod2: - functionValue: - args: - - fieldReferenceValue: rating - - integerValue: '2' - name: mod - name: select - description: testSampleLimit pipeline: - Collection: books @@ -338,14 +227,14 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: genre - Constant: Romance - Union: - Pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: genre - Constant: Dystopian - Select: @@ -410,12 +299,12 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: title - Constant: "The Hitchhiker's Guide to the Galaxy" - Select: - - AliasedExpr: - - Expr.document_id: + - AliasedExpression: + - Function.document_id: - Field: __name__ - "doc_id" assert_results: @@ -442,28 +331,13 @@ tests: args: - fieldReferenceValue: __name__ name: select - - description: testSum - pipeline: - - Collection: books - - Where: - - Expr.equal: - - Field: genre - - Constant: Science Fiction - - Aggregate: - - AliasedExpr: - - Expr.sum: - - Field: rating - - "total_rating" - assert_results: - - total_rating: 8.8 - - description: testCollectionId pipeline: - Collection: books - Limit: 1 - Select: - - AliasedExpr: - - Expr.collection_id: + - AliasedExpression: + - Function.collection_id: - Field: __name__ - "collectionName" assert_results: @@ -601,13 +475,13 @@ tests: - Distinct: - title - Aggregate: - - AliasedExpr: + - AliasedExpression: - Count: [] - count - Select: - - AliasedExpr: + - AliasedExpression: - Conditional: - - Expr.greater_than_or_equal: + - Function.greater_than_or_equal: - Field: count - Constant: 10 - Constant: True @@ -654,18 +528,18 @@ tests: - booleanValue: true - booleanValue: false name: conditional - - description: testGenericStage + - description: testRawStage pipeline: - - GenericStage: + - RawStage: - "collection" - Value: reference_value: "/books" - - GenericStage: + - RawStage: - "where" - - Expr.equal: + - Function.equal: - Field: title - Constant: The Hitchhiker's Guide to the Galaxy - - GenericStage: + - RawStage: - "select" - Value: map_value: @@ -697,7 +571,7 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: title - Constant: The Hitchhiker's Guide to the Galaxy - Unnest: @@ -735,7 +609,7 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: title - Constant: The Hitchhiker's Guide to the Galaxy - Unnest: diff --git a/tests/system/pipeline_e2e/logical.yaml b/tests/system/pipeline_e2e/logical.yaml index 203be290d..efc91c29a 100644 --- a/tests/system/pipeline_e2e/logical.yaml +++ b/tests/system/pipeline_e2e/logical.yaml @@ -4,10 +4,10 @@ tests: - Collection: books - Where: - And: - - Expr.greater_than: + - Function.greater_than: - Field: rating - Constant: 4.5 - - Expr.equal: + - Function.equal: - Field: genre - Constant: Science Fiction assert_results: @@ -49,10 +49,10 @@ tests: - Collection: books - Where: - Or: - - Expr.equal: + - Function.equal: - Field: genre - Constant: Romance - - Expr.equal: + - Function.equal: - Field: genre - Constant: Dystopian - Select: @@ -105,13 +105,13 @@ tests: - Collection: books - Where: - And: - - Expr.greater_than: + - Function.greater_than: - Field: rating - Constant: 4.2 - - Expr.less_than_or_equal: + - Function.less_than_or_equal: - Field: rating - Constant: 4.5 - - Expr.not_equal: + - Function.not_equal: - Field: genre - Constant: Science Fiction - Select: @@ -176,13 +176,13 @@ tests: - Where: - Or: - And: - - Expr.greater_than: + - Function.greater_than: - Field: rating - Constant: 4.5 - - Expr.equal: + - Function.equal: - Field: genre - Constant: Science Fiction - - Expr.less_than: + - Function.less_than: - Field: published - Constant: 1900 - Select: @@ -243,12 +243,12 @@ tests: - Collection: books - Where: - Not: - - Expr.is_nan: + - Function.is_nan: - Field: rating - Select: - - AliasedExpr: + - AliasedExpression: - Not: - - Expr.is_nan: + - Function.is_nan: - Field: rating - "ratingIsNotNaN" - Limit: 1 @@ -288,7 +288,7 @@ tests: pipeline: - Collection: books - Where: - - Expr.is_not_null: + - Function.is_not_null: - Field: rating assert_count: 10 assert_proto: @@ -307,7 +307,7 @@ tests: pipeline: - Collection: books - Where: - - Expr.is_not_nan: + - Function.is_not_nan: - Field: rating assert_count: 10 assert_proto: @@ -326,7 +326,7 @@ tests: pipeline: - Collection: books - Where: - - Expr.is_absent: + - Function.is_absent: - Field: awards.pulitzer assert_count: 9 assert_proto: @@ -345,14 +345,14 @@ tests: pipeline: - Collection: books - Select: - - AliasedExpr: - - Expr.if_absent: + - AliasedExpression: + - Function.if_absent: - Field: awards.pulitzer - Constant: false - "pulitzer_award" - title - Where: - - Expr.equal: + - Function.equal: - Field: pulitzer_award - Constant: true assert_results: @@ -387,9 +387,9 @@ tests: pipeline: - Collection: books - Select: - - AliasedExpr: - - Expr.is_error: - - Expr.divide: + - AliasedExpression: + - Function.is_error: + - Function.divide: - Field: rating - Constant: "string" - "is_error_result" @@ -422,9 +422,9 @@ tests: pipeline: - Collection: books - Select: - - AliasedExpr: - - Expr.if_error: - - Expr.divide: + - AliasedExpression: + - Function.if_error: + - Function.divide: - Field: rating - Field: genre - Constant: "An error occurred" @@ -459,17 +459,17 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: author - Constant: Douglas Adams - Select: - - AliasedExpr: - - Expr.logical_maximum: + - AliasedExpression: + - Function.logical_maximum: - Field: rating - Constant: 4.5 - "max_rating" - - AliasedExpr: - - Expr.logical_minimum: + - AliasedExpression: + - Function.logical_minimum: - Field: published - Constant: 1900 - "min_published" @@ -509,7 +509,7 @@ tests: pipeline: - Collection: books - Where: - - Expr.greater_than_or_equal: + - Function.greater_than_or_equal: - Field: rating - Constant: 4.6 - Select: @@ -529,11 +529,11 @@ tests: - Collection: books - Where: - And: - - Expr.equal_any: + - Function.equal_any: - Field: genre - - Constant: Romance - Constant: Dystopian - - Expr.not_equal_any: + - Function.not_equal_any: - Field: author - - Constant: "George Orwell" assert_results: @@ -565,9 +565,9 @@ tests: - Collection: books - Where: - And: - - Expr.exists: + - Function.exists: - Field: awards.pulitzer - - Expr.equal: + - Function.equal: - Field: awards.pulitzer - Constant: true - Select: @@ -579,10 +579,10 @@ tests: - Collection: books - Where: - Xor: - - - Expr.equal: + - - Function.equal: - Field: genre - Constant: Romance - - Expr.greater_than: + - Function.greater_than: - Field: published - Constant: 1980 - Select: @@ -605,9 +605,9 @@ tests: - Collection: books - Select: - title - - AliasedExpr: + - AliasedExpression: - Conditional: - - Expr.greater_than: + - Function.greater_than: - Field: published - Constant: 1950 - Constant: "Modern" @@ -631,7 +631,7 @@ tests: pipeline: - Collection: books - Where: - - Expr.greater_than: + - Function.greater_than: - Field: published - Field: rating - Select: @@ -641,22 +641,22 @@ tests: pipeline: - Collection: books - Where: - - Expr.exists: + - Function.exists: - Field: non_existent_field assert_count: 0 - description: testConditionalWithFields pipeline: - Collection: books - Where: - - Expr.equal_any: + - Function.equal_any: - Field: title - - Constant: "Dune" - Constant: "1984" - Select: - title - - AliasedExpr: + - AliasedExpression: - Conditional: - - Expr.greater_than: + - Function.greater_than: - Field: published - Constant: 1950 - Field: author diff --git a/tests/system/pipeline_e2e/map.yaml b/tests/system/pipeline_e2e/map.yaml index 638fe0798..546af1351 100644 --- a/tests/system/pipeline_e2e/map.yaml +++ b/tests/system/pipeline_e2e/map.yaml @@ -7,14 +7,14 @@ tests: - Field: published - DESCENDING - Select: - - AliasedExpr: - - Expr.map_get: + - AliasedExpression: + - Function.map_get: - Field: awards - hugo - "hugoAward" - Field: title - Where: - - Expr.equal: + - Function.equal: - Field: hugoAward - Constant: true assert_results: @@ -59,16 +59,16 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: title - Constant: "Dune" - AddFields: - - AliasedExpr: + - AliasedExpression: - Constant: "hugo" - "award_name" - Select: - - AliasedExpr: - - Expr.map_get: + - AliasedExpression: + - Function.map_get: - Field: awards - Field: award_name - "hugoAward" @@ -111,12 +111,12 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: title - Constant: "Dune" - Select: - - AliasedExpr: - - Expr.map_remove: + - AliasedExpression: + - Function.map_remove: - Field: awards - "nebula" - "awards_removed" @@ -150,12 +150,12 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: title - Constant: "Dune" - Select: - - AliasedExpr: - - Expr.map_merge: + - AliasedExpression: + - Function.map_merge: - Field: awards - Map: elements: {"new_award": true, "hugo": false} @@ -206,7 +206,7 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: awards.hugo - Constant: true - Sort: @@ -255,8 +255,8 @@ tests: - Collection: books - Limit: 1 - Select: - - AliasedExpr: - - Expr.map_merge: + - AliasedExpression: + - Function.map_merge: - Map: elements: {"a": "orig", "b": "orig"} - Map: diff --git a/tests/system/pipeline_e2e/math.yaml b/tests/system/pipeline_e2e/math.yaml index a5a47d4c0..b62c0510b 100644 --- a/tests/system/pipeline_e2e/math.yaml +++ b/tests/system/pipeline_e2e/math.yaml @@ -3,61 +3,61 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: title - Constant: "Dune" - Select: - - AliasedExpr: - - Expr.add: + - AliasedExpression: + - Function.add: - Field: published - Field: rating - "pub_plus_rating" assert_results: - pub_plus_rating: 1969.6 - - description: testMathExpressions + - description: testMathFunctionessions pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: title - Constant: To Kill a Mockingbird - Select: - - AliasedExpr: - - Expr.abs: + - AliasedExpression: + - Function.abs: - Field: rating - "abs_rating" - - AliasedExpr: - - Expr.ceil: + - AliasedExpression: + - Function.ceil: - Field: rating - "ceil_rating" - - AliasedExpr: - - Expr.exp: + - AliasedExpression: + - Function.exp: - Field: rating - "exp_rating" - - AliasedExpr: - - Expr.floor: + - AliasedExpression: + - Function.floor: - Field: rating - "floor_rating" - - AliasedExpr: - - Expr.ln: + - AliasedExpression: + - Function.ln: - Field: rating - "ln_rating" - - AliasedExpr: - - Expr.log10: + - AliasedExpression: + - Function.log10: - Field: rating - "log_rating_base10" - - AliasedExpr: - - Expr.log: + - AliasedExpression: + - Function.log: - Field: rating - Constant: 2 - "log_rating_base2" - - AliasedExpr: - - Expr.pow: + - AliasedExpression: + - Function.pow: - Field: rating - Constant: 2 - "pow_rating" - - AliasedExpr: - - Expr.sqrt: + - AliasedExpression: + - Function.sqrt: - Field: rating - "sqrt_rating" assert_results_approximate: @@ -134,19 +134,19 @@ tests: - fieldReferenceValue: rating name: sqrt name: select - - description: testRoundExpressions + - description: testRoundFunctionessions pipeline: - Collection: books - Where: - - Expr.equal_any: + - Function.equal_any: - Field: title - - Constant: "To Kill a Mockingbird" # rating 4.2 - Constant: "Pride and Prejudice" # rating 4.5 - Constant: "The Lord of the Rings" # rating 4.7 - Select: - title - - AliasedExpr: - - Expr.round: + - AliasedExpression: + - Function.round: - Field: rating - "round_rating" - Sort: @@ -201,42 +201,42 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: title - Constant: To Kill a Mockingbird - Select: - - AliasedExpr: - - Expr.add: + - AliasedExpression: + - Function.add: - Field: rating - Constant: 1 - "ratingPlusOne" - - AliasedExpr: - - Expr.subtract: + - AliasedExpression: + - Function.subtract: - Field: published - Constant: 1900 - "yearsSince1900" - - AliasedExpr: - - Expr.multiply: + - AliasedExpression: + - Function.multiply: - Field: rating - Constant: 10 - "ratingTimesTen" - - AliasedExpr: - - Expr.divide: + - AliasedExpression: + - Function.divide: - Field: rating - Constant: 2 - "ratingDividedByTwo" - - AliasedExpr: - - Expr.multiply: + - AliasedExpression: + - Function.multiply: - Field: rating - Constant: 20 - "ratingTimes20" - - AliasedExpr: - - Expr.add: + - AliasedExpression: + - Function.add: - Field: rating - Constant: 3 - "ratingPlus3" - - AliasedExpr: - - Expr.mod: + - AliasedExpression: + - Function.mod: - Field: rating - Constant: 2 - "ratingMod2" diff --git a/tests/system/pipeline_e2e/string.yaml b/tests/system/pipeline_e2e/string.yaml index b1e3a0b64..d612483e1 100644 --- a/tests/system/pipeline_e2e/string.yaml +++ b/tests/system/pipeline_e2e/string.yaml @@ -7,8 +7,8 @@ tests: - Field: author - ASCENDING - Select: - - AliasedExpr: - - Expr.string_concat: + - AliasedExpression: + - Function.string_concat: - Field: author - Constant: " - " - Field: title @@ -48,7 +48,7 @@ tests: pipeline: - Collection: books - Where: - - Expr.starts_with: + - Function.starts_with: - Field: title - Constant: The - Select: @@ -93,7 +93,7 @@ tests: pipeline: - Collection: books - Where: - - Expr.ends_with: + - Function.ends_with: - Field: title - Constant: y - Select: @@ -136,18 +136,18 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: title - Constant: "The Hitchhiker's Guide to the Galaxy" - Select: - - AliasedExpr: - - Expr.concat: + - AliasedExpression: + - Function.concat: - Field: author - Constant: ": " - Field: title - "author_title" - - AliasedExpr: - - Expr.concat: + - AliasedExpression: + - Function.concat: - Field: tags - - Constant: "new_tag" - "concatenatedTags" @@ -162,20 +162,20 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: title - Constant: "The Hitchhiker's Guide to the Galaxy" - Select: - - AliasedExpr: - - Expr.length: + - AliasedExpression: + - Function.length: - Field: title - "titleLength" - - AliasedExpr: - - Expr.length: + - AliasedExpression: + - Function.length: - Field: tags - "tagsLength" - - AliasedExpr: - - Expr.length: + - AliasedExpression: + - Function.length: - Field: awards - "awardsLength" assert_results: @@ -186,13 +186,13 @@ tests: pipeline: - Collection: books - Select: - - AliasedExpr: - - Expr.char_length: + - AliasedExpression: + - Function.char_length: - Field: title - "titleLength" - title - Where: - - Expr.greater_than: + - Function.greater_than: - Field: titleLength - Constant: 20 - Sort: @@ -244,12 +244,12 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: author - Constant: "Douglas Adams" - Select: - - AliasedExpr: - - Expr.char_length: + - AliasedExpression: + - Function.char_length: - Field: title - "title_length" assert_results: @@ -280,13 +280,13 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: author - Constant: Douglas Adams - Select: - - AliasedExpr: - - Expr.byte_length: - - Expr.string_concat: + - AliasedExpression: + - Function.byte_length: + - Function.string_concat: - Field: title - Constant: _银河系漫游指南 - "title_byte_length" @@ -322,7 +322,7 @@ tests: pipeline: - Collection: books - Where: - - Expr.like: + - Function.like: - Field: title - Constant: "%Guide%" - Select: @@ -334,7 +334,7 @@ tests: pipeline: - Collection: books - Where: - - Expr.regex_contains: + - Function.regex_contains: - Field: title - Constant: "(?i)(the|of)" assert_count: 5 @@ -356,7 +356,7 @@ tests: pipeline: - Collection: books - Where: - - Expr.regex_match: + - Function.regex_match: - Field: title - Constant: ".*(?i)(the|of).*" assert_count: 5 @@ -377,7 +377,7 @@ tests: pipeline: - Collection: books - Where: - - Expr.string_contains: + - Function.string_contains: - Field: title - Constant: "Hitchhiker's" - Select: @@ -388,12 +388,12 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: author - Constant: "Douglas Adams" - Select: - - AliasedExpr: - - Expr.to_lower: + - AliasedExpression: + - Function.to_lower: - Field: title - "lower_title" assert_results: @@ -424,12 +424,12 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: author - Constant: "Douglas Adams" - Select: - - AliasedExpr: - - Expr.to_upper: + - AliasedExpression: + - Function.to_upper: - Field: title - "upper_title" assert_results: @@ -460,13 +460,13 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: author - Constant: "Douglas Adams" - Select: - - AliasedExpr: - - Expr.trim: - - Expr.string_concat: + - AliasedExpression: + - Function.trim: + - Function.string_concat: - Constant: " " - Field: title - Constant: " " @@ -504,12 +504,12 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: author - Constant: "Jane Austen" - Select: - - AliasedExpr: - - Expr.string_reverse: + - AliasedExpression: + - Function.string_reverse: - Field: title - "reversed_title" assert_results: @@ -540,12 +540,12 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: author - Constant: "Douglas Adams" - Select: - - AliasedExpr: - - Expr.substring: + - AliasedExpression: + - Function.substring: - Field: title - Constant: 4 - Constant: 11 @@ -580,12 +580,12 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: author - Constant: "Fyodor Dostoevsky" - Select: - - AliasedExpr: - - Expr.substring: + - AliasedExpression: + - Function.substring: - Field: title - Constant: 10 - "substring_title" @@ -618,12 +618,12 @@ tests: pipeline: - Collection: books - Where: - - Expr.equal: + - Function.equal: - Field: author - Constant: "Douglas Adams" - Select: - - AliasedExpr: - - Expr.join: + - AliasedExpression: + - Function.join: - Field: tags - Constant: ", " - "joined_tags" diff --git a/tests/system/pipeline_e2e/vector.yaml b/tests/system/pipeline_e2e/vector.yaml index 15fc9bcaa..85d265c2d 100644 --- a/tests/system/pipeline_e2e/vector.yaml +++ b/tests/system/pipeline_e2e/vector.yaml @@ -3,8 +3,8 @@ tests: pipeline: - Collection: vectors - Select: - - AliasedExpr: - - Expr.vector_length: + - AliasedExpression: + - Function.vector_length: - Field: embedding - "embedding_length" - Sort: @@ -117,12 +117,12 @@ tests: pipeline: - Collection: vectors - Where: - - Expr.equal: + - Function.equal: - Field: embedding - Vector: [1.0, 2.0, 3.0] - Select: - - AliasedExpr: - - Expr.dot_product: + - AliasedExpression: + - Function.dot_product: - Field: embedding - Vector: [1.0, 1.0, 1.0] - "dot_product_result" @@ -132,12 +132,12 @@ tests: pipeline: - Collection: vectors - Where: - - Expr.equal: + - Function.equal: - Field: embedding - Vector: [1.0, 2.0, 3.0] - Select: - - AliasedExpr: - - Expr.euclidean_distance: + - AliasedExpression: + - Function.euclidean_distance: - Field: embedding - Vector: [1.0, 2.0, 3.0] - "euclidean_distance_result" @@ -147,12 +147,12 @@ tests: pipeline: - Collection: vectors - Where: - - Expr.equal: + - Function.equal: - Field: embedding - Vector: [1.0, 2.0, 3.0] - Select: - - AliasedExpr: - - Expr.cosine_distance: + - AliasedExpression: + - Function.cosine_distance: - Field: embedding - Vector: [1.0, 2.0, 3.0] - "cosine_distance_result" diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index c7eaa6aff..fd74d9458 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -25,7 +25,7 @@ from google.protobuf.json_format import MessageToDict -from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1 import pipeline_expressions from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1 import pipeline_expressions as expr diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 66239f9ea..299283564 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -127,12 +127,12 @@ def test_avg_aggregation_no_alias_to_pb(): "in_alias,expected_alias", [("total", "total"), (None, "field_1")] ) def test_count_aggregation_to_pipeline_expr(in_alias, expected_alias): - from google.cloud.firestore_v1.pipeline_expressions import AliasedAggregate + from google.cloud.firestore_v1.pipeline_expressions import AliasedExpression from google.cloud.firestore_v1.pipeline_expressions import Count count_aggregation = CountAggregation(alias=in_alias) got = count_aggregation._to_pipeline_expr(iter([1])) - assert isinstance(got, AliasedAggregate) + assert isinstance(got, AliasedExpression) assert got.alias == expected_alias assert isinstance(got.expr, Count) assert len(got.expr.params) == 0 @@ -143,11 +143,11 @@ def test_count_aggregation_to_pipeline_expr(in_alias, expected_alias): [("total", "path", "total"), (None, "some_ref", "field_1")], ) def test_sum_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alias): - from google.cloud.firestore_v1.pipeline_expressions import AliasedAggregate + from google.cloud.firestore_v1.pipeline_expressions import AliasedExpression count_aggregation = SumAggregation(expected_path, alias=in_alias) got = count_aggregation._to_pipeline_expr(iter([1])) - assert isinstance(got, AliasedAggregate) + assert isinstance(got, AliasedExpression) assert got.alias == expected_alias assert got.expr.name == "sum" assert got.expr.params[0].path == expected_path @@ -158,11 +158,11 @@ def test_sum_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alia [("total", "path", "total"), (None, "some_ref", "field_1")], ) def test_avg_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alias): - from google.cloud.firestore_v1.pipeline_expressions import AliasedAggregate + from google.cloud.firestore_v1.pipeline_expressions import AliasedExpression count_aggregation = AvgAggregation(expected_path, alias=in_alias) got = count_aggregation._to_pipeline_expr(iter([1])) - assert isinstance(got, AliasedAggregate) + assert isinstance(got, AliasedExpression) assert got.alias == expected_alias assert got.expr.name == "average" assert got.expr.params[0].path == expected_path @@ -1033,7 +1033,7 @@ def test_aggregation_from_query(): ) def test_aggreation_to_pipeline_sum(field, in_alias, out_alias): from google.cloud.firestore_v1.pipeline import Pipeline - from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_stages import Collection, Aggregate client = make_client() parent = client.collection("dee") @@ -1064,7 +1064,7 @@ def test_aggreation_to_pipeline_sum(field, in_alias, out_alias): ) def test_aggreation_to_pipeline_avg(field, in_alias, out_alias): from google.cloud.firestore_v1.pipeline import Pipeline - from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_stages import Collection, Aggregate client = make_client() parent = client.collection("dee") @@ -1094,7 +1094,7 @@ def test_aggreation_to_pipeline_avg(field, in_alias, out_alias): ) def test_aggreation_to_pipeline_count(in_alias, out_alias): from google.cloud.firestore_v1.pipeline import Pipeline - from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_stages import Collection, Aggregate from google.cloud.firestore_v1.pipeline_expressions import Count client = make_client() @@ -1137,7 +1137,7 @@ def test_aggreation_to_pipeline_count_increment(): def test_aggreation_to_pipeline_complex(): from google.cloud.firestore_v1.pipeline import Pipeline - from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate, Select + from google.cloud.firestore_v1.pipeline_stages import Collection, Aggregate, Select client = make_client() query = client.collection("my_col").select(["field_a", "field_b.c"]) diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index f51db482d..eca2ecef1 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -709,7 +709,7 @@ async def test_aggregation_query_stream_w_explain_options_analyze_false(): ) def test_async_aggreation_to_pipeline_sum(field, in_alias, out_alias): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline - from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_stages import Collection, Aggregate client = make_async_client() parent = client.collection("dee") @@ -740,7 +740,7 @@ def test_async_aggreation_to_pipeline_sum(field, in_alias, out_alias): ) def test_async_aggreation_to_pipeline_avg(field, in_alias, out_alias): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline - from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_stages import Collection, Aggregate client = make_async_client() parent = client.collection("dee") @@ -770,7 +770,7 @@ def test_async_aggreation_to_pipeline_avg(field, in_alias, out_alias): ) def test_async_aggreation_to_pipeline_count(in_alias, out_alias): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline - from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate + from google.cloud.firestore_v1.pipeline_stages import Collection, Aggregate from google.cloud.firestore_v1.pipeline_expressions import Count client = make_async_client() @@ -813,7 +813,7 @@ def test_aggreation_to_pipeline_count_increment(): def test_async_aggreation_to_pipeline_complex(): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline - from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate, Select + from google.cloud.firestore_v1.pipeline_stages import Collection, Aggregate, Select client = make_async_client() query = client.collection("my_col").select(["field_a", "field_b.c"]) diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index 353997b8e..5b4df059a 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -605,7 +605,7 @@ def test_asynccollectionreference_recursive(): def test_asynccollectionreference_pipeline(): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline - from google.cloud.firestore_v1._pipeline_stages import Collection + from google.cloud.firestore_v1.pipeline_stages import Collection client = make_async_client() collection = _make_async_collection_reference("collection", client=client) diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py index a11a2951b..2fc39a906 100644 --- a/tests/unit/v1/test_async_pipeline.py +++ b/tests/unit/v1/test_async_pipeline.py @@ -15,7 +15,7 @@ import mock import pytest -from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.pipeline_expressions import Field @@ -67,33 +67,33 @@ def test_async_pipeline_repr_single_stage(): def test_async_pipeline_repr_multiple_stage(): stage_1 = stages.Collection("path") - stage_2 = stages.GenericStage("second", 2) - stage_3 = stages.GenericStage("third", 3) + stage_2 = stages.RawStage("second", 2) + stage_3 = stages.RawStage("third", 3) ppl = _make_async_pipeline(stage_1, stage_2, stage_3) repr_str = repr(ppl) assert repr_str == ( "AsyncPipeline(\n" " Collection(path='/path'),\n" - " GenericStage(name='second'),\n" - " GenericStage(name='third')\n" + " RawStage(name='second'),\n" + " RawStage(name='third')\n" ")" ) def test_async_pipeline_repr_long(): num_stages = 100 - stage_list = [stages.GenericStage("custom", i) for i in range(num_stages)] + stage_list = [stages.RawStage("custom", i) for i in range(num_stages)] ppl = _make_async_pipeline(*stage_list) repr_str = repr(ppl) - assert repr_str.count("GenericStage") == num_stages + assert repr_str.count("RawStage") == num_stages assert repr_str.count("\n") == num_stages + 1 def test_async_pipeline__to_pb(): from google.cloud.firestore_v1.types.pipeline import StructuredPipeline - stage_1 = stages.GenericStage("first") - stage_2 = stages.GenericStage("second") + stage_1 = stages.RawStage("first") + stage_2 = stages.RawStage("second") ppl = _make_async_pipeline(stage_1, stage_2) pb = ppl._to_pb() assert isinstance(pb, StructuredPipeline) @@ -103,9 +103,9 @@ def test_async_pipeline__to_pb(): def test_async_pipeline_append(): """append should create a new pipeline with the additional stage""" - stage_1 = stages.GenericStage("first") + stage_1 = stages.RawStage("first") ppl_1 = _make_async_pipeline(stage_1, client=object()) - stage_2 = stages.GenericStage("second") + stage_2 = stages.RawStage("second") ppl_2 = ppl_1._append(stage_2) assert ppl_1 != ppl_2 assert len(ppl_1.stages) == 1 @@ -130,7 +130,7 @@ async def test_async_pipeline_stream_empty(): mock_rpc = mock.AsyncMock() client._firestore_api.execute_pipeline = mock_rpc mock_rpc.return_value = _async_it([ExecutePipelineResponse()]) - ppl_1 = _make_async_pipeline(stages.GenericStage("s"), client=client) + ppl_1 = _make_async_pipeline(stages.RawStage("s"), client=client) results = [r async for r in ppl_1.stream()] assert results == [] @@ -159,7 +159,7 @@ async def test_async_pipeline_stream_no_doc_ref(): mock_rpc.return_value = _async_it( [ExecutePipelineResponse(results=[Document()], execution_time={"seconds": 9})] ) - ppl_1 = _make_async_pipeline(stages.GenericStage("s"), client=client) + ppl_1 = _make_async_pipeline(stages.RawStage("s"), client=client) results = [r async for r in ppl_1.stream()] assert len(results) == 1 @@ -401,8 +401,8 @@ async def test_async_pipeline_stream_stream_equivalence_mocked(): ("unnest", ("field_name", "alias"), stages.Unnest), ("unnest", (Field.of("n"), Field.of("alias")), stages.Unnest), ("unnest", ("n", "a", stages.UnnestOptions("idx")), stages.Unnest), - ("generic_stage", ("stage_name",), stages.GenericStage), - ("generic_stage", ("stage_name", Field.of("n")), stages.GenericStage), + ("raw_stage", ("stage_name",), stages.RawStage), + ("raw_stage", ("stage_name", Field.of("n")), stages.RawStage), ("offset", (1,), stages.Offset), ("limit", (1,), stages.Limit), ("aggregate", (Field.of("n").as_("alias"),), stages.Aggregate), diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 7efa0dacf..925010070 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -18,7 +18,7 @@ import pytest from tests.unit.v1._test_helpers import make_client -from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1 import pipeline_stages as stages def _make_base_query(*args, **kwargs): @@ -2040,7 +2040,9 @@ def test__query_pipeline_composite_filter(): client = make_client() in_filter = FieldFilter("field_a", "==", "value_a") query = client.collection("my_col").where(filter=in_filter) - with mock.patch.object(expr.BooleanExpr, "_from_query_filter_pb") as convert_mock: + with mock.patch.object( + expr.BooleanExpression, "_from_query_filter_pb" + ) as convert_mock: pipeline = query.pipeline() convert_mock.assert_called_once_with(in_filter._to_pb(), client) assert len(pipeline.stages) == 2 diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index 9e615541a..76418204b 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -514,7 +514,7 @@ def test_stream_w_read_time(query_class): def test_collectionreference_pipeline(): from tests.unit.v1 import _test_helpers from google.cloud.firestore_v1.pipeline import Pipeline - from google.cloud.firestore_v1._pipeline_stages import Collection + from google.cloud.firestore_v1.pipeline_stages import Collection client = _test_helpers.make_client() collection = _make_collection_reference("collection", client=client) diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index 161eef1cc..b6d353f1a 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -15,7 +15,7 @@ import mock import pytest -from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.pipeline_expressions import Field @@ -62,33 +62,33 @@ def test_pipeline_repr_single_stage(): def test_pipeline_repr_multiple_stage(): stage_1 = stages.Collection("path") - stage_2 = stages.GenericStage("second", 2) - stage_3 = stages.GenericStage("third", 3) + stage_2 = stages.RawStage("second", 2) + stage_3 = stages.RawStage("third", 3) ppl = _make_pipeline(stage_1, stage_2, stage_3) repr_str = repr(ppl) assert repr_str == ( "Pipeline(\n" " Collection(path='/path'),\n" - " GenericStage(name='second'),\n" - " GenericStage(name='third')\n" + " RawStage(name='second'),\n" + " RawStage(name='third')\n" ")" ) def test_pipeline_repr_long(): num_stages = 100 - stage_list = [stages.GenericStage("custom", i) for i in range(num_stages)] + stage_list = [stages.RawStage("custom", i) for i in range(num_stages)] ppl = _make_pipeline(*stage_list) repr_str = repr(ppl) - assert repr_str.count("GenericStage") == num_stages + assert repr_str.count("RawStage") == num_stages assert repr_str.count("\n") == num_stages + 1 def test_pipeline__to_pb(): from google.cloud.firestore_v1.types.pipeline import StructuredPipeline - stage_1 = stages.GenericStage("first") - stage_2 = stages.GenericStage("second") + stage_1 = stages.RawStage("first") + stage_2 = stages.RawStage("second") ppl = _make_pipeline(stage_1, stage_2) pb = ppl._to_pb() assert isinstance(pb, StructuredPipeline) @@ -99,9 +99,9 @@ def test_pipeline__to_pb(): def test_pipeline_append(): """append should create a new pipeline with the additional stage""" - stage_1 = stages.GenericStage("first") + stage_1 = stages.RawStage("first") ppl_1 = _make_pipeline(stage_1, client=object()) - stage_2 = stages.GenericStage("second") + stage_2 = stages.RawStage("second") ppl_2 = ppl_1._append(stage_2) assert ppl_1 != ppl_2 assert len(ppl_1.stages) == 1 @@ -124,7 +124,7 @@ def test_pipeline_stream_empty(): client._database = "B" mock_rpc = client._firestore_api.execute_pipeline mock_rpc.return_value = [ExecutePipelineResponse()] - ppl_1 = _make_pipeline(stages.GenericStage("s"), client=client) + ppl_1 = _make_pipeline(stages.RawStage("s"), client=client) results = list(ppl_1.stream()) assert results == [] @@ -151,7 +151,7 @@ def test_pipeline_stream_no_doc_ref(): mock_rpc.return_value = [ ExecutePipelineResponse(results=[Document()], execution_time={"seconds": 9}) ] - ppl_1 = _make_pipeline(stages.GenericStage("s"), client=client) + ppl_1 = _make_pipeline(stages.RawStage("s"), client=client) results = list(ppl_1.stream()) assert len(results) == 1 @@ -378,8 +378,8 @@ def test_pipeline_execute_stream_equivalence_mocked(): ("unnest", ("field_name", "alias"), stages.Unnest), ("unnest", (Field.of("n"), Field.of("alias")), stages.Unnest), ("unnest", ("n", "a", stages.UnnestOptions("idx")), stages.Unnest), - ("generic_stage", ("stage_name",), stages.GenericStage), - ("generic_stage", ("stage_name", Field.of("n")), stages.GenericStage), + ("raw_stage", ("stage_name",), stages.RawStage), + ("raw_stage", ("stage_name", Field.of("n")), stages.RawStage), ("offset", (1,), stages.Offset), ("limit", (1,), stages.Limit), ("aggregate", (Field.of("n").as_("alias"),), stages.Aggregate), diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 522b51c84..d3b7dfbf2 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -23,8 +23,8 @@ from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1._helpers import GeoPoint import google.cloud.firestore_v1.pipeline_expressions as expr -from google.cloud.firestore_v1.pipeline_expressions import BooleanExpr -from google.cloud.firestore_v1.pipeline_expressions import Expr +from google.cloud.firestore_v1.pipeline_expressions import BooleanExpression +from google.cloud.firestore_v1.pipeline_expressions import Expression from google.cloud.firestore_v1.pipeline_expressions import Constant from google.cloud.firestore_v1.pipeline_expressions import Field from google.cloud.firestore_v1.pipeline_expressions import Ordering @@ -167,7 +167,7 @@ def test_equality(self, first, second, expected): class TestSelectable: """ - contains tests for each Expr class that derives from Selectable + contains tests for each Expression class that derives from Selectable """ def test_ctor(self): @@ -224,7 +224,7 @@ def test_to_map(self): assert result[0] == "field1" assert result[1] == Value(field_reference_value="field1") - class TestAliasedExpr: + class TestAliasedExpression: def test_repr(self): instance = Field.of("field1").as_("alias1") assert repr(instance) == "Field.of('field1').as_('alias1')" @@ -232,14 +232,14 @@ def test_repr(self): def test_ctor(self): arg = Field.of("field1") alias = "alias1" - instance = expr.AliasedExpr(arg, alias) + instance = expr.AliasedExpression(arg, alias) assert instance.expr == arg assert instance.alias == alias def test_to_pb(self): arg = Field.of("field1") alias = "alias1" - instance = expr.AliasedExpr(arg, alias) + instance = expr.AliasedExpression(arg, alias) result = instance._to_pb() assert result.map_value.fields.get("alias1") == arg._to_pb() @@ -249,35 +249,8 @@ def test_to_map(self): assert result[0] == "alias1" assert result[1] == Value(field_reference_value="field1") - class TestAliasedAggregate: - def test_repr(self): - instance = Field.of("field1").maximum().as_("alias1") - assert repr(instance) == "Field.of('field1').maximum().as_('alias1')" - - def test_ctor(self): - arg = Expr.minimum("field1") - alias = "alias1" - instance = expr.AliasedAggregate(arg, alias) - assert instance.expr == arg - assert instance.alias == alias - def test_to_pb(self): - arg = Field.of("field1").average() - alias = "alias1" - instance = expr.AliasedAggregate(arg, alias) - result = instance._to_pb() - assert result.map_value.fields.get("alias1") == arg._to_pb() - - def test_to_map(self): - arg = Field.of("field1").count() - alias = "alias1" - instance = expr.AliasedAggregate(arg, alias) - result = instance._to_map() - assert result[0] == "alias1" - assert result[1] == arg._to_pb() - - -class TestBooleanExpr: +class TestBooleanExpression: def test__from_query_filter_pb_composite_filter_or(self, mock_client): """ test composite OR filters @@ -305,7 +278,7 @@ def test__from_query_filter_pb_composite_filter_or(self, mock_client): composite_filter=composite_pb ) - result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpression._from_query_filter_pb(wrapped_filter_pb, mock_client) # should include existance checks field1 = Field.of("field1") @@ -344,7 +317,7 @@ def test__from_query_filter_pb_composite_filter_and(self, mock_client): composite_filter=composite_pb ) - result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpression._from_query_filter_pb(wrapped_filter_pb, mock_client) # should include existance checks field1 = Field.of("field1") @@ -391,7 +364,7 @@ def test__from_query_filter_pb_composite_filter_nested(self, mock_client): composite_filter=outer_or_pb ) - result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpression._from_query_filter_pb(wrapped_filter_pb, mock_client) field1 = Field.of("field1") field2 = Field.of("field2") @@ -422,23 +395,23 @@ def test__from_query_filter_pb_composite_filter_unknown_op(self, mock_client): ) with pytest.raises(TypeError, match="Unexpected CompositeFilter operator type"): - BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) + BooleanExpression._from_query_filter_pb(wrapped_filter_pb, mock_client) @pytest.mark.parametrize( "op_enum, expected_expr_func", [ - (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NAN, Expr.is_nan), + (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NAN, Expression.is_nan), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NAN, - Expr.is_not_nan, + Expression.is_not_nan, ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, - Expr.is_null, + Expression.is_null, ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, - Expr.is_not_null, + Expression.is_not_null, ), ], ) @@ -455,7 +428,7 @@ def test__from_query_filter_pb_unary_filter( ) wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) - result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpression._from_query_filter_pb(wrapped_filter_pb, mock_client) field_expr_inst = Field.of(field_path) expected_condition = expected_expr_func(field_expr_inst) @@ -476,7 +449,7 @@ def test__from_query_filter_pb_unary_filter_unknown_op(self, mock_client): wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) with pytest.raises(TypeError, match="Unexpected UnaryFilter operator type"): - BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) + BooleanExpression._from_query_filter_pb(wrapped_filter_pb, mock_client) @pytest.mark.parametrize( "op_enum, value, expected_expr_func", @@ -484,48 +457,48 @@ def test__from_query_filter_pb_unary_filter_unknown_op(self, mock_client): ( query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, 10, - Expr.less_than, + Expression.less_than, ), ( query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN_OR_EQUAL, 10, - Expr.less_than_or_equal, + Expression.less_than_or_equal, ), ( query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, 10, - Expr.greater_than, + Expression.greater_than, ), ( query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN_OR_EQUAL, 10, - Expr.greater_than_or_equal, + Expression.greater_than_or_equal, ), - (query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, 10, Expr.equal), + (query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, 10, Expression.equal), ( query_pb.StructuredQuery.FieldFilter.Operator.NOT_EQUAL, 10, - Expr.not_equal, + Expression.not_equal, ), ( query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS, 10, - Expr.array_contains, + Expression.array_contains, ), ( query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS_ANY, [10, 20], - Expr.array_contains_any, + Expression.array_contains_any, ), ( query_pb.StructuredQuery.FieldFilter.Operator.IN, [10, 20], - Expr.equal_any, + Expression.equal_any, ), ( query_pb.StructuredQuery.FieldFilter.Operator.NOT_IN, [10, 20], - Expr.not_equal_any, + Expression.not_equal_any, ), ], ) @@ -544,7 +517,7 @@ def test__from_query_filter_pb_field_filter( ) wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) - result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpression._from_query_filter_pb(wrapped_filter_pb, mock_client) field_expr = Field.of(field_path) # convert values into constants @@ -571,7 +544,7 @@ def test__from_query_filter_pb_field_filter_unknown_op(self, mock_client): wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) with pytest.raises(TypeError, match="Unexpected FieldFilter operator type"): - BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) + BooleanExpression._from_query_filter_pb(wrapped_filter_pb, mock_client) def test__from_query_filter_pb_unknown_filter_type(self, mock_client): """ @@ -579,7 +552,7 @@ def test__from_query_filter_pb_unknown_filter_type(self, mock_client): """ # Test with an unexpected protobuf type with pytest.raises(TypeError, match="Unexpected filter type"): - BooleanExpr._from_query_filter_pb(document_pb.Value(), mock_client) + BooleanExpression._from_query_filter_pb(document_pb.Value(), mock_client) class TestArray: @@ -655,7 +628,7 @@ def test_equals(self): class TestExpressionMethods: """ - contains test methods for each Expr method + contains test methods for each Expression method """ @pytest.mark.parametrize( @@ -693,26 +666,26 @@ def test_equality(self, first, second, expected): assert (first == second) is expected def _make_arg(self, name="Mock"): - class MockExpr(Constant): + class MockExpression(Constant): def __repr__(self): return self.value - arg = MockExpr(name) + arg = MockExpression(name) return arg def test_expression_wrong_first_type(self): """The first argument should always be an expression or field name""" expected_message = "must be called on an Expression or a string representing a field. got ." with pytest.raises(TypeError) as e1: - Expr.logical_minimum(5, 1) + Expression.logical_minimum(5, 1) assert str(e1.value) == f"'logical_minimum' {expected_message}" with pytest.raises(TypeError) as e2: - Expr.sqrt(9) + Expression.sqrt(9) assert str(e2.value) == f"'sqrt' {expected_message}" def test_expression_w_string(self): """should be able to use string for first argument. Should be interpreted as Field""" - instance = Expr.logical_minimum("first", "second") + instance = Expression.logical_minimum("first", "second") assert isinstance(instance.params[0], Field) assert instance.params[0].path == "first" @@ -736,7 +709,7 @@ def test_or(self): def test_array_contains(self): arg1 = self._make_arg("ArrayField") arg2 = self._make_arg("Element") - instance = Expr.array_contains(arg1, arg2) + instance = Expression.array_contains(arg1, arg2) assert instance.name == "array_contains" assert instance.params == [arg1, arg2] assert repr(instance) == "ArrayField.array_contains(Element)" @@ -747,7 +720,7 @@ def test_array_contains_any(self): arg1 = self._make_arg("ArrayField") arg2 = self._make_arg("Element1") arg3 = self._make_arg("Element2") - instance = Expr.array_contains_any(arg1, [arg2, arg3]) + instance = Expression.array_contains_any(arg1, [arg2, arg3]) assert instance.name == "array_contains_any" assert isinstance(instance.params[1], expr.Array) assert instance.params[0] == arg1 @@ -761,7 +734,7 @@ def test_array_contains_any(self): def test_exists(self): arg1 = self._make_arg("Field") - instance = Expr.exists(arg1) + instance = Expression.exists(arg1) assert instance.name == "exists" assert instance.params == [arg1] assert repr(instance) == "Field.exists()" @@ -771,7 +744,7 @@ def test_exists(self): def test_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = Expr.equal(arg1, arg2) + instance = Expression.equal(arg1, arg2) assert instance.name == "equal" assert instance.params == [arg1, arg2] assert repr(instance) == "Left.equal(Right)" @@ -781,7 +754,7 @@ def test_equal(self): def test_greater_than_or_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = Expr.greater_than_or_equal(arg1, arg2) + instance = Expression.greater_than_or_equal(arg1, arg2) assert instance.name == "greater_than_or_equal" assert instance.params == [arg1, arg2] assert repr(instance) == "Left.greater_than_or_equal(Right)" @@ -791,7 +764,7 @@ def test_greater_than_or_equal(self): def test_greater_than(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = Expr.greater_than(arg1, arg2) + instance = Expression.greater_than(arg1, arg2) assert instance.name == "greater_than" assert instance.params == [arg1, arg2] assert repr(instance) == "Left.greater_than(Right)" @@ -801,7 +774,7 @@ def test_greater_than(self): def test_less_than_or_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = Expr.less_than_or_equal(arg1, arg2) + instance = Expression.less_than_or_equal(arg1, arg2) assert instance.name == "less_than_or_equal" assert instance.params == [arg1, arg2] assert repr(instance) == "Left.less_than_or_equal(Right)" @@ -811,7 +784,7 @@ def test_less_than_or_equal(self): def test_less_than(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = Expr.less_than(arg1, arg2) + instance = Expression.less_than(arg1, arg2) assert instance.name == "less_than" assert instance.params == [arg1, arg2] assert repr(instance) == "Left.less_than(Right)" @@ -821,7 +794,7 @@ def test_less_than(self): def test_not_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = Expr.not_equal(arg1, arg2) + instance = Expression.not_equal(arg1, arg2) assert instance.name == "not_equal" assert instance.params == [arg1, arg2] assert repr(instance) == "Left.not_equal(Right)" @@ -832,7 +805,7 @@ def test_equal_any(self): arg1 = self._make_arg("Field") arg2 = self._make_arg("Value1") arg3 = self._make_arg("Value2") - instance = Expr.equal_any(arg1, [arg2, arg3]) + instance = Expression.equal_any(arg1, [arg2, arg3]) assert instance.name == "equal_any" assert isinstance(instance.params[1], expr.Array) assert instance.params[0] == arg1 @@ -845,7 +818,7 @@ def test_not_equal_any(self): arg1 = self._make_arg("Field") arg2 = self._make_arg("Value1") arg3 = self._make_arg("Value2") - instance = Expr.not_equal_any(arg1, [arg2, arg3]) + instance = Expression.not_equal_any(arg1, [arg2, arg3]) assert instance.name == "not_equal_any" assert isinstance(instance.params[1], expr.Array) assert instance.params[0] == arg1 @@ -856,7 +829,7 @@ def test_not_equal_any(self): def test_is_absent(self): arg1 = self._make_arg("Field") - instance = Expr.is_absent(arg1) + instance = Expression.is_absent(arg1) assert instance.name == "is_absent" assert instance.params == [arg1] assert repr(instance) == "Field.is_absent()" @@ -865,17 +838,17 @@ def test_is_absent(self): def test_if_absent(self): arg1 = self._make_arg("Field") - arg2 = self._make_arg("ThenExpr") - instance = Expr.if_absent(arg1, arg2) + arg2 = self._make_arg("ThenExpression") + instance = Expression.if_absent(arg1, arg2) assert instance.name == "if_absent" assert instance.params == [arg1, arg2] - assert repr(instance) == "Field.if_absent(ThenExpr)" + assert repr(instance) == "Field.if_absent(ThenExpression)" infix_instance = arg1.if_absent(arg2) assert infix_instance == instance def test_is_nan(self): arg1 = self._make_arg("Value") - instance = Expr.is_nan(arg1) + instance = Expression.is_nan(arg1) assert instance.name == "is_nan" assert instance.params == [arg1] assert repr(instance) == "Value.is_nan()" @@ -884,7 +857,7 @@ def test_is_nan(self): def test_is_not_nan(self): arg1 = self._make_arg("Value") - instance = Expr.is_not_nan(arg1) + instance = Expression.is_not_nan(arg1) assert instance.name == "is_not_nan" assert instance.params == [arg1] assert repr(instance) == "Value.is_not_nan()" @@ -893,7 +866,7 @@ def test_is_not_nan(self): def test_is_null(self): arg1 = self._make_arg("Value") - instance = Expr.is_null(arg1) + instance = Expression.is_null(arg1) assert instance.name == "is_null" assert instance.params == [arg1] assert repr(instance) == "Value.is_null()" @@ -902,7 +875,7 @@ def test_is_null(self): def test_is_not_null(self): arg1 = self._make_arg("Value") - instance = Expr.is_not_null(arg1) + instance = Expression.is_not_null(arg1) assert instance.name == "is_not_null" assert instance.params == [arg1] assert repr(instance) == "Value.is_not_null()" @@ -911,7 +884,7 @@ def test_is_not_null(self): def test_is_error(self): arg1 = self._make_arg("Value") - instance = Expr.is_error(arg1) + instance = Expression.is_error(arg1) assert instance.name == "is_error" assert instance.params == [arg1] assert repr(instance) == "Value.is_error()" @@ -920,11 +893,11 @@ def test_is_error(self): def test_if_error(self): arg1 = self._make_arg("Value") - arg2 = self._make_arg("ThenExpr") - instance = Expr.if_error(arg1, arg2) + arg2 = self._make_arg("ThenExpression") + instance = Expression.if_error(arg1, arg2) assert instance.name == "if_error" assert instance.params == [arg1, arg2] - assert repr(instance) == "Value.if_error(ThenExpr)" + assert repr(instance) == "Value.if_error(ThenExpression)" infix_instance = arg1.if_error(arg2) assert infix_instance == instance @@ -939,7 +912,7 @@ def test_array_contains_all(self): arg1 = self._make_arg("ArrayField") arg2 = self._make_arg("Element1") arg3 = self._make_arg("Element2") - instance = Expr.array_contains_all(arg1, [arg2, arg3]) + instance = Expression.array_contains_all(arg1, [arg2, arg3]) assert instance.name == "array_contains_all" assert isinstance(instance.params[1], expr.Array) assert instance.params[0] == arg1 @@ -952,71 +925,73 @@ def test_array_contains_all(self): assert infix_instance == instance def test_ends_with(self): - arg1 = self._make_arg("Expr") + arg1 = self._make_arg("Expression") arg2 = self._make_arg("Postfix") - instance = Expr.ends_with(arg1, arg2) + instance = Expression.ends_with(arg1, arg2) assert instance.name == "ends_with" assert instance.params == [arg1, arg2] - assert repr(instance) == "Expr.ends_with(Postfix)" + assert repr(instance) == "Expression.ends_with(Postfix)" infix_instance = arg1.ends_with(arg2) assert infix_instance == instance def test_conditional(self): arg1 = self._make_arg("Condition") - arg2 = self._make_arg("ThenExpr") - arg3 = self._make_arg("ElseExpr") + arg2 = self._make_arg("ThenExpression") + arg3 = self._make_arg("ElseExpression") instance = expr.Conditional(arg1, arg2, arg3) assert instance.name == "conditional" assert instance.params == [arg1, arg2, arg3] - assert repr(instance) == "Conditional(Condition, ThenExpr, ElseExpr)" + assert ( + repr(instance) == "Conditional(Condition, ThenExpression, ElseExpression)" + ) def test_like(self): - arg1 = self._make_arg("Expr") + arg1 = self._make_arg("Expression") arg2 = self._make_arg("Pattern") - instance = Expr.like(arg1, arg2) + instance = Expression.like(arg1, arg2) assert instance.name == "like" assert instance.params == [arg1, arg2] - assert repr(instance) == "Expr.like(Pattern)" + assert repr(instance) == "Expression.like(Pattern)" infix_instance = arg1.like(arg2) assert infix_instance == instance def test_regex_contains(self): - arg1 = self._make_arg("Expr") + arg1 = self._make_arg("Expression") arg2 = self._make_arg("Regex") - instance = Expr.regex_contains(arg1, arg2) + instance = Expression.regex_contains(arg1, arg2) assert instance.name == "regex_contains" assert instance.params == [arg1, arg2] - assert repr(instance) == "Expr.regex_contains(Regex)" + assert repr(instance) == "Expression.regex_contains(Regex)" infix_instance = arg1.regex_contains(arg2) assert infix_instance == instance def test_regex_match(self): - arg1 = self._make_arg("Expr") + arg1 = self._make_arg("Expression") arg2 = self._make_arg("Regex") - instance = Expr.regex_match(arg1, arg2) + instance = Expression.regex_match(arg1, arg2) assert instance.name == "regex_match" assert instance.params == [arg1, arg2] - assert repr(instance) == "Expr.regex_match(Regex)" + assert repr(instance) == "Expression.regex_match(Regex)" infix_instance = arg1.regex_match(arg2) assert infix_instance == instance def test_starts_with(self): - arg1 = self._make_arg("Expr") + arg1 = self._make_arg("Expression") arg2 = self._make_arg("Prefix") - instance = Expr.starts_with(arg1, arg2) + instance = Expression.starts_with(arg1, arg2) assert instance.name == "starts_with" assert instance.params == [arg1, arg2] - assert repr(instance) == "Expr.starts_with(Prefix)" + assert repr(instance) == "Expression.starts_with(Prefix)" infix_instance = arg1.starts_with(arg2) assert infix_instance == instance def test_string_contains(self): - arg1 = self._make_arg("Expr") + arg1 = self._make_arg("Expression") arg2 = self._make_arg("Substring") - instance = Expr.string_contains(arg1, arg2) + instance = Expression.string_contains(arg1, arg2) assert instance.name == "string_contains" assert instance.params == [arg1, arg2] - assert repr(instance) == "Expr.string_contains(Substring)" + assert repr(instance) == "Expression.string_contains(Substring)" infix_instance = arg1.string_contains(arg2) assert infix_instance == instance @@ -1031,7 +1006,7 @@ def test_xor(self): def test_divide(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = Expr.divide(arg1, arg2) + instance = Expression.divide(arg1, arg2) assert instance.name == "divide" assert instance.params == [arg1, arg2] assert repr(instance) == "Left.divide(Right)" @@ -1041,7 +1016,7 @@ def test_divide(self): def test_logical_maximum(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = Expr.logical_maximum(arg1, arg2) + instance = Expression.logical_maximum(arg1, arg2) assert instance.name == "maximum" assert instance.params == [arg1, arg2] assert repr(instance) == "Left.logical_maximum(Right)" @@ -1051,7 +1026,7 @@ def test_logical_maximum(self): def test_logical_minimum(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = Expr.logical_minimum(arg1, arg2) + instance = Expression.logical_minimum(arg1, arg2) assert instance.name == "minimum" assert instance.params == [arg1, arg2] assert repr(instance) == "Left.logical_minimum(Right)" @@ -1060,7 +1035,7 @@ def test_logical_minimum(self): def test_to_lower(self): arg1 = self._make_arg("Input") - instance = Expr.to_lower(arg1) + instance = Expression.to_lower(arg1) assert instance.name == "to_lower" assert instance.params == [arg1] assert repr(instance) == "Input.to_lower()" @@ -1069,7 +1044,7 @@ def test_to_lower(self): def test_to_upper(self): arg1 = self._make_arg("Input") - instance = Expr.to_upper(arg1) + instance = Expression.to_upper(arg1) assert instance.name == "to_upper" assert instance.params == [arg1] assert repr(instance) == "Input.to_upper()" @@ -1078,7 +1053,7 @@ def test_to_upper(self): def test_trim(self): arg1 = self._make_arg("Input") - instance = Expr.trim(arg1) + instance = Expression.trim(arg1) assert instance.name == "trim" assert instance.params == [arg1] assert repr(instance) == "Input.trim()" @@ -1087,7 +1062,7 @@ def test_trim(self): def test_string_reverse(self): arg1 = self._make_arg("Input") - instance = Expr.string_reverse(arg1) + instance = Expression.string_reverse(arg1) assert instance.name == "string_reverse" assert instance.params == [arg1] assert repr(instance) == "Input.string_reverse()" @@ -1097,7 +1072,7 @@ def test_string_reverse(self): def test_substring(self): arg1 = self._make_arg("Input") arg2 = self._make_arg("Position") - instance = Expr.substring(arg1, arg2) + instance = Expression.substring(arg1, arg2) assert instance.name == "substring" assert instance.params == [arg1, arg2] assert repr(instance) == "Input.substring(Position)" @@ -1108,7 +1083,7 @@ def test_substring_w_length(self): arg1 = self._make_arg("Input") arg2 = self._make_arg("Position") arg3 = self._make_arg("Length") - instance = Expr.substring(arg1, arg2, arg3) + instance = Expression.substring(arg1, arg2, arg3) assert instance.name == "substring" assert instance.params == [arg1, arg2, arg3] assert repr(instance) == "Input.substring(Position, Length)" @@ -1118,7 +1093,7 @@ def test_substring_w_length(self): def test_join(self): arg1 = self._make_arg("Array") arg2 = self._make_arg("Separator") - instance = Expr.join(arg1, arg2) + instance = Expression.join(arg1, arg2) assert instance.name == "join" assert instance.params == [arg1, arg2] assert repr(instance) == "Array.join(Separator)" @@ -1128,7 +1103,7 @@ def test_join(self): def test_map_get(self): arg1 = self._make_arg("Map") arg2 = "key" - instance = Expr.map_get(arg1, arg2) + instance = Expression.map_get(arg1, arg2) assert instance.name == "map_get" assert instance.params == [arg1, Constant.of(arg2)] assert repr(instance) == "Map.map_get(Constant.of('key'))" @@ -1138,7 +1113,7 @@ def test_map_get(self): def test_map_remove(self): arg1 = self._make_arg("Map") arg2 = "key" - instance = Expr.map_remove(arg1, arg2) + instance = Expression.map_remove(arg1, arg2) assert instance.name == "map_remove" assert instance.params == [arg1, Constant.of(arg2)] assert repr(instance) == "Map.map_remove(Constant.of('key'))" @@ -1149,7 +1124,7 @@ def test_map_merge(self): arg1 = expr.Map({"a": 1}) arg2 = expr.Map({"b": 2}) arg3 = {"c": 3} - instance = Expr.map_merge(arg1, arg2, arg3) + instance = Expression.map_merge(arg1, arg2, arg3) assert instance.name == "map_merge" assert instance.params == [arg1, arg2, expr.Map(arg3)] assert repr(instance) == "Map({'a': 1}).map_merge(Map({'b': 2}), Map({'c': 3}))" @@ -1159,7 +1134,7 @@ def test_map_merge(self): def test_mod(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = Expr.mod(arg1, arg2) + instance = Expression.mod(arg1, arg2) assert instance.name == "mod" assert instance.params == [arg1, arg2] assert repr(instance) == "Left.mod(Right)" @@ -1169,7 +1144,7 @@ def test_mod(self): def test_multiply(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = Expr.multiply(arg1, arg2) + instance = Expression.multiply(arg1, arg2) assert instance.name == "multiply" assert instance.params == [arg1, arg2] assert repr(instance) == "Left.multiply(Right)" @@ -1180,7 +1155,7 @@ def test_string_concat(self): arg1 = self._make_arg("Str1") arg2 = self._make_arg("Str2") arg3 = self._make_arg("Str3") - instance = Expr.string_concat(arg1, arg2, arg3) + instance = Expression.string_concat(arg1, arg2, arg3) assert instance.name == "string_concat" assert instance.params == [arg1, arg2, arg3] assert repr(instance) == "Str1.string_concat(Str2, Str3)" @@ -1190,7 +1165,7 @@ def test_string_concat(self): def test_subtract(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = Expr.subtract(arg1, arg2) + instance = Expression.subtract(arg1, arg2) assert instance.name == "subtract" assert instance.params == [arg1, arg2] assert repr(instance) == "Left.subtract(Right)" @@ -1207,7 +1182,7 @@ def test_timestamp_add(self): arg1 = self._make_arg("Timestamp") arg2 = self._make_arg("Unit") arg3 = self._make_arg("Amount") - instance = Expr.timestamp_add(arg1, arg2, arg3) + instance = Expression.timestamp_add(arg1, arg2, arg3) assert instance.name == "timestamp_add" assert instance.params == [arg1, arg2, arg3] assert repr(instance) == "Timestamp.timestamp_add(Unit, Amount)" @@ -1218,7 +1193,7 @@ def test_timestamp_subtract(self): arg1 = self._make_arg("Timestamp") arg2 = self._make_arg("Unit") arg3 = self._make_arg("Amount") - instance = Expr.timestamp_subtract(arg1, arg2, arg3) + instance = Expression.timestamp_subtract(arg1, arg2, arg3) assert instance.name == "timestamp_subtract" assert instance.params == [arg1, arg2, arg3] assert repr(instance) == "Timestamp.timestamp_subtract(Unit, Amount)" @@ -1227,7 +1202,7 @@ def test_timestamp_subtract(self): def test_timestamp_to_unix_micros(self): arg1 = self._make_arg("Input") - instance = Expr.timestamp_to_unix_micros(arg1) + instance = Expression.timestamp_to_unix_micros(arg1) assert instance.name == "timestamp_to_unix_micros" assert instance.params == [arg1] assert repr(instance) == "Input.timestamp_to_unix_micros()" @@ -1236,7 +1211,7 @@ def test_timestamp_to_unix_micros(self): def test_timestamp_to_unix_millis(self): arg1 = self._make_arg("Input") - instance = Expr.timestamp_to_unix_millis(arg1) + instance = Expression.timestamp_to_unix_millis(arg1) assert instance.name == "timestamp_to_unix_millis" assert instance.params == [arg1] assert repr(instance) == "Input.timestamp_to_unix_millis()" @@ -1245,7 +1220,7 @@ def test_timestamp_to_unix_millis(self): def test_timestamp_to_unix_seconds(self): arg1 = self._make_arg("Input") - instance = Expr.timestamp_to_unix_seconds(arg1) + instance = Expression.timestamp_to_unix_seconds(arg1) assert instance.name == "timestamp_to_unix_seconds" assert instance.params == [arg1] assert repr(instance) == "Input.timestamp_to_unix_seconds()" @@ -1254,7 +1229,7 @@ def test_timestamp_to_unix_seconds(self): def test_unix_micros_to_timestamp(self): arg1 = self._make_arg("Input") - instance = Expr.unix_micros_to_timestamp(arg1) + instance = Expression.unix_micros_to_timestamp(arg1) assert instance.name == "unix_micros_to_timestamp" assert instance.params == [arg1] assert repr(instance) == "Input.unix_micros_to_timestamp()" @@ -1263,7 +1238,7 @@ def test_unix_micros_to_timestamp(self): def test_unix_millis_to_timestamp(self): arg1 = self._make_arg("Input") - instance = Expr.unix_millis_to_timestamp(arg1) + instance = Expression.unix_millis_to_timestamp(arg1) assert instance.name == "unix_millis_to_timestamp" assert instance.params == [arg1] assert repr(instance) == "Input.unix_millis_to_timestamp()" @@ -1272,7 +1247,7 @@ def test_unix_millis_to_timestamp(self): def test_unix_seconds_to_timestamp(self): arg1 = self._make_arg("Input") - instance = Expr.unix_seconds_to_timestamp(arg1) + instance = Expression.unix_seconds_to_timestamp(arg1) assert instance.name == "unix_seconds_to_timestamp" assert instance.params == [arg1] assert repr(instance) == "Input.unix_seconds_to_timestamp()" @@ -1282,7 +1257,7 @@ def test_unix_seconds_to_timestamp(self): def test_euclidean_distance(self): arg1 = self._make_arg("Vector1") arg2 = self._make_arg("Vector2") - instance = Expr.euclidean_distance(arg1, arg2) + instance = Expression.euclidean_distance(arg1, arg2) assert instance.name == "euclidean_distance" assert instance.params == [arg1, arg2] assert repr(instance) == "Vector1.euclidean_distance(Vector2)" @@ -1292,7 +1267,7 @@ def test_euclidean_distance(self): def test_cosine_distance(self): arg1 = self._make_arg("Vector1") arg2 = self._make_arg("Vector2") - instance = Expr.cosine_distance(arg1, arg2) + instance = Expression.cosine_distance(arg1, arg2) assert instance.name == "cosine_distance" assert instance.params == [arg1, arg2] assert repr(instance) == "Vector1.cosine_distance(Vector2)" @@ -1302,7 +1277,7 @@ def test_cosine_distance(self): def test_dot_product(self): arg1 = self._make_arg("Vector1") arg2 = self._make_arg("Vector2") - instance = Expr.dot_product(arg1, arg2) + instance = Expression.dot_product(arg1, arg2) assert instance.name == "dot_product" assert instance.params == [arg1, arg2] assert repr(instance) == "Vector1.dot_product(Vector2)" @@ -1329,7 +1304,7 @@ def test_vector_ctor(self, method, input): def test_vector_length(self): arg1 = self._make_arg("Array") - instance = Expr.vector_length(arg1) + instance = Expression.vector_length(arg1) assert instance.name == "vector_length" assert instance.params == [arg1] assert repr(instance) == "Array.vector_length()" @@ -1339,7 +1314,7 @@ def test_vector_length(self): def test_add(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = Expr.add(arg1, arg2) + instance = Expression.add(arg1, arg2) assert instance.name == "add" assert instance.params == [arg1, arg2] assert repr(instance) == "Left.add(Right)" @@ -1348,7 +1323,7 @@ def test_add(self): def test_abs(self): arg1 = self._make_arg("Value") - instance = Expr.abs(arg1) + instance = Expression.abs(arg1) assert instance.name == "abs" assert instance.params == [arg1] assert repr(instance) == "Value.abs()" @@ -1357,7 +1332,7 @@ def test_abs(self): def test_ceil(self): arg1 = self._make_arg("Value") - instance = Expr.ceil(arg1) + instance = Expression.ceil(arg1) assert instance.name == "ceil" assert instance.params == [arg1] assert repr(instance) == "Value.ceil()" @@ -1366,7 +1341,7 @@ def test_ceil(self): def test_exp(self): arg1 = self._make_arg("Value") - instance = Expr.exp(arg1) + instance = Expression.exp(arg1) assert instance.name == "exp" assert instance.params == [arg1] assert repr(instance) == "Value.exp()" @@ -1375,7 +1350,7 @@ def test_exp(self): def test_floor(self): arg1 = self._make_arg("Value") - instance = Expr.floor(arg1) + instance = Expression.floor(arg1) assert instance.name == "floor" assert instance.params == [arg1] assert repr(instance) == "Value.floor()" @@ -1384,7 +1359,7 @@ def test_floor(self): def test_ln(self): arg1 = self._make_arg("Value") - instance = Expr.ln(arg1) + instance = Expression.ln(arg1) assert instance.name == "ln" assert instance.params == [arg1] assert repr(instance) == "Value.ln()" @@ -1394,7 +1369,7 @@ def test_ln(self): def test_log(self): arg1 = self._make_arg("Value") arg2 = self._make_arg("Base") - instance = Expr.log(arg1, arg2) + instance = Expression.log(arg1, arg2) assert instance.name == "log" assert instance.params == [arg1, arg2] assert repr(instance) == "Value.log(Base)" @@ -1403,7 +1378,7 @@ def test_log(self): def test_log10(self): arg1 = self._make_arg("Value") - instance = Expr.log10(arg1) + instance = Expression.log10(arg1) assert instance.name == "log10" assert instance.params == [arg1] assert repr(instance) == "Value.log10()" @@ -1413,7 +1388,7 @@ def test_log10(self): def test_pow(self): arg1 = self._make_arg("Value") arg2 = self._make_arg("Exponent") - instance = Expr.pow(arg1, arg2) + instance = Expression.pow(arg1, arg2) assert instance.name == "pow" assert instance.params == [arg1, arg2] assert repr(instance) == "Value.pow(Exponent)" @@ -1422,7 +1397,7 @@ def test_pow(self): def test_round(self): arg1 = self._make_arg("Value") - instance = Expr.round(arg1) + instance = Expression.round(arg1) assert instance.name == "round" assert instance.params == [arg1] assert repr(instance) == "Value.round()" @@ -1431,7 +1406,7 @@ def test_round(self): def test_sqrt(self): arg1 = self._make_arg("Value") - instance = Expr.sqrt(arg1) + instance = Expression.sqrt(arg1) assert instance.name == "sqrt" assert instance.params == [arg1] assert repr(instance) == "Value.sqrt()" @@ -1440,7 +1415,7 @@ def test_sqrt(self): def test_array_length(self): arg1 = self._make_arg("Array") - instance = Expr.array_length(arg1) + instance = Expression.array_length(arg1) assert instance.name == "array_length" assert instance.params == [arg1] assert repr(instance) == "Array.array_length()" @@ -1449,7 +1424,7 @@ def test_array_length(self): def test_array_reverse(self): arg1 = self._make_arg("Array") - instance = Expr.array_reverse(arg1) + instance = Expression.array_reverse(arg1) assert instance.name == "array_reverse" assert instance.params == [arg1] assert repr(instance) == "Array.array_reverse()" @@ -1459,7 +1434,7 @@ def test_array_reverse(self): def test_array_concat(self): arg1 = self._make_arg("ArrayRef1") arg2 = self._make_arg("ArrayRef2") - instance = Expr.array_concat(arg1, arg2) + instance = Expression.array_concat(arg1, arg2) assert instance.name == "array_concat" assert instance.params == [arg1, arg2] assert repr(instance) == "ArrayRef1.array_concat(ArrayRef2)" @@ -1480,20 +1455,20 @@ def test_array_concat_multiple(self): ) def test_byte_length(self): - arg1 = self._make_arg("Expr") - instance = Expr.byte_length(arg1) + arg1 = self._make_arg("Expression") + instance = Expression.byte_length(arg1) assert instance.name == "byte_length" assert instance.params == [arg1] - assert repr(instance) == "Expr.byte_length()" + assert repr(instance) == "Expression.byte_length()" infix_instance = arg1.byte_length() assert infix_instance == instance def test_char_length(self): - arg1 = self._make_arg("Expr") - instance = Expr.char_length(arg1) + arg1 = self._make_arg("Expression") + instance = Expression.char_length(arg1) assert instance.name == "char_length" assert instance.params == [arg1] - assert repr(instance) == "Expr.char_length()" + assert repr(instance) == "Expression.char_length()" infix_instance = arg1.char_length() assert infix_instance == instance @@ -1501,7 +1476,7 @@ def test_concat(self): arg1 = self._make_arg("First") arg2 = self._make_arg("Second") arg3 = "Third" - instance = Expr.concat(arg1, arg2, arg3) + instance = Expression.concat(arg1, arg2, arg3) assert instance.name == "concat" assert instance.params == [arg1, arg2, Constant.of(arg3)] assert repr(instance) == "First.concat(Second, Constant.of('Third'))" @@ -1509,17 +1484,17 @@ def test_concat(self): assert infix_instance == instance def test_length(self): - arg1 = self._make_arg("Expr") - instance = Expr.length(arg1) + arg1 = self._make_arg("Expression") + instance = Expression.length(arg1) assert instance.name == "length" assert instance.params == [arg1] - assert repr(instance) == "Expr.length()" + assert repr(instance) == "Expression.length()" infix_instance = arg1.length() assert infix_instance == instance def test_collection_id(self): arg1 = self._make_arg("Value") - instance = Expr.collection_id(arg1) + instance = Expression.collection_id(arg1) assert instance.name == "collection_id" assert instance.params == [arg1] assert repr(instance) == "Value.collection_id()" @@ -1528,7 +1503,7 @@ def test_collection_id(self): def test_document_id(self): arg1 = self._make_arg("Value") - instance = Expr.document_id(arg1) + instance = Expression.document_id(arg1) assert instance.name == "document_id" assert instance.params == [arg1] assert repr(instance) == "Value.document_id()" @@ -1537,7 +1512,7 @@ def test_document_id(self): def test_sum(self): arg1 = self._make_arg("Value") - instance = Expr.sum(arg1) + instance = Expression.sum(arg1) assert instance.name == "sum" assert instance.params == [arg1] assert repr(instance) == "Value.sum()" @@ -1546,7 +1521,7 @@ def test_sum(self): def test_average(self): arg1 = self._make_arg("Value") - instance = Expr.average(arg1) + instance = Expression.average(arg1) assert instance.name == "average" assert instance.params == [arg1] assert repr(instance) == "Value.average()" @@ -1555,7 +1530,7 @@ def test_average(self): def test_count(self): arg1 = self._make_arg("Value") - instance = Expr.count(arg1) + instance = Expression.count(arg1) assert instance.name == "count" assert instance.params == [arg1] assert repr(instance) == "Value.count()" @@ -1570,7 +1545,7 @@ def test_base_count(self): def test_count_if(self): arg1 = self._make_arg("Value") - instance = Expr.count_if(arg1) + instance = Expression.count_if(arg1) assert instance.name == "count_if" assert instance.params == [arg1] assert repr(instance) == "Value.count_if()" @@ -1579,7 +1554,7 @@ def test_count_if(self): def test_count_distinct(self): arg1 = self._make_arg("Value") - instance = Expr.count_distinct(arg1) + instance = Expression.count_distinct(arg1) assert instance.name == "count_distinct" assert instance.params == [arg1] assert repr(instance) == "Value.count_distinct()" @@ -1588,7 +1563,7 @@ def test_count_distinct(self): def test_minimum(self): arg1 = self._make_arg("Value") - instance = Expr.minimum(arg1) + instance = Expression.minimum(arg1) assert instance.name == "minimum" assert instance.params == [arg1] assert repr(instance) == "Value.minimum()" @@ -1597,7 +1572,7 @@ def test_minimum(self): def test_maximum(self): arg1 = self._make_arg("Value") - instance = Expr.maximum(arg1) + instance = Expression.maximum(arg1) assert instance.name == "maximum" assert instance.params == [arg1] assert repr(instance) == "Value.maximum()" diff --git a/tests/unit/v1/test_pipeline_source.py b/tests/unit/v1/test_pipeline_source.py index bed1bd05a..e29b763e2 100644 --- a/tests/unit/v1/test_pipeline_source.py +++ b/tests/unit/v1/test_pipeline_source.py @@ -17,7 +17,7 @@ from google.cloud.firestore_v1.async_pipeline import AsyncPipeline from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.async_client import AsyncClient -from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.base_document import BaseDocumentReference diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py index 1d2ff8760..18c9d6790 100644 --- a/tests/unit/v1/test_pipeline_stages.py +++ b/tests/unit/v1/test_pipeline_stages.py @@ -16,7 +16,7 @@ from unittest import mock from google.cloud.firestore_v1.base_pipeline import _BasePipeline -import google.cloud.firestore_v1._pipeline_stages as stages +import google.cloud.firestore_v1.pipeline_stages as stages from google.cloud.firestore_v1.pipeline_expressions import ( Constant, Field, @@ -420,9 +420,9 @@ def test_to_pb_no_options(self): assert len(result.args) == 3 -class TestGenericStage: +class TestRawStage: def _make_one(self, *args, **kwargs): - return stages.GenericStage(*args, **kwargs) + return stages.RawStage(*args, **kwargs) @pytest.mark.parametrize( "input_args,expected_params", @@ -471,7 +471,7 @@ def test_ctor_with_options(self): standard_unnest = stages.Unnest( field, alias, options=stages.UnnestOptions(**options) ) - generic_unnest = stages.GenericStage("unnest", field, alias, options=options) + generic_unnest = stages.RawStage("unnest", field, alias, options=options) assert standard_unnest._pb_args() == generic_unnest._pb_args() assert standard_unnest._pb_options() == generic_unnest._pb_options() assert standard_unnest._to_pb() == generic_unnest._to_pb() @@ -479,8 +479,8 @@ def test_ctor_with_options(self): @pytest.mark.parametrize( "input_args,expected", [ - (("name",), "GenericStage(name='name')"), - (("custom", Value(string_value="val")), "GenericStage(name='custom')"), + (("name",), "RawStage(name='name')"), + (("custom", Value(string_value="val")), "RawStage(name='custom')"), ], ) def test_repr(self, input_args, expected): From 9a35dfe88f4863cbe17bfe0492901ef3eed18189 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 31 Oct 2025 11:56:42 -0700 Subject: [PATCH 09/27] feat: replace_with pipeline stage (#1121) --- google/cloud/firestore_v1/base_pipeline.py | 41 ++++++++++++++++++++ google/cloud/firestore_v1/pipeline_stages.py | 11 ++++++ tests/system/pipeline_e2e/general.yaml | 31 ++++++++++++++- tests/unit/v1/test_pipeline.py | 2 + tests/unit/v1/test_pipeline_expressions.py | 18 ++++----- tests/unit/v1/test_pipeline_stages.py | 33 ++++++++++++++++ 6 files changed, 126 insertions(+), 10 deletions(-) diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 63fee19fa..c66321793 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -315,6 +315,47 @@ def find_nearest( stages.FindNearest(field, vector, distance_measure, options) ) + def replace_with( + self, + field: Selectable, + ) -> "_BasePipeline": + """ + Fully overwrites all fields in a document with those coming from a nested map. + + This stage allows you to emit a map value as a document. Each key of the map becomes a field + on the document that contains the corresponding value. + + Example: + Input document: + ```json + { + "name": "John Doe Jr.", + "parents": { + "father": "John Doe Sr.", + "mother": "Jane Doe" + } + } + ``` + + >>> # Emit the 'parents' map as the document + >>> pipeline = client.pipeline().collection("people").replace_with(Field.of("parents")) + + Output document: + ```json + { + "father": "John Doe Sr.", + "mother": "Jane Doe" + } + ``` + + Args: + field: The `Selectable` field containing the map whose content will + replace the document. + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.ReplaceWith(field)) + def sort(self, *orders: stages.Ordering) -> "_BasePipeline": """ Sorts the documents from previous stages based on one or more `Ordering` criteria. diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 95ce32021..37829465e 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -362,6 +362,17 @@ def _pb_args(self) -> list[Value]: return [f._to_pb() for f in self.fields] +class ReplaceWith(Stage): + """Replaces the document content with the value of a specified field.""" + + def __init__(self, field: Selectable): + super().__init__("replace_with") + self.field = Field(field) if isinstance(field, str) else field + + def _pb_args(self): + return [self.field._to_pb(), Value(string_value="full_replace")] + + class Sample(Stage): """Performs pseudo-random sampling of documents.""" diff --git a/tests/system/pipeline_e2e/general.yaml b/tests/system/pipeline_e2e/general.yaml index 23e98cf3d..8ff3f60d2 100644 --- a/tests/system/pipeline_e2e/general.yaml +++ b/tests/system/pipeline_e2e/general.yaml @@ -655,4 +655,33 @@ tests: fieldReferenceValue: tags_alias index: fieldReferenceValue: index - name: select \ No newline at end of file + name: select + - description: replaceWith + pipeline: + - Collection: books + - Where: + - Function.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - ReplaceWith: + - Field: awards + assert_results: + - hugo: True + nebula: False + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "The Hitchhiker's Guide to the Galaxy" + name: equal + name: where + - args: + - fieldReferenceValue: awards + - stringValue: full_replace + name: replace_with \ No newline at end of file diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index b6d353f1a..e203f6d69 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -369,6 +369,8 @@ def test_pipeline_execute_stream_equivalence_mocked(): ("name", [0.1], "cosine", stages.FindNearestOptions(10)), stages.FindNearest, ), + ("replace_with", ("name",), stages.ReplaceWith), + ("replace_with", (Field.of("n"),), stages.ReplaceWith), ("sort", (Field.of("n").descending(),), stages.Sort), ("sort", (Field.of("n").descending(), Field.of("m").ascending()), stages.Sort), ("sample", (10,), stages.Sample), diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index d3b7dfbf2..ec7f4901e 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -555,6 +555,14 @@ def test__from_query_filter_pb_unknown_filter_type(self, mock_client): BooleanExpression._from_query_filter_pb(document_pb.Value(), mock_client) +class TestFunction: + def test_equals(self): + assert expr.Function.sqrt("1") == expr.Function.sqrt("1") + assert expr.Function.sqrt("1") != expr.Function.sqrt("2") + assert expr.Function.sqrt("1") != expr.Function.sum("1") + assert expr.Function.sqrt("1") != object() + + class TestArray: """Tests for the array class""" @@ -618,15 +626,7 @@ def test_w_exprs(self): ) -class TestFunction: - def test_equals(self): - assert expr.Function.sqrt("1") == expr.Function.sqrt("1") - assert expr.Function.sqrt("1") != expr.Function.sqrt("2") - assert expr.Function.sqrt("1") != expr.Function.sum("1") - assert expr.Function.sqrt("1") != object() - - -class TestExpressionMethods: +class TestExpressionessionMethods: """ contains test methods for each Expression method """ diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py index 18c9d6790..a2d466f47 100644 --- a/tests/unit/v1/test_pipeline_stages.py +++ b/tests/unit/v1/test_pipeline_stages.py @@ -562,6 +562,39 @@ def test_to_pb(self): assert len(result.options) == 0 +class TestReplaceWith: + def _make_one(self, *args, **kwargs): + return stages.ReplaceWith(*args, **kwargs) + + @pytest.mark.parametrize( + "in_field,expected_field", + [ + ("test", Field.of("test")), + ("test", Field.of("test")), + ("test", Field.of("test")), + (Field.of("test"), Field.of("test")), + (Field.of("test"), Field.of("test")), + ], + ) + def test_ctor(self, in_field, expected_field): + instance = self._make_one(in_field) + assert instance.field == expected_field + assert instance.name == "replace_with" + + def test_repr(self): + instance = self._make_one("test") + repr_str = repr(instance) + assert repr_str == "ReplaceWith(field=Field.of('test'))" + + def test_to_pb(self): + instance = self._make_one(Field.of("test")) + result = instance._to_pb() + assert result.name == "replace_with" + assert len(result.args) == 2 + assert result.args[0].field_reference_value == "test" + assert result.args[1].string_value == "full_replace" + + class TestSample: class TestSampleOptions: def test_ctor_percent(self): From 8e83a40de0f77042d8694859e653270046ccbebc Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 6 Nov 2025 09:17:58 -0800 Subject: [PATCH 10/27] chore: remove is_nan and is_null (#1123) --- .../firestore_v1/pipeline_expressions.py | 63 ++------------- tests/system/pipeline_e2e/array.yaml | 6 +- tests/system/pipeline_e2e/data.yaml | 7 +- tests/system/pipeline_e2e/logical.yaml | 81 +++++-------------- tests/system/test_pipeline_acceptance.py | 4 + tests/system/test_system.py | 4 +- tests/system/test_system_async.py | 2 +- tests/unit/v1/test_pipeline_expressions.py | 59 ++++---------- 8 files changed, 61 insertions(+), 165 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 780fe8e8e..7a90dea1d 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -816,56 +816,6 @@ def if_absent(self, default_value: Expression | CONSTANT_TYPE) -> "Expression": [self, self._cast_to_expr_or_convert_to_constant(default_value)], ) - @expose_as_static - def is_nan(self) -> "BooleanExpression": - """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). - - Example: - >>> # Check if the result of a calculation is NaN - >>> Field.of("value").divide(0).is_nan() - - Returns: - A new `Expression` representing the 'isNaN' check. - """ - return BooleanExpression("is_nan", [self]) - - @expose_as_static - def is_not_nan(self) -> "BooleanExpression": - """Creates an expression that checks if this expression evaluates to a non-'NaN' (Not a Number) value. - - Example: - >>> # Check if the result of a calculation is not NaN - >>> Field.of("value").divide(1).is_not_nan() - - Returns: - A new `Expression` representing the 'is not NaN' check. - """ - return BooleanExpression("is_not_nan", [self]) - - @expose_as_static - def is_null(self) -> "BooleanExpression": - """Creates an expression that checks if the value of a field is 'Null'. - - Example: - >>> Field.of("value").is_null() - - Returns: - A new `Expression` representing the 'isNull' check. - """ - return BooleanExpression("is_null", [self]) - - @expose_as_static - def is_not_null(self) -> "BooleanExpression": - """Creates an expression that checks if the value of a field is not 'Null'. - - Example: - >>> Field.of("value").is_not_null() - - Returns: - A new `Expression` representing the 'isNotNull' check. - """ - return BooleanExpression("is_not_null", [self]) - @expose_as_static def is_error(self): """Creates an expression that checks if a given expression produces an error @@ -1653,7 +1603,10 @@ def of(value: CONSTANT_TYPE) -> Constant[CONSTANT_TYPE]: return Constant(value) def __repr__(self): - return f"Constant.of({self.value!r})" + value_str = repr(self.value) + if isinstance(self.value, float) and value_str == "nan": + value_str = "math.nan" + return f"Constant.of({value_str})" def __hash__(self): return hash(self.value) @@ -1827,13 +1780,13 @@ def _from_query_filter_pb(filter_pb, client): elif isinstance(filter_pb, Query_pb.UnaryFilter): field = Field.of(filter_pb.field.field_path) if filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NAN: - return And(field.exists(), field.is_nan()) + return And(field.exists(), field.equal(float("nan"))) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NAN: - return And(field.exists(), field.is_not_nan()) + return And(field.exists(), Not(field.equal(float("nan")))) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NULL: - return And(field.exists(), field.is_null()) + return And(field.exists(), field.equal(None)) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NULL: - return And(field.exists(), field.is_not_null()) + return And(field.exists(), Not(field.equal(None))) else: raise TypeError(f"Unexpected UnaryFilter operator type: {filter_pb.op}") elif isinstance(filter_pb, Query_pb.FieldFilter): diff --git a/tests/system/pipeline_e2e/array.yaml b/tests/system/pipeline_e2e/array.yaml index 3da16264d..d32491d8b 100644 --- a/tests/system/pipeline_e2e/array.yaml +++ b/tests/system/pipeline_e2e/array.yaml @@ -185,7 +185,7 @@ tests: - AliasedExpression: - Function.array_concat: - Field: tags - - Constant: ["new_tag", "another_tag"] + - ["new_tag", "another_tag"] - "concatenatedTags" assert_results: - concatenatedTags: @@ -232,8 +232,8 @@ tests: - AliasedExpression: - Function.array_concat: - Field: tags - - Constant: ["sci-fi"] - - Constant: ["classic", "epic"] + - ["sci-fi"] + - ["classic", "epic"] - "concatenatedTags" assert_results: - concatenatedTags: diff --git a/tests/system/pipeline_e2e/data.yaml b/tests/system/pipeline_e2e/data.yaml index 902f7782d..f2533d2b1 100644 --- a/tests/system/pipeline_e2e/data.yaml +++ b/tests/system/pipeline_e2e/data.yaml @@ -139,4 +139,9 @@ data: vec3: embedding: [5.0, 6.0, 7.0] vec4: - embedding: [1.0, 2.0, 4.0] \ No newline at end of file + embedding: [1.0, 2.0, 4.0] + errors: + doc_with_nan: + value: "NaN" + doc_with_null: + value: null \ No newline at end of file diff --git a/tests/system/pipeline_e2e/logical.yaml b/tests/system/pipeline_e2e/logical.yaml index efc91c29a..0bc889ac1 100644 --- a/tests/system/pipeline_e2e/logical.yaml +++ b/tests/system/pipeline_e2e/logical.yaml @@ -238,89 +238,48 @@ tests: expression: fieldReferenceValue: title name: sort - - description: testChecks + - description: testIsNull pipeline: - - Collection: books + - Collection: errors - Where: - - Not: - - Function.is_nan: - - Field: rating - - Select: - - AliasedExpression: - - Not: - - Function.is_nan: - - Field: rating - - "ratingIsNotNaN" - - Limit: 1 + - Function.equal: + - Field: value + - null assert_results: - - ratingIsNotNaN: true + - value: null assert_proto: pipeline: stages: - args: - - referenceValue: /books + - referenceValue: /errors name: collection - args: - functionValue: args: - - functionValue: - args: - - fieldReferenceValue: rating - name: is_nan - name: not - name: where - - args: - - mapValue: - fields: - ratingIsNotNaN: - functionValue: - args: - - functionValue: - args: - - fieldReferenceValue: rating - name: is_nan - name: not - name: select - - args: - - integerValue: '1' - name: limit - - description: testIsNotNull - pipeline: - - Collection: books - - Where: - - Function.is_not_null: - - Field: rating - assert_count: 10 - assert_proto: - pipeline: - stages: - - args: - - referenceValue: /books - name: collection - - args: - - functionValue: - args: - - fieldReferenceValue: rating - name: is_not_null + - fieldReferenceValue: value + - nullValue: null + name: equal name: where - - description: testIsNotNaN + - description: testIsNan pipeline: - - Collection: books + - Collection: errors - Where: - - Function.is_not_nan: - - Field: rating - assert_count: 10 + - Function.equal: + - Field: value + - NaN + assert_count: 1 assert_proto: pipeline: stages: - args: - - referenceValue: /books + - referenceValue: /errors name: collection - args: - functionValue: args: - - fieldReferenceValue: rating - name: is_not_nan + - fieldReferenceValue: value + - doubleValue: NaN + name: equal name: where - description: testIsAbsent pipeline: diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index fd74d9458..02a27ca86 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -270,6 +270,8 @@ def _parse_expressions(client, yaml_element: Any): } elif _is_expr_string(yaml_element): return getattr(pipeline_expressions, yaml_element)() + elif yaml_element == "NaN": + return float(yaml_element) else: return yaml_element @@ -351,6 +353,8 @@ def _parse_yaml_types(data): return parsed_datetime except ValueError: pass + if data == "NaN": + return float("NaN") return data diff --git a/tests/system/test_system.py b/tests/system/test_system.py index ce4f64b32..592a73f67 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -1727,7 +1727,7 @@ def test_query_with_order_dot_key(client, cleanup, database): assert found_data == [snap.to_dict() for snap in cursor_with_key_data] -@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_unary(client, cleanup, database): collection_name = "unary" + UNIQUE_RESOURCE_ID collection = client.collection(collection_name) @@ -3287,7 +3287,7 @@ def test_query_with_or_composite_filter(collection, database): verify_pipeline(query) -@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "aggregation_type,expected_value", [("count", 5), ("sum", 100), ("avg", 4.0)] ) diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index ed679402a..f87da0112 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -1651,7 +1651,7 @@ async def test_query_with_order_dot_key(client, cleanup, database): assert found_data == [snap.to_dict() for snap in cursor_with_key_data] -@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_unary(client, cleanup, database): collection_name = "unary" + UNIQUE_RESOURCE_ID collection = client.collection(collection_name) diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index ec7f4901e..1546dbe66 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -14,6 +14,7 @@ import pytest import mock +import math import datetime from google.cloud.firestore_v1 import _helpers @@ -117,6 +118,13 @@ def test_to_pb(self, input_val, to_pb_val): instance = Constant.of(input_val) assert instance._to_pb() == to_pb_val + @pytest.mark.parametrize("input", [float("nan"), math.nan]) + def test_nan_to_pb(self, input): + instance = Constant.of(input) + assert repr(instance) == "Constant.of(math.nan)" + pb_val = instance._to_pb() + assert math.isnan(pb_val.double_value) + @pytest.mark.parametrize( "input_val,expected", [ @@ -284,7 +292,7 @@ def test__from_query_filter_pb_composite_filter_or(self, mock_client): field1 = Field.of("field1") field2 = Field.of("field2") expected_cond1 = expr.And(field1.exists(), field1.equal(Constant("val1"))) - expected_cond2 = expr.And(field2.exists(), field2.is_null()) + expected_cond2 = expr.And(field2.exists(), field2.equal(None)) expected = expr.Or(expected_cond1, expected_cond2) assert repr(result) == repr(expected) @@ -371,7 +379,7 @@ def test__from_query_filter_pb_composite_filter_nested(self, mock_client): field3 = Field.of("field3") expected_cond1 = expr.And(field1.exists(), field1.equal(Constant("val1"))) expected_cond2 = expr.And(field2.exists(), field2.greater_than(Constant(10))) - expected_cond3 = expr.And(field3.exists(), field3.is_not_null()) + expected_cond3 = expr.And(field3.exists(), expr.Not(field3.equal(None))) expected_inner_and = expr.And(expected_cond2, expected_cond3) expected_outer_or = expr.Or(expected_cond1, expected_inner_and) @@ -400,18 +408,21 @@ def test__from_query_filter_pb_composite_filter_unknown_op(self, mock_client): @pytest.mark.parametrize( "op_enum, expected_expr_func", [ - (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NAN, Expression.is_nan), + ( + query_pb.StructuredQuery.UnaryFilter.Operator.IS_NAN, + lambda x: x.equal(float("nan")), + ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NAN, - Expression.is_not_nan, + lambda x: expr.Not(x.equal(float("nan"))), ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, - Expression.is_null, + lambda x: x.equal(None), ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, - Expression.is_not_null, + lambda x: expr.Not(x.equal(None)), ), ], ) @@ -846,42 +857,6 @@ def test_if_absent(self): infix_instance = arg1.if_absent(arg2) assert infix_instance == instance - def test_is_nan(self): - arg1 = self._make_arg("Value") - instance = Expression.is_nan(arg1) - assert instance.name == "is_nan" - assert instance.params == [arg1] - assert repr(instance) == "Value.is_nan()" - infix_instance = arg1.is_nan() - assert infix_instance == instance - - def test_is_not_nan(self): - arg1 = self._make_arg("Value") - instance = Expression.is_not_nan(arg1) - assert instance.name == "is_not_nan" - assert instance.params == [arg1] - assert repr(instance) == "Value.is_not_nan()" - infix_instance = arg1.is_not_nan() - assert infix_instance == instance - - def test_is_null(self): - arg1 = self._make_arg("Value") - instance = Expression.is_null(arg1) - assert instance.name == "is_null" - assert instance.params == [arg1] - assert repr(instance) == "Value.is_null()" - infix_instance = arg1.is_null() - assert infix_instance == instance - - def test_is_not_null(self): - arg1 = self._make_arg("Value") - instance = Expression.is_not_null(arg1) - assert instance.name == "is_not_null" - assert instance.params == [arg1] - assert repr(instance) == "Value.is_not_null()" - infix_instance = arg1.is_not_null() - assert infix_instance == instance - def test_is_error(self): arg1 = self._make_arg("Value") instance = Expression.is_error(arg1) From b77002d69564683ef454a29f4bcae4d98f0f06fc Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 7 Nov 2025 12:13:33 -0800 Subject: [PATCH 11/27] feat: pipelines create_from() (#1124) --- google/cloud/firestore_v1/base_aggregation.py | 8 +++-- google/cloud/firestore_v1/base_collection.py | 7 ++-- google/cloud/firestore_v1/base_query.py | 12 +++---- google/cloud/firestore_v1/pipeline_source.py | 20 ++++++++++++ tests/system/test_system.py | 3 +- tests/system/test_system_async.py | 3 +- tests/unit/v1/test_aggregation.py | 10 +++--- tests/unit/v1/test_async_aggregation.py | 10 +++--- tests/unit/v1/test_async_collection.py | 8 +---- tests/unit/v1/test_async_query.py | 4 +-- tests/unit/v1/test_base_collection.py | 7 ++-- tests/unit/v1/test_base_query.py | 32 +++++++------------ tests/unit/v1/test_collection.py | 9 +----- tests/unit/v1/test_pipeline_source.py | 26 +++++++++++++++ tests/unit/v1/test_query.py | 4 +-- 15 files changed, 98 insertions(+), 65 deletions(-) diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index d8d7cc6b4..6f392207e 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -48,6 +48,7 @@ from google.cloud.firestore_v1.stream_generator import ( StreamGenerator, ) + from google.cloud.firestore_v1.pipeline_source import PipelineSource import datetime @@ -356,14 +357,15 @@ def stream( A generator of the query results. """ - def pipeline(self): + def _build_pipeline(self, source: "PipelineSource"): """ Convert this query into a Pipeline Queries containing a `cursor` or `limit_to_last` are not currently supported + Args: + source: the PipelineSource to build the pipeline off of Raises: - - ValueError: raised if Query wasn't created with an associated client - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` Returns: a Pipeline representing the query @@ -371,4 +373,4 @@ def pipeline(self): # use autoindexer to keep track of which field number to use for un-aliased fields autoindexer = itertools.count(start=1) exprs = [a._to_pipeline_expr(autoindexer) for a in self._aggregations] - return self._nested_query.pipeline().aggregate(*exprs) + return self._nested_query._build_pipeline(source).aggregate(*exprs) diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index a4cc2b1b7..567fe4d8a 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -48,6 +48,7 @@ from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.field_path import FieldPath + from google.cloud.firestore_v1.pipeline_source import PipelineSource from google.cloud.firestore_v1.query_profile import ExplainOptions from google.cloud.firestore_v1.query_results import QueryResultsList from google.cloud.firestore_v1.stream_generator import StreamGenerator @@ -602,18 +603,20 @@ def find_nearest( distance_threshold=distance_threshold, ) - def pipeline(self): + def _build_pipeline(self, source: "PipelineSource"): """ Convert this query into a Pipeline Queries containing a `cursor` or `limit_to_last` are not currently supported + Args: + source: the PipelineSource to build the pipeline off o Raises: - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` Returns: a Pipeline representing the query """ - return self._query().pipeline() + return self._query()._build_pipeline(source) def _auto_id() -> str: diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 0f4347e5f..b1b74fcf1 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -67,6 +67,7 @@ from google.cloud.firestore_v1.query_profile import ExplainOptions from google.cloud.firestore_v1.query_results import QueryResultsList from google.cloud.firestore_v1.stream_generator import StreamGenerator + from google.cloud.firestore_v1.pipeline_source import PipelineSource import datetime @@ -1129,24 +1130,23 @@ def recursive(self: QueryType) -> QueryType: return copied - def pipeline(self): + def _build_pipeline(self, source: "PipelineSource"): """ Convert this query into a Pipeline Queries containing a `cursor` or `limit_to_last` are not currently supported + Args: + source: the PipelineSource to build the pipeline off of Raises: - - ValueError: raised if Query wasn't created with an associated client - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` Returns: a Pipeline representing the query """ - if not self._client: - raise ValueError("Query does not have an associated client") if self._all_descendants: - ppl = self._client.pipeline().collection_group(self._parent.id) + ppl = source.collection_group(self._parent.id) else: - ppl = self._client.pipeline().collection(self._parent._path) + ppl = source.collection(self._parent._path) # Filters for filter_ in self._field_filters: diff --git a/google/cloud/firestore_v1/pipeline_source.py b/google/cloud/firestore_v1/pipeline_source.py index f4328afa4..3fb73b365 100644 --- a/google/cloud/firestore_v1/pipeline_source.py +++ b/google/cloud/firestore_v1/pipeline_source.py @@ -22,6 +22,9 @@ from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.base_document import BaseDocumentReference + from google.cloud.firestore_v1.base_query import BaseQuery + from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + from google.cloud.firestore_v1.base_collection import BaseCollectionReference PipelineType = TypeVar("PipelineType", bound=_BasePipeline) @@ -43,6 +46,23 @@ def __init__(self, client: Client | AsyncClient): def _create_pipeline(self, source_stage): return self.client._pipeline_cls._create_with_stages(self.client, source_stage) + def create_from( + self, query: "BaseQuery" | "BaseAggregationQuery" | "BaseCollectionReference" + ) -> PipelineType: + """ + Create a pipeline from an existing query + + Queries containing a `cursor` or `limit_to_last` are not currently supported + + Args: + query: the query to build the pipeline off of + Raises: + - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` + Returns: + a new pipeline instance representing the query + """ + return query._build_pipeline(self) + def collection(self, path: str | tuple[str]) -> PipelineType: """ Creates a new Pipeline that operates on a specified Firestore collection. diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 592a73f67..09dc607eb 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -133,7 +133,8 @@ def _clean_results(results): except Exception as e: # if we expect the query to fail, capture the exception query_exception = e - pipeline = query.pipeline() + client = query._client + pipeline = client.pipeline().create_from(query) if query_exception: # ensure that the pipeline uses same error as query with pytest.raises(query_exception.__class__): diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index f87da0112..5f8e07eda 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -213,7 +213,8 @@ def _clean_results(results): except Exception as e: # if we expect the query to fail, capture the exception query_exception = e - pipeline = query.pipeline() + client = query._client + pipeline = client.pipeline().create_from(query) if query_exception: # ensure that the pipeline uses same error as query with pytest.raises(query_exception.__class__): diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 299283564..96928e88e 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -1040,7 +1040,7 @@ def test_aggreation_to_pipeline_sum(field, in_alias, out_alias): query = make_query(parent) aggregation_query = make_aggregation_query(query) aggregation_query.sum(field, alias=in_alias) - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) assert isinstance(pipeline, Pipeline) assert len(pipeline.stages) == 2 assert isinstance(pipeline.stages[0], Collection) @@ -1071,7 +1071,7 @@ def test_aggreation_to_pipeline_avg(field, in_alias, out_alias): query = make_query(parent) aggregation_query = make_aggregation_query(query) aggregation_query.avg(field, alias=in_alias) - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) assert isinstance(pipeline, Pipeline) assert len(pipeline.stages) == 2 assert isinstance(pipeline.stages[0], Collection) @@ -1102,7 +1102,7 @@ def test_aggreation_to_pipeline_count(in_alias, out_alias): query = make_query(parent) aggregation_query = make_aggregation_query(query) aggregation_query.count(alias=in_alias) - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) assert isinstance(pipeline, Pipeline) assert len(pipeline.stages) == 2 assert isinstance(pipeline.stages[0], Collection) @@ -1127,7 +1127,7 @@ def test_aggreation_to_pipeline_count_increment(): aggregation_query = make_aggregation_query(query) for _ in range(n): aggregation_query.count() - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) aggregate_stage = pipeline.stages[1] assert len(aggregate_stage.accumulators) == n for i in range(n): @@ -1146,7 +1146,7 @@ def test_aggreation_to_pipeline_complex(): aggregation_query.count() aggregation_query.avg("other") aggregation_query.sum("final") - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) assert isinstance(pipeline, Pipeline) assert len(pipeline.stages) == 3 assert isinstance(pipeline.stages[0], Collection) diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index eca2ecef1..025146145 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -716,7 +716,7 @@ def test_async_aggreation_to_pipeline_sum(field, in_alias, out_alias): query = make_async_query(parent) aggregation_query = make_async_aggregation_query(query) aggregation_query.sum(field, alias=in_alias) - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) assert isinstance(pipeline, AsyncPipeline) assert len(pipeline.stages) == 2 assert isinstance(pipeline.stages[0], Collection) @@ -747,7 +747,7 @@ def test_async_aggreation_to_pipeline_avg(field, in_alias, out_alias): query = make_async_query(parent) aggregation_query = make_async_aggregation_query(query) aggregation_query.avg(field, alias=in_alias) - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) assert isinstance(pipeline, AsyncPipeline) assert len(pipeline.stages) == 2 assert isinstance(pipeline.stages[0], Collection) @@ -778,7 +778,7 @@ def test_async_aggreation_to_pipeline_count(in_alias, out_alias): query = make_async_query(parent) aggregation_query = make_async_aggregation_query(query) aggregation_query.count(alias=in_alias) - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) assert isinstance(pipeline, AsyncPipeline) assert len(pipeline.stages) == 2 assert isinstance(pipeline.stages[0], Collection) @@ -803,7 +803,7 @@ def test_aggreation_to_pipeline_count_increment(): aggregation_query = make_async_aggregation_query(query) for _ in range(n): aggregation_query.count() - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) aggregate_stage = pipeline.stages[1] assert len(aggregate_stage.accumulators) == n for i in range(n): @@ -822,7 +822,7 @@ def test_async_aggreation_to_pipeline_complex(): aggregation_query.count() aggregation_query.avg("other") aggregation_query.sum("final") - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) assert isinstance(pipeline, AsyncPipeline) assert len(pipeline.stages) == 3 assert isinstance(pipeline.stages[0], Collection) diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index 5b4df059a..34a259996 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -609,15 +609,9 @@ def test_asynccollectionreference_pipeline(): client = make_async_client() collection = _make_async_collection_reference("collection", client=client) - pipeline = collection.pipeline() + pipeline = collection._build_pipeline(client.pipeline()) assert isinstance(pipeline, AsyncPipeline) # should have single "Collection" stage assert len(pipeline.stages) == 1 assert isinstance(pipeline.stages[0], Collection) assert pipeline.stages[0].path == "/collection" - - -def test_asynccollectionreference_pipeline_no_client(): - collection = _make_async_collection_reference("collection") - with pytest.raises(ValueError, match="client"): - collection.pipeline() diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index dc5eb9e8a..6e2aa8393 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -917,7 +917,7 @@ def test_asyncquery_collection_pipeline_type(): client = make_async_client() parent = client.collection("test") query = parent._query() - ppl = query.pipeline() + ppl = query._build_pipeline(client.pipeline()) assert isinstance(ppl, AsyncPipeline) @@ -926,5 +926,5 @@ def test_asyncquery_collectiongroup_pipeline_type(): client = make_async_client() query = client.collection_group("test") - ppl = query.pipeline() + ppl = query._build_pipeline(client.pipeline()) assert isinstance(ppl, AsyncPipeline) diff --git a/tests/unit/v1/test_base_collection.py b/tests/unit/v1/test_base_collection.py index 7f7be9c07..9124e4d01 100644 --- a/tests/unit/v1/test_base_collection.py +++ b/tests/unit/v1/test_base_collection.py @@ -430,10 +430,11 @@ def test_basecollectionreference_pipeline(mock_query): _query.return_value = mock_query collection = _make_base_collection_reference("collection") - pipeline = collection.pipeline() + mock_source = mock.Mock() + pipeline = collection._build_pipeline(mock_source) - mock_query.pipeline.assert_called_once_with() - assert pipeline == mock_query.pipeline.return_value + mock_query._build_pipeline.assert_called_once_with(mock_source) + assert pipeline == mock_query._build_pipeline.return_value @mock.patch("random.choice") diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 925010070..4a4dac727 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -1994,18 +1994,10 @@ def test__collection_group_query_response_to_snapshot_response(): assert snapshot.update_time == response_pb._pb.document.update_time -def test__query_pipeline_no_client(): - mock_parent = mock.Mock() - mock_parent._client = None - query = _make_base_query(mock_parent) - with pytest.raises(ValueError, match="client"): - query.pipeline() - - def test__query_pipeline_decendants(): client = make_client() query = client.collection_group("my_col") - pipeline = query.pipeline() + pipeline = query._build_pipeline(client.pipeline()) assert len(pipeline.stages) == 1 stage = pipeline.stages[0] @@ -2025,7 +2017,7 @@ def test__query_pipeline_no_decendants(in_path, out_path): client = make_client() collection = client.collection(in_path) query = collection._query() - pipeline = query.pipeline() + pipeline = query._build_pipeline(client.pipeline()) assert len(pipeline.stages) == 1 stage = pipeline.stages[0] @@ -2043,7 +2035,7 @@ def test__query_pipeline_composite_filter(): with mock.patch.object( expr.BooleanExpression, "_from_query_filter_pb" ) as convert_mock: - pipeline = query.pipeline() + pipeline = query._build_pipeline(client.pipeline()) convert_mock.assert_called_once_with(in_filter._to_pb(), client) assert len(pipeline.stages) == 2 stage = pipeline.stages[1] @@ -2054,7 +2046,7 @@ def test__query_pipeline_composite_filter(): def test__query_pipeline_projections(): client = make_client() query = client.collection("my_col").select(["field_a", "field_b.c"]) - pipeline = query.pipeline() + pipeline = query._build_pipeline(client.pipeline()) assert len(pipeline.stages) == 2 stage = pipeline.stages[1] @@ -2069,7 +2061,7 @@ def test__query_pipeline_order_exists_multiple(): client = make_client() query = client.collection("my_col").order_by("field_a").order_by("field_b") - pipeline = query.pipeline() + pipeline = query._build_pipeline(client.pipeline()) # should have collection, where, and sort # we're interested in where @@ -2089,7 +2081,7 @@ def test__query_pipeline_order_exists_multiple(): def test__query_pipeline_order_exists_single(): client = make_client() query_single = client.collection("my_col").order_by("field_c") - pipeline_single = query_single.pipeline() + pipeline_single = query_single._build_pipeline(client.pipeline()) # should have collection, where, and sort # we're interested in where @@ -2110,7 +2102,7 @@ def test__query_pipeline_order_sorts(): .order_by("field_a", direction=BaseQuery.ASCENDING) .order_by("field_b", direction=BaseQuery.DESCENDING) ) - pipeline = query.pipeline() + pipeline = query._build_pipeline(client.pipeline()) assert len(pipeline.stages) == 3 sort_stage = pipeline.stages[2] @@ -2128,21 +2120,21 @@ def test__query_pipeline_unsupported(): client = make_client() query_start = client.collection("my_col").start_at({"field_a": "value"}) with pytest.raises(NotImplementedError, match="cursors"): - query_start.pipeline() + query_start._build_pipeline(client.pipeline()) query_end = client.collection("my_col").end_at({"field_a": "value"}) with pytest.raises(NotImplementedError, match="cursors"): - query_end.pipeline() + query_end._build_pipeline(client.pipeline()) query_limit_last = client.collection("my_col").limit_to_last(10) with pytest.raises(NotImplementedError, match="limit_to_last"): - query_limit_last.pipeline() + query_limit_last._build_pipeline(client.pipeline()) def test__query_pipeline_limit(): client = make_client() query = client.collection("my_col").limit(15) - pipeline = query.pipeline() + pipeline = query._build_pipeline(client.pipeline()) assert len(pipeline.stages) == 2 stage = pipeline.stages[1] @@ -2153,7 +2145,7 @@ def test__query_pipeline_limit(): def test__query_pipeline_offset(): client = make_client() query = client.collection("my_col").offset(5) - pipeline = query.pipeline() + pipeline = query._build_pipeline(client.pipeline()) assert len(pipeline.stages) == 2 stage = pipeline.stages[1] diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index 76418204b..156b314aa 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -15,7 +15,6 @@ import types import mock -import pytest from datetime import datetime, timezone from tests.unit.v1._test_helpers import DEFAULT_TEST_PROJECT @@ -518,7 +517,7 @@ def test_collectionreference_pipeline(): client = _test_helpers.make_client() collection = _make_collection_reference("collection", client=client) - pipeline = collection.pipeline() + pipeline = collection._build_pipeline(client.pipeline()) assert isinstance(pipeline, Pipeline) # should have single "Collection" stage assert len(pipeline.stages) == 1 @@ -526,12 +525,6 @@ def test_collectionreference_pipeline(): assert pipeline.stages[0].path == "/collection" -def test_collectionreference_pipeline_no_client(): - collection = _make_collection_reference("collection") - with pytest.raises(ValueError, match="client"): - collection.pipeline() - - @mock.patch("google.cloud.firestore_v1.collection.Watch", autospec=True) def test_on_snapshot(watch): collection = _make_collection_reference("collection") diff --git a/tests/unit/v1/test_pipeline_source.py b/tests/unit/v1/test_pipeline_source.py index e29b763e2..69754a941 100644 --- a/tests/unit/v1/test_pipeline_source.py +++ b/tests/unit/v1/test_pipeline_source.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License +import mock from google.cloud.firestore_v1.pipeline_source import PipelineSource from google.cloud.firestore_v1.pipeline import Pipeline @@ -19,6 +20,8 @@ from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.base_document import BaseDocumentReference +from google.cloud.firestore_v1.query import Query +from google.cloud.firestore_v1.async_query import AsyncQuery class TestPipelineSource: @@ -27,6 +30,9 @@ class TestPipelineSource: def _make_client(self): return Client() + def _make_query(self): + return Query(mock.Mock()) + def test_make_from_client(self): instance = self._make_client().pipeline() assert isinstance(instance, PipelineSource) @@ -36,6 +42,23 @@ def test_create_pipeline(self): ppl = instance._create_pipeline(None) assert isinstance(ppl, self._expected_pipeline_type) + def test_create_from_mock(self): + mock_query = mock.Mock() + expected = object() + mock_query._build_pipeline.return_value = expected + instance = self._make_client().pipeline() + got = instance.create_from(mock_query) + assert got == expected + assert mock_query._build_pipeline.call_count == 1 + assert mock_query._build_pipeline.call_args_list[0][0][0] == instance + + def test_create_from_query(self): + query = self._make_query() + instance = self._make_client().pipeline() + ppl = instance.create_from(query) + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + def test_collection(self): instance = self._make_client().pipeline() ppl = instance.collection("path") @@ -98,3 +121,6 @@ class TestPipelineSourceWithAsyncClient(TestPipelineSource): def _make_client(self): return AsyncClient() + + def _make_query(self): + return AsyncQuery(mock.Mock()) diff --git a/tests/unit/v1/test_query.py b/tests/unit/v1/test_query.py index 8b1217370..7eaeef61b 100644 --- a/tests/unit/v1/test_query.py +++ b/tests/unit/v1/test_query.py @@ -1054,7 +1054,7 @@ def test_asyncquery_collection_pipeline_type(): client = make_client() parent = client.collection("test") query = parent._query() - ppl = query.pipeline() + ppl = query._build_pipeline(client.pipeline()) assert isinstance(ppl, Pipeline) @@ -1063,5 +1063,5 @@ def test_asyncquery_collectiongroup_pipeline_type(): client = make_client() query = client.collection_group("test") - ppl = query.pipeline() + ppl = query._build_pipeline(client.pipeline()) assert isinstance(ppl, Pipeline) From aef4391b47a08414cff3399c17ffb1fa1a34647f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 7 Nov 2025 15:49:04 -0800 Subject: [PATCH 12/27] feat: pipeline read time (#1125) --- google/cloud/firestore_v1/async_pipeline.py | 20 ++++++++- google/cloud/firestore_v1/base_pipeline.py | 6 ++- google/cloud/firestore_v1/pipeline.py | 18 ++++++++- tests/system/test_system.py | 44 ++++++++++++++++++++ tests/system/test_system_async.py | 45 +++++++++++++++++++++ tests/unit/v1/test_async_pipeline.py | 44 ++++++++++++++++++-- tests/unit/v1/test_pipeline.py | 42 +++++++++++++++++-- 7 files changed, 206 insertions(+), 13 deletions(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 9fe0c8756..f175b5394 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -18,6 +18,7 @@ from google.cloud.firestore_v1.base_pipeline import _BasePipeline if TYPE_CHECKING: # pragma: NO COVER + import datetime from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.pipeline_result import PipelineResult from google.cloud.firestore_v1.async_transaction import AsyncTransaction @@ -59,6 +60,7 @@ def __init__(self, client: AsyncClient, *stages: stages.Stage): async def execute( self, transaction: "AsyncTransaction" | None = None, + read_time: datetime.datetime | None = None, ) -> list[PipelineResult]: """ Executes this pipeline and returns results as a list @@ -70,12 +72,22 @@ async def execute( If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not allowed). + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. """ - return [result async for result in self.stream(transaction=transaction)] + return [ + result + async for result in self.stream( + transaction=transaction, read_time=read_time + ) + ] async def stream( self, transaction: "AsyncTransaction" | None = None, + read_time: datetime.datetime | None = None, ) -> AsyncIterable[PipelineResult]: """ Process this pipeline as a stream, providing results through an Iterable @@ -87,8 +99,12 @@ async def stream( If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not allowed). + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. """ - request = self._prep_execute_request(transaction) + request = self._prep_execute_request(transaction, read_time) async for response in await self._client._firestore_api.execute_pipeline( request ): diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index c66321793..7f52c2021 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -33,6 +33,7 @@ from google.cloud.firestore_v1 import _helpers if TYPE_CHECKING: # pragma: NO COVER + import datetime from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse @@ -99,7 +100,9 @@ def _append(self, new_stage): return self.__class__._create_with_stages(self._client, *self.stages, new_stage) def _prep_execute_request( - self, transaction: BaseTransaction | None + self, + transaction: BaseTransaction | None, + read_time: datetime.datetime | None, ) -> ExecutePipelineRequest: """ shared logic for creating an ExecutePipelineRequest @@ -116,6 +119,7 @@ def _prep_execute_request( database=database_name, transaction=transaction_id, structured_pipeline=self._to_pb(), + read_time=read_time, ) return request diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index f578e00b6..b4567189b 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -18,6 +18,7 @@ from google.cloud.firestore_v1.base_pipeline import _BasePipeline if TYPE_CHECKING: # pragma: NO COVER + import datetime from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.pipeline_result import PipelineResult from google.cloud.firestore_v1.transaction import Transaction @@ -56,6 +57,7 @@ def __init__(self, client: Client, *stages: stages.Stage): def execute( self, transaction: "Transaction" | None = None, + read_time: datetime.datetime | None = None, ) -> list[PipelineResult]: """ Executes this pipeline and returns results as a list @@ -67,12 +69,20 @@ def execute( If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not allowed). + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. """ - return [result for result in self.stream(transaction=transaction)] + return [ + result + for result in self.stream(transaction=transaction, read_time=read_time) + ] def stream( self, transaction: "Transaction" | None = None, + read_time: datetime.datetime | None = None, ) -> Iterable[PipelineResult]: """ Process this pipeline as a stream, providing results through an Iterable @@ -84,7 +94,11 @@ def stream( If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not allowed). + read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given + time. This must be a microsecond precision timestamp within the past one hour, or + if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp + within the past 7 days. For the most accurate results, use UTC timezone. """ - request = self._prep_execute_request(transaction) + request = self._prep_execute_request(transaction, read_time) for response in self._client._firestore_api.execute_pipeline(request): yield from self._execute_response_helper(response) diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 09dc607eb..61b1a983c 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -46,6 +46,7 @@ TEST_DATABASES, TEST_DATABASES_W_ENTERPRISE, IS_KOKORO_TEST, + FIRESTORE_ENTERPRISE_DB, ) @@ -1689,6 +1690,49 @@ def test_query_stream_w_read_time(query_docs, cleanup, database): assert new_values[new_ref.id] == new_data +@pytest.mark.skipif(IS_KOKORO_TEST, reason="skipping pipeline verification on kokoro") +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +def test_pipeline_w_read_time(query_docs, cleanup, database): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + + # Find a read_time before adding the new document. + snapshots = collection.get() + read_time = snapshots[0].read_time + + new_data = { + "a": 9000, + "b": 1, + "c": [10000, 1000], + "stats": {"sum": 9001, "product": 9000}, + } + _, new_ref = collection.add(new_data) + # Add to clean-up. + cleanup(new_ref.delete) + stored[new_ref.id] = new_data + + pipeline = collection.where(filter=FieldFilter("b", "==", 1)).pipeline() + # new query should have new_data + new_results = list(pipeline.stream()) + new_values = {result.ref.id: result.data() for result in new_results} + assert len(new_values) == num_vals + 1 + assert new_ref.id in new_values + assert new_values[new_ref.id] == new_data + + # query with read_time should not have new)data + results = list(pipeline.stream(read_time=read_time)) + + values = {result.ref.id: result.data() for result in results} + + assert len(values) == num_vals + assert new_ref.id not in values + for key, value in values.items(): + assert stored[key] == value + assert value["b"] == 1 + assert value["a"] != 9000 + assert key != new_ref.id + + @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_with_order_dot_key(client, cleanup, database): db = client diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 5f8e07eda..99b9da801 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -57,6 +57,7 @@ TEST_DATABASES, TEST_DATABASES_W_ENTERPRISE, IS_KOKORO_TEST, + FIRESTORE_ENTERPRISE_DB, ) RETRIES = retries.AsyncRetry( @@ -1612,6 +1613,50 @@ async def test_query_stream_w_read_time(query_docs, cleanup, database): assert new_values[new_ref.id] == new_data +@pytest.mark.skipif(IS_KOKORO_TEST, reason="skipping pipeline verification on kokoro") +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +async def test_pipeline_w_read_time(query_docs, cleanup, database): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + + # Find a read_time before adding the new document. + snapshots = await collection.get() + read_time = snapshots[0].read_time + + new_data = { + "a": 9000, + "b": 1, + "c": [10000, 1000], + "stats": {"sum": 9001, "product": 9000}, + } + _, new_ref = await collection.add(new_data) + # Add to clean-up. + cleanup(new_ref.delete) + stored[new_ref.id] = new_data + + pipeline = collection.where(filter=FieldFilter("b", "==", 1)).pipeline() + + # new query should have new_data + new_results = [result async for result in pipeline.stream()] + new_values = {result.ref.id: result.data() for result in new_results} + assert len(new_values) == num_vals + 1 + assert new_ref.id in new_values + assert new_values[new_ref.id] == new_data + + # pipeline with read_time should not have new_data + results = [result async for result in pipeline.stream(read_time=read_time)] + + values = {result.ref.id: result.data() for result in results} + + assert len(values) == num_vals + assert new_ref.id not in values + for key, value in values.items(): + assert stored[key] == value + assert value["b"] == 1 + assert value["a"] != 9000 + assert key != new_ref.id + + @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_with_order_dot_key(client, cleanup, database): db = client diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py index 2fc39a906..189b24fba 100644 --- a/tests/unit/v1/test_async_pipeline.py +++ b/tests/unit/v1/test_async_pipeline.py @@ -320,6 +320,36 @@ async def test_async_pipeline_stream_with_transaction(): assert request.transaction == b"123" +@pytest.mark.asyncio +async def test_async_pipeline_stream_with_read_time(): + """ + test stream pipeline with read_time + """ + import datetime + + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + + mock_rpc.return_value = _async_it([ExecutePipelineResponse()]) + ppl_1 = _make_async_pipeline(client=client) + + [r async for r in ppl_1.stream(read_time=read_time)] + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.read_time == read_time + + @pytest.mark.asyncio async def test_async_pipeline_stream_stream_equivalence(): """ @@ -364,16 +394,22 @@ async def test_async_pipeline_stream_stream_equivalence_mocked(): """ pipeline.stream should call pipeline.stream internally """ + import datetime + ppl_1 = _make_async_pipeline() expected_data = [object(), object()] - expected_arg = object() + expected_transaction = object() + expected_read_time = datetime.datetime.now(tz=datetime.timezone.utc) with mock.patch.object(ppl_1, "stream") as mock_stream: mock_stream.return_value = _async_it(expected_data) - stream_results = await ppl_1.execute(expected_arg) + stream_results = await ppl_1.execute( + transaction=expected_transaction, read_time=expected_read_time + ) assert mock_stream.call_count == 1 assert mock_stream.call_args[0] == () - assert len(mock_stream.call_args[1]) == 1 - assert mock_stream.call_args[1]["transaction"] == expected_arg + assert len(mock_stream.call_args[1]) == 2 + assert mock_stream.call_args[1]["transaction"] == expected_transaction + assert mock_stream.call_args[1]["read_time"] == expected_read_time assert stream_results == expected_data diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index e203f6d69..34d3400e8 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -302,6 +302,34 @@ def test_pipeline_stream_with_transaction(): assert request.transaction == b"123" +def test_pipeline_stream_with_read_time(): + """ + test stream pipeline with read_time + """ + import datetime + + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = client._firestore_api.execute_pipeline + + read_time = datetime.datetime.now(tz=datetime.timezone.utc) + + mock_rpc.return_value = [ExecutePipelineResponse()] + ppl_1 = _make_pipeline(client=client) + + list(ppl_1.stream(read_time=read_time)) + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.read_time == read_time + + def test_pipeline_execute_stream_equivalence(): """ Pipeline.execute should provide same results from pipeline.stream, as a list @@ -341,16 +369,22 @@ def test_pipeline_execute_stream_equivalence_mocked(): """ pipeline.execute should call pipeline.stream internally """ + import datetime + ppl_1 = _make_pipeline() expected_data = [object(), object()] - expected_arg = object() + expected_transaction = object() + expected_read_time = datetime.datetime.now(tz=datetime.timezone.utc) with mock.patch.object(ppl_1, "stream") as mock_stream: mock_stream.return_value = expected_data - stream_results = ppl_1.execute(expected_arg) + stream_results = ppl_1.execute( + transaction=expected_transaction, read_time=expected_read_time + ) assert mock_stream.call_count == 1 assert mock_stream.call_args[0] == () - assert len(mock_stream.call_args[1]) == 1 - assert mock_stream.call_args[1]["transaction"] == expected_arg + assert len(mock_stream.call_args[1]) == 2 + assert mock_stream.call_args[1]["transaction"] == expected_transaction + assert mock_stream.call_args[1]["read_time"] == expected_read_time assert stream_results == expected_data From 2d3ed7324273883abf6138fe5408a87d372c277d Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 10 Nov 2025 16:40:43 -0800 Subject: [PATCH 13/27] feat: improve pipeline expressions (#1126) --- .../firestore_v1/pipeline_expressions.py | 39 +++++++--- tests/system/pipeline_e2e/array.yaml | 76 +++++++++++++++++++ tests/system/pipeline_e2e/logical.yaml | 58 ++++++++++++++ tests/unit/v1/test_pipeline_expressions.py | 36 ++++++--- 4 files changed, 187 insertions(+), 22 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 7a90dea1d..439b224cc 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -396,7 +396,7 @@ def sqrt(self) -> "Expression": return Function("sqrt", [self]) @expose_as_static - def logical_maximum(self, other: Expression | CONSTANT_TYPE) -> "Expression": + def logical_maximum(self, *others: Expression | CONSTANT_TYPE) -> "Expression": """Creates an expression that returns the larger value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -406,23 +406,23 @@ def logical_maximum(self, other: Expression | CONSTANT_TYPE) -> "Expression": Example: >>> # Returns the larger value between the 'discount' field and the 'cap' field. >>> Field.of("discount").logical_maximum(Field.of("cap")) - >>> # Returns the larger value between the 'value' field and 10. - >>> Field.of("value").logical_maximum(10) + >>> # Returns the larger value between the 'value' field and some ints + >>> Field.of("value").logical_maximum(10, 20, 30) Args: - other: The other expression or constant value to compare with. + others: The other expression or constant values to compare with. Returns: A new `Expression` representing the logical maximum operation. """ return Function( "maximum", - [self, self._cast_to_expr_or_convert_to_constant(other)], + [self] + [self._cast_to_expr_or_convert_to_constant(o) for o in others], infix_name_override="logical_maximum", ) @expose_as_static - def logical_minimum(self, other: Expression | CONSTANT_TYPE) -> "Expression": + def logical_minimum(self, *others: Expression | CONSTANT_TYPE) -> "Expression": """Creates an expression that returns the smaller value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -432,18 +432,18 @@ def logical_minimum(self, other: Expression | CONSTANT_TYPE) -> "Expression": Example: >>> # Returns the smaller value between the 'discount' field and the 'floor' field. >>> Field.of("discount").logical_minimum(Field.of("floor")) - >>> # Returns the smaller value between the 'value' field and 10. - >>> Field.of("value").logical_minimum(10) + >>> # Returns the smaller value between the 'value' field and some ints + >>> Field.of("value").logical_minimum(10, 20, 30) Args: - other: The other expression or constant value to compare with. + others: The other expression or constant values to compare with. Returns: A new `Expression` representing the logical minimum operation. """ return Function( "minimum", - [self, self._cast_to_expr_or_convert_to_constant(other)], + [self] + [self._cast_to_expr_or_convert_to_constant(o) for o in others], infix_name_override="logical_minimum", ) @@ -629,6 +629,25 @@ def not_equal_any( ], ) + @expose_as_static + def array_get(self, offset: Expression | int) -> "Function": + """ + Creates an expression that indexes into an array from the beginning or end and returns the + element. A negative offset starts from the end. + + Example: + >>> Array([1,2,3]).array_get(0) + + Args: + offset: the index of the element to return + + Returns: + A new `Expression` representing the `array_get` operation. + """ + return Function( + "array_get", [self, self._cast_to_expr_or_convert_to_constant(offset)] + ) + @expose_as_static def array_contains( self, element: Expression | CONSTANT_TYPE diff --git a/tests/system/pipeline_e2e/array.yaml b/tests/system/pipeline_e2e/array.yaml index d32491d8b..acdded36b 100644 --- a/tests/system/pipeline_e2e/array.yaml +++ b/tests/system/pipeline_e2e/array.yaml @@ -386,3 +386,79 @@ tests: name: array name: array_concat name: select + - description: testArrayGet + pipeline: + - Collection: books + - Where: + - Function.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpression: + - Function.array_get: + - Field: tags + - Constant: 0 + - "firstTag" + assert_results: + - firstTag: "comedy" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "The Hitchhiker's Guide to the Galaxy" + name: equal + name: where + - args: + - mapValue: + fields: + firstTag: + functionValue: + args: + - fieldReferenceValue: tags + - integerValue: '0' + name: array_get + name: select + - description: testArrayGet_NegativeOffset + pipeline: + - Collection: books + - Where: + - Function.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpression: + - Function.array_get: + - Field: tags + - Constant: -1 + - "lastTag" + assert_results: + - lastTag: "adventure" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "The Hitchhiker's Guide to the Galaxy" + name: equal + name: where + - args: + - mapValue: + fields: + lastTag: + functionValue: + args: + - fieldReferenceValue: tags + - integerValue: '-1' + name: array_get + name: select \ No newline at end of file diff --git a/tests/system/pipeline_e2e/logical.yaml b/tests/system/pipeline_e2e/logical.yaml index 0bc889ac1..bbb71921b 100644 --- a/tests/system/pipeline_e2e/logical.yaml +++ b/tests/system/pipeline_e2e/logical.yaml @@ -464,6 +464,64 @@ tests: - doubleValue: 4.5 name: maximum name: select + - description: testLogicalMinMaxWithMultipleInputs + pipeline: + - Collection: books + - Where: + - Function.equal: + - Field: author + - Constant: Douglas Adams + - Select: + - AliasedExpression: + - Function.logical_maximum: + - Field: rating + - Constant: 4.5 + - Constant: 3.0 + - Constant: 5.0 + - "max_rating" + - AliasedExpression: + - Function.logical_minimum: + - Field: published + - Constant: 1900 + - Constant: 2000 + - Constant: 1984 + - "min_published" + assert_results: + - max_rating: 5.0 + min_published: 1900 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + min_published: + functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + - integerValue: '2000' + - integerValue: '1984' + name: minimum + max_rating: + functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + - doubleValue: 3.0 + - doubleValue: 5.0 + name: maximum + name: select - description: testGreaterThanOrEqual pipeline: - Collection: books diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 1546dbe66..84eb6cfe9 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -717,6 +717,16 @@ def test_or(self): assert instance.params == [arg1, arg2] assert repr(instance) == "Or(Arg1, Arg2)" + def test_array_get(self): + arg1 = self._make_arg("ArrayField") + arg2 = self._make_arg("Offset") + instance = Expression.array_get(arg1, arg2) + assert instance.name == "array_get" + assert instance.params == [arg1, arg2] + assert repr(instance) == "ArrayField.array_get(Offset)" + infix_istance = arg1.array_get(arg2) + assert infix_istance == instance + def test_array_contains(self): arg1 = self._make_arg("ArrayField") arg2 = self._make_arg("Element") @@ -989,23 +999,25 @@ def test_divide(self): assert infix_instance == instance def test_logical_maximum(self): - arg1 = self._make_arg("Left") - arg2 = self._make_arg("Right") - instance = Expression.logical_maximum(arg1, arg2) + arg1 = self._make_arg("A1") + arg2 = self._make_arg("A2") + arg3 = self._make_arg("A3") + instance = Expression.logical_maximum(arg1, arg2, arg3) assert instance.name == "maximum" - assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.logical_maximum(Right)" - infix_instance = arg1.logical_maximum(arg2) + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "A1.logical_maximum(A2, A3)" + infix_instance = arg1.logical_maximum(arg2, arg3) assert infix_instance == instance def test_logical_minimum(self): - arg1 = self._make_arg("Left") - arg2 = self._make_arg("Right") - instance = Expression.logical_minimum(arg1, arg2) + arg1 = self._make_arg("A1") + arg2 = self._make_arg("A2") + arg3 = self._make_arg("A3") + instance = Expression.logical_minimum(arg1, arg2, arg3) assert instance.name == "minimum" - assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.logical_minimum(Right)" - infix_instance = arg1.logical_minimum(arg2) + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "A1.logical_minimum(A2, A3)" + infix_instance = arg1.logical_minimum(arg2, arg3) assert infix_instance == instance def test_to_lower(self): From 4848fbe29fbed70f35067de694314dc91c2b68f1 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 11 Nov 2025 09:18:23 -0800 Subject: [PATCH 14/27] feat: pipeline explain stats and index mode (#1128) --- google/cloud/firestore_v1/async_pipeline.py | 67 ++-- google/cloud/firestore_v1/base_pipeline.py | 54 +-- google/cloud/firestore_v1/pipeline.py | 54 ++- google/cloud/firestore_v1/pipeline_result.py | 163 ++++++++- google/cloud/firestore_v1/query_profile.py | 83 +++++ tests/system/test_system.py | 142 +++++++- tests/system/test_system_async.py | 196 ++++++++--- tests/unit/v1/test_async_pipeline.py | 24 -- tests/unit/v1/test_pipeline.py | 34 +- tests/unit/v1/test_pipeline_result.py | 334 +++++++++++++++++++ tests/unit/v1/test_query_profile.py | 61 ++++ 11 files changed, 1034 insertions(+), 178 deletions(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index f175b5394..6b017d88e 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -13,15 +13,20 @@ # limitations under the License. from __future__ import annotations -from typing import AsyncIterable, TYPE_CHECKING +from typing import TYPE_CHECKING from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.base_pipeline import _BasePipeline +from google.cloud.firestore_v1.pipeline_result import AsyncPipelineStream +from google.cloud.firestore_v1.pipeline_result import PipelineSnapshot +from google.cloud.firestore_v1.pipeline_result import PipelineResult if TYPE_CHECKING: # pragma: NO COVER import datetime from google.cloud.firestore_v1.async_client import AsyncClient - from google.cloud.firestore_v1.pipeline_result import PipelineResult from google.cloud.firestore_v1.async_transaction import AsyncTransaction + from google.cloud.firestore_v1.pipeline_expressions import Constant + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.query_profile import PipelineExplainOptions class AsyncPipeline(_BasePipeline): @@ -41,7 +46,7 @@ class AsyncPipeline(_BasePipeline): ... .collection("books") ... .where(Field.of("published").gt(1980)) ... .select("title", "author") - ... async for result in pipeline.execute(): + ... async for result in pipeline.stream(): ... print(result) Use `client.pipeline()` to create instances of this class. @@ -59,15 +64,18 @@ def __init__(self, client: AsyncClient, *stages: stages.Stage): async def execute( self, + *, transaction: "AsyncTransaction" | None = None, read_time: datetime.datetime | None = None, - ) -> list[PipelineResult]: + explain_options: PipelineExplainOptions | None = None, + index_mode: str | None = None, + additional_options: dict[str, Value | Constant] = {}, + ) -> PipelineSnapshot[PipelineResult]: """ Executes this pipeline and returns results as a list Args: - transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not @@ -76,25 +84,33 @@ async def execute( time. This must be a microsecond precision timestamp within the past one hour, or if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp within the past 7 days. For the most accurate results, use UTC timezone. + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.PipelineExplainOptions`]): + Options to enable query profiling for this query. When set, + explain_metrics will be available on the returned list. + index_mode (Optional[str]): Configures the pipeline to require a certain type of indexes to be present. + Firestore will reject the request if there is not appropiate indexes to serve the query. + additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query. + These options will take precedence over method argument if there is a conflict (e.g. explain_options, index_mode) """ - return [ - result - async for result in self.stream( - transaction=transaction, read_time=read_time - ) - ] + kwargs = {k: v for k, v in locals().items() if k != "self"} + stream = AsyncPipelineStream(PipelineResult, self, **kwargs) + results = [result async for result in stream] + return PipelineSnapshot(results, stream) - async def stream( + def stream( self, - transaction: "AsyncTransaction" | None = None, + *, read_time: datetime.datetime | None = None, - ) -> AsyncIterable[PipelineResult]: + transaction: "AsyncTransaction" | None = None, + explain_options: PipelineExplainOptions | None = None, + index_mode: str | None = None, + additional_options: dict[str, Value | Constant] = {}, + ) -> AsyncPipelineStream[PipelineResult]: """ - Process this pipeline as a stream, providing results through an Iterable + Process this pipeline as a stream, providing results through an AsyncIterable Args: - transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not @@ -103,10 +119,13 @@ async def stream( time. This must be a microsecond precision timestamp within the past one hour, or if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp within the past 7 days. For the most accurate results, use UTC timezone. + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.PipelineExplainOptions`]): + Options to enable query profiling for this query. When set, + explain_metrics will be available on the returned generator. + index_mode (Optional[str]): Configures the pipeline to require a certain type of indexes to be present. + Firestore will reject the request if there is not appropiate indexes to serve the query. + additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query. + These options will take precedence over method argument if there is a conflict (e.g. explain_options, index_mode) """ - request = self._prep_execute_request(transaction, read_time) - async for response in await self._client._firestore_api.execute_pipeline( - request - ): - for result in self._execute_response_helper(response): - yield result + kwargs = {k: v for k, v in locals().items() if k != "self"} + return AsyncPipelineStream(PipelineResult, self, **kwargs) diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 7f52c2021..153564663 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -13,15 +13,13 @@ # limitations under the License. from __future__ import annotations -from typing import Iterable, Sequence, TYPE_CHECKING +from typing import Sequence, TYPE_CHECKING from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.types.pipeline import ( StructuredPipeline as StructuredPipeline_pb, ) from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1.base_vector_query import DistanceMeasure -from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest -from google.cloud.firestore_v1.pipeline_result import PipelineResult from google.cloud.firestore_v1.pipeline_expressions import ( AggregateFunction, AliasedExpression, @@ -30,14 +28,10 @@ BooleanExpression, Selectable, ) -from google.cloud.firestore_v1 import _helpers if TYPE_CHECKING: # pragma: NO COVER - import datetime from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.async_client import AsyncClient - from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse - from google.cloud.firestore_v1.transaction import BaseTransaction class _BasePipeline: @@ -88,9 +82,10 @@ def __repr__(self): stages_str = ",\n ".join([repr(s) for s in self.stages]) return f"{cls_str}(\n {stages_str}\n)" - def _to_pb(self) -> StructuredPipeline_pb: + def _to_pb(self, **options) -> StructuredPipeline_pb: return StructuredPipeline_pb( - pipeline={"stages": [s._to_pb() for s in self.stages]} + pipeline={"stages": [s._to_pb() for s in self.stages]}, + options=options, ) def _append(self, new_stage): @@ -99,47 +94,6 @@ def _append(self, new_stage): """ return self.__class__._create_with_stages(self._client, *self.stages, new_stage) - def _prep_execute_request( - self, - transaction: BaseTransaction | None, - read_time: datetime.datetime | None, - ) -> ExecutePipelineRequest: - """ - shared logic for creating an ExecutePipelineRequest - """ - database_name = ( - f"projects/{self._client.project}/databases/{self._client._database}" - ) - transaction_id = ( - _helpers.get_transaction_id(transaction) - if transaction is not None - else None - ) - request = ExecutePipelineRequest( - database=database_name, - transaction=transaction_id, - structured_pipeline=self._to_pb(), - read_time=read_time, - ) - return request - - def _execute_response_helper( - self, response: ExecutePipelineResponse - ) -> Iterable[PipelineResult]: - """ - shared logic for unpacking an ExecutePipelineReponse into PipelineResults - """ - for doc in response.results: - ref = self._client.document(doc.name) if doc.name else None - yield PipelineResult( - self._client, - doc.fields, - ref, - response._pb.execution_time, - doc._pb.create_time if doc.create_time else None, - doc._pb.update_time if doc.update_time else None, - ) - def add_fields(self, *fields: Selectable) -> "_BasePipeline": """ Adds new fields to outputs from previous stages. diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index b4567189b..950eb6ffa 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -13,15 +13,20 @@ # limitations under the License. from __future__ import annotations -from typing import Iterable, TYPE_CHECKING +from typing import TYPE_CHECKING from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.base_pipeline import _BasePipeline +from google.cloud.firestore_v1.pipeline_result import PipelineStream +from google.cloud.firestore_v1.pipeline_result import PipelineSnapshot +from google.cloud.firestore_v1.pipeline_result import PipelineResult if TYPE_CHECKING: # pragma: NO COVER import datetime from google.cloud.firestore_v1.client import Client - from google.cloud.firestore_v1.pipeline_result import PipelineResult + from google.cloud.firestore_v1.pipeline_expressions import Constant from google.cloud.firestore_v1.transaction import Transaction + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.query_profile import PipelineExplainOptions class Pipeline(_BasePipeline): @@ -56,15 +61,18 @@ def __init__(self, client: Client, *stages: stages.Stage): def execute( self, + *, transaction: "Transaction" | None = None, read_time: datetime.datetime | None = None, - ) -> list[PipelineResult]: + explain_options: PipelineExplainOptions | None = None, + index_mode: str | None = None, + additional_options: dict[str, Value | Constant] = {}, + ) -> PipelineSnapshot[PipelineResult]: """ Executes this pipeline and returns results as a list Args: - transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not @@ -73,23 +81,33 @@ def execute( time. This must be a microsecond precision timestamp within the past one hour, or if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp within the past 7 days. For the most accurate results, use UTC timezone. + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.PipelineExplainOptions`]): + Options to enable query profiling for this query. When set, + explain_metrics will be available on the returned list. + index_mode (Optional[str]): Configures the pipeline to require a certain type of indexes to be present. + Firestore will reject the request if there is not appropiate indexes to serve the query. + additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query. + These options will take precedence over method argument if there is a conflict (e.g. explain_options, index_mode) """ - return [ - result - for result in self.stream(transaction=transaction, read_time=read_time) - ] + kwargs = {k: v for k, v in locals().items() if k != "self"} + stream = PipelineStream(PipelineResult, self, **kwargs) + results = [result for result in stream] + return PipelineSnapshot(results, stream) def stream( self, + *, transaction: "Transaction" | None = None, read_time: datetime.datetime | None = None, - ) -> Iterable[PipelineResult]: + explain_options: PipelineExplainOptions | None = None, + index_mode: str | None = None, + additional_options: dict[str, Value | Constant] = {}, + ) -> PipelineStream[PipelineResult]: """ Process this pipeline as a stream, providing results through an Iterable Args: - transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not @@ -98,7 +116,13 @@ def stream( time. This must be a microsecond precision timestamp within the past one hour, or if Point-in-Time Recovery is enabled, can additionally be a whole minute timestamp within the past 7 days. For the most accurate results, use UTC timezone. + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.PipelineExplainOptions`]): + Options to enable query profiling for this query. When set, + explain_metrics will be available on the returned generator. + index_mode (Optional[str]): Configures the pipeline to require a certain type of indexes to be present. + Firestore will reject the request if there is not appropiate indexes to serve the query. + additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query. + These options will take precedence over method argument if there is a conflict (e.g. explain_options, index_mode) """ - request = self._prep_execute_request(transaction, read_time) - for response in self._client._firestore_api.execute_pipeline(request): - yield from self._execute_response_helper(response) + kwargs = {k: v for k, v in locals().items() if k != "self"} + return PipelineStream(PipelineResult, self, **kwargs) diff --git a/google/cloud/firestore_v1/pipeline_result.py b/google/cloud/firestore_v1/pipeline_result.py index ada855fea..6be08fa57 100644 --- a/google/cloud/firestore_v1/pipeline_result.py +++ b/google/cloud/firestore_v1/pipeline_result.py @@ -13,17 +13,43 @@ # limitations under the License. from __future__ import annotations -from typing import Any, MutableMapping, TYPE_CHECKING +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Iterable, + Iterator, + Generic, + MutableMapping, + Type, + TypeVar, + TYPE_CHECKING, +) from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.field_path import get_nested_value from google.cloud.firestore_v1.field_path import FieldPath +from google.cloud.firestore_v1.query_profile import ExplainStats +from google.cloud.firestore_v1.query_profile import QueryExplainError +from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest +from google.cloud.firestore_v1.types.document import Value if TYPE_CHECKING: # pragma: NO COVER + import datetime + from google.cloud.firestore_v1.async_client import AsyncClient + from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.base_client import BaseClient + from google.cloud.firestore_v1.async_transaction import AsyncTransaction + from google.cloud.firestore_v1.transaction import Transaction from google.cloud.firestore_v1.base_document import BaseDocumentReference from google.protobuf.timestamp_pb2 import Timestamp + from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse from google.cloud.firestore_v1.types.document import Value as ValueProto from google.cloud.firestore_v1.vector import Vector + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1.base_pipeline import _BasePipeline + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1.pipeline_expressions import Constant + from google.cloud.firestore_v1.query_profile import PipelineExplainOptions class PipelineResult: @@ -137,3 +163,138 @@ def get(self, field_path: str | FieldPath) -> Any: ) value = get_nested_value(str_path, self._fields_pb) return _helpers.decode_value(value, self._client) + + +T = TypeVar("T", bound=PipelineResult) + + +class _PipelineResultContainer(Generic[T]): + """Base class to hold shared attributes for PipelineSnapshot and PipelineStream""" + + def __init__( + self, + return_type: Type[T], + pipeline: Pipeline | AsyncPipeline, + transaction: Transaction | AsyncTransaction | None, + read_time: datetime.datetime | None, + explain_options: PipelineExplainOptions | None, + index_mode: str | None, + additional_options: dict[str, Constant | Value], + ): + # public + self.transaction = transaction + self.pipeline: _BasePipeline = pipeline + self.execution_time: Timestamp | None = None + # private + self._client: Client | AsyncClient = pipeline._client + self._started: bool = False + self._read_time = read_time + self._explain_stats: ExplainStats | None = None + self._explain_options: PipelineExplainOptions | None = explain_options + self._return_type = return_type + self._index_mode = index_mode + self._additonal_options = { + k: v if isinstance(v, Value) else v._to_pb() + for k, v in additional_options.items() + } + + @property + def explain_stats(self) -> ExplainStats: + if self._explain_stats is not None: + return self._explain_stats + elif self._explain_options is None: + raise QueryExplainError("explain_options not set on query.") + elif not self._started: + raise QueryExplainError( + "explain_stats not available until query is complete" + ) + else: + raise QueryExplainError("explain_stats not found") + + def _build_request(self) -> ExecutePipelineRequest: + """ + shared logic for creating an ExecutePipelineRequest + """ + database_name = ( + f"projects/{self._client.project}/databases/{self._client._database}" + ) + transaction_id = ( + _helpers.get_transaction_id(self.transaction, read_operation=False) + if self.transaction is not None + else None + ) + options = {} + if self._explain_options: + options["explain_options"] = self._explain_options._to_value() + if self._index_mode: + options["index_mode"] = Value(string_value=self._index_mode) + if self._additonal_options: + options.update(self._additonal_options) + request = ExecutePipelineRequest( + database=database_name, + transaction=transaction_id, + structured_pipeline=self.pipeline._to_pb(**options), + read_time=self._read_time, + ) + return request + + def _process_response(self, response: ExecutePipelineResponse) -> Iterable[T]: + """Shared logic for processing an individual response from a stream""" + if response.explain_stats: + self._explain_stats = ExplainStats(response.explain_stats) + execution_time = response._pb.execution_time + if execution_time and not self.execution_time: + self.execution_time = execution_time + for doc in response.results: + ref = self._client.document(doc.name) if doc.name else None + yield self._return_type( + self._client, + doc.fields, + ref, + execution_time, + doc._pb.create_time if doc.create_time else None, + doc._pb.update_time if doc.update_time else None, + ) + + +class PipelineSnapshot(_PipelineResultContainer[T], list[T]): + """ + A list type that holds the result of a pipeline.execute() operation, along with related metadata + """ + + def __init__(self, results_list: list[T], source: _PipelineResultContainer[T]): + self.__dict__.update(source.__dict__.copy()) + list.__init__(self, results_list) + # snapshots are always complete + self._started = True + + +class PipelineStream(_PipelineResultContainer[T], Iterable[T]): + """ + An iterable stream representing the result of a pipeline.stream() operation, along with related metadata + """ + + def __iter__(self) -> Iterator[T]: + if self._started: + raise RuntimeError(f"{self.__class__.__name__} can only be iterated once") + self._started = True + request = self._build_request() + stream = self._client._firestore_api.execute_pipeline(request) + for response in stream: + yield from self._process_response(response) + + +class AsyncPipelineStream(_PipelineResultContainer[T], AsyncIterable[T]): + """ + An iterable stream representing the result of an async pipeline.stream() operation, along with related metadata + """ + + async def __aiter__(self) -> AsyncIterator[T]: + if self._started: + raise RuntimeError(f"{self.__class__.__name__} can only be iterated once") + self._started = True + request = self._build_request() + stream = await self._client._firestore_api.execute_pipeline(request) + async for response in stream: + for result in self._process_response(response): + yield result diff --git a/google/cloud/firestore_v1/query_profile.py b/google/cloud/firestore_v1/query_profile.py index 6925f83ff..5e8491fc6 100644 --- a/google/cloud/firestore_v1/query_profile.py +++ b/google/cloud/firestore_v1/query_profile.py @@ -19,6 +19,12 @@ from dataclasses import dataclass from google.protobuf.json_format import MessageToDict +from google.cloud.firestore_v1.types.document import MapValue +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.types.explain_stats import ( + ExplainStats as ExplainStats_pb, +) +from google.protobuf.wrappers_pb2 import StringValue @dataclass(frozen=True) @@ -42,6 +48,32 @@ def _to_dict(self): return {"analyze": self.analyze} +@dataclass(frozen=True) +class PipelineExplainOptions: + """ + Explain options for pipeline queries. + + Set on a pipeline.execution() or pipeline.stream() call, to provide + explain_stats in the pipeline output + + :type mode: str + :param mode: Optional. The mode of operation for this explain query. + When set to 'analyze', the query will be executed and return the full + query results along with execution statistics. + + :type output_format: str | None + :param output_format: Optional. The format in which to return the explain + stats. + """ + + mode: str = "analyze" + + def _to_value(self): + out_dict = {"mode": Value(string_value=self.mode)} + value_pb = MapValue(fields=out_dict) + return Value(map_value=value_pb) + + @dataclass(frozen=True) class PlanSummary: """ @@ -143,3 +175,54 @@ class QueryExplainError(Exception): """ pass + + +class ExplainStats: + """ + Contains query profiling statistics for a pipeline query. + + This class is not meant to be instantiated directly by the user. Instead, an + instance of `ExplainStats` may be returned by pipeline execution methods + when `explain_options` are provided. + + It provides methods to access the explain statistics in different formats. + """ + + def __init__(self, stats_pb: ExplainStats_pb): + """ + Args: + stats_pb (ExplainStats_pb): The raw protobuf message for explain stats. + """ + self._stats_pb = stats_pb + + def get_text(self) -> str: + """ + Returns the explain stats as a string. + + This method is suitable for explain formats that have a text-based output, + such as 'text' or 'json'. + + Returns: + str: The string representation of the explain stats. + + Raises: + QueryExplainError: If the explain stats payload from the backend is not + a string. This can happen if a non-text output format was requested. + """ + pb_data = self._stats_pb._pb.data + content = StringValue() + if pb_data.Unpack(content): + return content.value + raise QueryExplainError( + "Unable to decode explain stats. Did you request an output format that returns a string value, such as 'text' or 'json'?" + ) + + def get_raw(self) -> ExplainStats_pb: + """ + Returns the explain stats in an encoded proto format, as returned from the Firestore backend. + The caller is responsible for unpacking this proto message. + + Returns: + google.cloud.firestore_v1.types.explain_stats.ExplainStats: the proto from the backend + """ + return self._stats_pb diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 61b1a983c..615ff1226 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -21,6 +21,7 @@ import google.auth import pytest +import mock from google.api_core.exceptions import ( AlreadyExists, FailedPrecondition, @@ -1652,6 +1653,140 @@ def test_query_stream_or_get_w_explain_options_analyze_false( explain_metrics.execution_stats +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("method", ["execute", "stream"]) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +def test_pipeline_explain_options_explain_mode(database, method, query_docs): + """Explain currently not supported by backend. Expect error""" + from google.cloud.firestore_v1.query_profile import ( + PipelineExplainOptions, + ) + + collection, _, _ = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) + + # Tests either `execute()` or `stream()`. + method_under_test = getattr(pipeline, method) + explain_options = PipelineExplainOptions(mode="explain") + + # for now, expect error on explain mode + with pytest.raises(InvalidArgument) as e: + results = method_under_test(explain_options=explain_options) + list(results) + assert "Explain execution mode is not supported" in str(e) + + +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("method", ["execute", "stream"]) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +def test_pipeline_explain_options_analyze_mode(database, method, query_docs): + from google.cloud.firestore_v1.query_profile import ( + PipelineExplainOptions, + ExplainStats, + QueryExplainError, + ) + from google.cloud.firestore_v1.types.explain_stats import ( + ExplainStats as ExplainStats_pb, + ) + + collection, _, allowed_vals = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) + + # Tests either `execute()` or `stream()`. + method_under_test = getattr(pipeline, method) + results = method_under_test(explain_options=PipelineExplainOptions()) + + if method == "stream": + # check for error accessing explain stats before iterating + with pytest.raises( + QueryExplainError, + match="explain_stats not available until query is complete", + ): + results.explain_stats + + # Finish iterating results, and explain_stats should be available. + results_list = list(results) + num_results = len(results_list) + assert num_results == len(allowed_vals) + + # Verify explain_stats. + explain_stats = results.explain_stats + assert isinstance(explain_stats, ExplainStats) + + assert isinstance(explain_stats.get_raw(), ExplainStats_pb) + text_stats = explain_stats.get_text() + assert "Execution:" in text_stats + + +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("method", ["execute", "stream"]) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +def test_pipeline_explain_options_using_additional_options( + database, method, query_docs +): + """additional_options field allows passing in arbitrary options. Test with explain_options""" + from google.cloud.firestore_v1.query_profile import ( + PipelineExplainOptions, + ExplainStats, + ) + from google.cloud.firestore_v1.types.explain_stats import ( + ExplainStats as ExplainStats_pb, + ) + + collection, _, allowed_vals = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) + + # Tests either `execute()` or `stream()`. + method_under_test = getattr(pipeline, method) + + encoded_options = {"explain_options": PipelineExplainOptions()._to_value()} + + results = method_under_test( + explain_options=mock.Mock(), additional_options=encoded_options + ) + + # Finish iterating results, and explain_stats should be available./w_read + results_list = list(results) + num_results = len(results_list) + assert num_results == len(allowed_vals) + + # Verify explain_stats. + explain_stats = results.explain_stats + assert isinstance(explain_stats, ExplainStats) + + assert isinstance(explain_stats.get_raw(), ExplainStats_pb) + text_stats = explain_stats.get_text() + assert "Execution:" in text_stats + + +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +def test_pipeline_index_mode(database, query_docs): + """test pipeline query with explicit index mode""" + + collection, _, allowed_vals = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) + with pytest.raises(InvalidArgument) as e: + pipeline.execute(index_mode="fake_index") + assert "Invalid index_mode: fake_index" in str(e) + + @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_read_time(query_docs, cleanup, database): collection, stored, allowed_vals = query_docs @@ -1703,15 +1838,16 @@ def test_pipeline_w_read_time(query_docs, cleanup, database): new_data = { "a": 9000, "b": 1, - "c": [10000, 1000], - "stats": {"sum": 9001, "product": 9000}, } _, new_ref = collection.add(new_data) # Add to clean-up. cleanup(new_ref.delete) stored[new_ref.id] = new_data - pipeline = collection.where(filter=FieldFilter("b", "==", 1)).pipeline() + client = collection._client + query = collection.where(filter=FieldFilter("b", "==", 1)) + pipeline = client.pipeline().create_from(query) + # new query should have new_data new_results = list(pipeline.stream()) new_values = {result.ref.id: result.data() for result in new_results} diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 99b9da801..373c40118 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -22,6 +22,7 @@ import google.auth import pytest import pytest_asyncio +import mock from google.api_core import exceptions as core_exceptions from google.api_core import retry_async as retries from google.api_core.exceptions import ( @@ -1573,44 +1574,124 @@ async def test_query_stream_or_get_w_explain_options_analyze_false( _verify_explain_metrics_analyze_false(explain_metrics) -@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) -async def test_query_stream_w_read_time(query_docs, cleanup, database): - collection, stored, allowed_vals = query_docs - num_vals = len(allowed_vals) +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("method", ["execute", "stream"]) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +async def test_pipeline_explain_options_explain_mode(database, method, query_docs): + """Explain currently not supported by backend. Expect error""" + from google.api_core.exceptions import InvalidArgument + from google.cloud.firestore_v1.query_profile import ( + PipelineExplainOptions, + ) - # Find the most recent read_time in collections - read_time = max( - [(await docref.get()).read_time async for docref in collection.list_documents()] + collection, _, _ = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) + + method_under_test = getattr(pipeline, method) + explain_options = PipelineExplainOptions(mode="explain") + + with pytest.raises(InvalidArgument) as e: + if method == "stream": + results = method_under_test(explain_options=explain_options) + _ = [i async for i in results] + else: + await method_under_test(explain_options=explain_options) + + assert "Explain execution mode is not supported" in str(e.value) + + +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("method", ["execute", "stream"]) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +async def test_pipeline_explain_options_analyze_mode(database, method, query_docs): + from google.cloud.firestore_v1.query_profile import ( + PipelineExplainOptions, + ExplainStats, + QueryExplainError, + ) + from google.cloud.firestore_v1.types.explain_stats import ( + ExplainStats as ExplainStats_pb, ) - new_data = { - "a": 9000, - "b": 1, - "c": [10000, 1000], - "stats": {"sum": 9001, "product": 9000}, - } - _, new_ref = await collection.add(new_data) - # Add to clean-up. - cleanup(new_ref.delete) - stored[new_ref.id] = new_data - # Compare query at read_time to query at current time. - query = collection.where(filter=FieldFilter("b", "==", 1)) - values = { - snapshot.id: snapshot.to_dict() - async for snapshot in query.stream(read_time=read_time) - } - assert len(values) == num_vals - assert new_ref.id not in values - for key, value in values.items(): - assert stored[key] == value - assert value["b"] == 1 - assert value["a"] != 9000 - assert key != new_ref + collection, _, allowed_vals = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) - new_values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} - assert len(new_values) == num_vals + 1 - assert new_ref.id in new_values - assert new_values[new_ref.id] == new_data + method_under_test = getattr(pipeline, method) + explain_options = PipelineExplainOptions() + + if method == "execute": + results = await method_under_test(explain_options=explain_options) + num_results = len(results) + else: + results = method_under_test(explain_options=explain_options) + with pytest.raises( + QueryExplainError, + match="explain_stats not available until query is complete", + ): + results.explain_stats + + num_results = len([item async for item in results]) + + explain_stats = results.explain_stats + + assert num_results == len(allowed_vals) + + assert isinstance(explain_stats, ExplainStats) + assert isinstance(explain_stats.get_raw(), ExplainStats_pb) + text_stats = explain_stats.get_text() + assert "Execution:" in text_stats + + +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." +) +@pytest.mark.parametrize("method", ["execute", "stream"]) +@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) +async def test_pipeline_explain_options_using_additional_options( + database, method, query_docs +): + """additional_options field allows passing in arbitrary options. Test with explain_options""" + from google.cloud.firestore_v1.query_profile import ( + PipelineExplainOptions, + ExplainStats, + ) + from google.cloud.firestore_v1.types.explain_stats import ( + ExplainStats as ExplainStats_pb, + ) + + collection, _, allowed_vals = query_docs + client = collection._client + query = collection.where(filter=FieldFilter("a", "==", 1)) + pipeline = client.pipeline().create_from(query) + + method_under_test = getattr(pipeline, method) + encoded_options = {"explain_options": PipelineExplainOptions()._to_value()} + + stub = method_under_test( + explain_options=mock.Mock(), additional_options=encoded_options + ) + if method == "execute": + results = await stub + num_results = len(results) + else: + results = stub + num_results = len([item async for item in results]) + + assert num_results == len(allowed_vals) + + explain_stats = results.explain_stats + assert isinstance(explain_stats, ExplainStats) + assert isinstance(explain_stats.get_raw(), ExplainStats_pb) + text_stats = explain_stats.get_text() + assert "Execution:" in text_stats @pytest.mark.skipif(IS_KOKORO_TEST, reason="skipping pipeline verification on kokoro") @@ -1626,15 +1707,14 @@ async def test_pipeline_w_read_time(query_docs, cleanup, database): new_data = { "a": 9000, "b": 1, - "c": [10000, 1000], - "stats": {"sum": 9001, "product": 9000}, } _, new_ref = await collection.add(new_data) # Add to clean-up. cleanup(new_ref.delete) stored[new_ref.id] = new_data - - pipeline = collection.where(filter=FieldFilter("b", "==", 1)).pipeline() + client = collection._client + query = collection.where(filter=FieldFilter("b", "==", 1)) + pipeline = client.pipeline().create_from(query) # new query should have new_data new_results = [result async for result in pipeline.stream()] @@ -1657,6 +1737,46 @@ async def test_pipeline_w_read_time(query_docs, cleanup, database): assert key != new_ref.id +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) +async def test_query_stream_w_read_time(query_docs, cleanup, database): + collection, stored, allowed_vals = query_docs + num_vals = len(allowed_vals) + + # Find the most recent read_time in collections + read_time = max( + [(await docref.get()).read_time async for docref in collection.list_documents()] + ) + new_data = { + "a": 9000, + "b": 1, + "c": [10000, 1000], + "stats": {"sum": 9001, "product": 9000}, + } + _, new_ref = await collection.add(new_data) + # Add to clean-up. + cleanup(new_ref.delete) + stored[new_ref.id] = new_data + + # Compare query at read_time to query at current time. + query = collection.where(filter=FieldFilter("b", "==", 1)) + values = { + snapshot.id: snapshot.to_dict() + async for snapshot in query.stream(read_time=read_time) + } + assert len(values) == num_vals + assert new_ref.id not in values + for key, value in values.items(): + assert stored[key] == value + assert value["b"] == 1 + assert value["a"] != 9000 + assert key != new_ref + + new_values = {snapshot.id: snapshot.to_dict() async for snapshot in query.stream()} + assert len(new_values) == num_vals + 1 + assert new_ref.id in new_values + assert new_values[new_ref.id] == new_data + + @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_with_order_dot_key(client, cleanup, database): db = client diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py index 189b24fba..5a7fb360c 100644 --- a/tests/unit/v1/test_async_pipeline.py +++ b/tests/unit/v1/test_async_pipeline.py @@ -389,30 +389,6 @@ async def test_async_pipeline_stream_stream_equivalence(): assert stream_results[0].data()["key"] == "str_val" -@pytest.mark.asyncio -async def test_async_pipeline_stream_stream_equivalence_mocked(): - """ - pipeline.stream should call pipeline.stream internally - """ - import datetime - - ppl_1 = _make_async_pipeline() - expected_data = [object(), object()] - expected_transaction = object() - expected_read_time = datetime.datetime.now(tz=datetime.timezone.utc) - with mock.patch.object(ppl_1, "stream") as mock_stream: - mock_stream.return_value = _async_it(expected_data) - stream_results = await ppl_1.execute( - transaction=expected_transaction, read_time=expected_read_time - ) - assert mock_stream.call_count == 1 - assert mock_stream.call_args[0] == () - assert len(mock_stream.call_args[1]) == 2 - assert mock_stream.call_args[1]["transaction"] == expected_transaction - assert mock_stream.call_args[1]["read_time"] == expected_read_time - assert stream_results == expected_data - - @pytest.mark.parametrize( "method,args,result_cls", [ diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index 34d3400e8..fc8e90a04 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -96,6 +96,17 @@ def test_pipeline__to_pb(): assert pb.pipeline.stages[1] == stage_2._to_pb() +def test_pipeline__to_pb_with_options(): + from google.cloud.firestore_v1.types.pipeline import StructuredPipeline + from google.cloud.firestore_v1.types.document import Value + + ppl = _make_pipeline() + options = {"option_1": Value(integer_value=1)} + pb = ppl._to_pb(**options) + assert isinstance(pb, StructuredPipeline) + assert pb.options["option_1"].integer_value == 1 + + def test_pipeline_append(): """append should create a new pipeline with the additional stage""" @@ -365,29 +376,6 @@ def test_pipeline_execute_stream_equivalence(): assert execute_results[0].data()["key"] == "str_val" -def test_pipeline_execute_stream_equivalence_mocked(): - """ - pipeline.execute should call pipeline.stream internally - """ - import datetime - - ppl_1 = _make_pipeline() - expected_data = [object(), object()] - expected_transaction = object() - expected_read_time = datetime.datetime.now(tz=datetime.timezone.utc) - with mock.patch.object(ppl_1, "stream") as mock_stream: - mock_stream.return_value = expected_data - stream_results = ppl_1.execute( - transaction=expected_transaction, read_time=expected_read_time - ) - assert mock_stream.call_count == 1 - assert mock_stream.call_args[0] == () - assert len(mock_stream.call_args[1]) == 2 - assert mock_stream.call_args[1]["transaction"] == expected_transaction - assert mock_stream.call_args[1]["read_time"] == expected_read_time - assert stream_results == expected_data - - @pytest.mark.parametrize( "method,args,result_cls", [ diff --git a/tests/unit/v1/test_pipeline_result.py b/tests/unit/v1/test_pipeline_result.py index 2facf7110..579992741 100644 --- a/tests/unit/v1/test_pipeline_result.py +++ b/tests/unit/v1/test_pipeline_result.py @@ -15,7 +15,30 @@ import mock import pytest +from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse +from google.cloud.firestore_v1.pipeline_expressions import Constant from google.cloud.firestore_v1.pipeline_result import PipelineResult +from google.cloud.firestore_v1.pipeline_result import PipelineSnapshot +from google.cloud.firestore_v1.pipeline_result import PipelineStream +from google.cloud.firestore_v1.pipeline_result import AsyncPipelineStream +from google.cloud.firestore_v1.query_profile import QueryExplainError +from google.cloud.firestore_v1.query_profile import PipelineExplainOptions +from google.cloud.firestore_v1._helpers import encode_value +from google.cloud.firestore_v1.types.document import Document +from google.protobuf.timestamp_pb2 import Timestamp + + +_mock_stream_responses = [ + ExecutePipelineResponse( + results=[Document(name="projects/p/databases/d/documents/c/d1", fields={})], + execution_time=Timestamp(seconds=1, nanos=2), + explain_stats={"data": {}}, + ), + ExecutePipelineResponse( + results=[Document(name="projects/p/databases/d/documents/c/d2", fields={})], + execution_time=Timestamp(seconds=3, nanos=4), + ), +] class TestPipelineResult: @@ -174,3 +197,314 @@ def test_get_call(self): got = instance.get("key") decode_mock.assert_called_once_with("value", client) assert got == decode_mock.return_value + + +class TestPipelineSnapshot: + def _make_one(self, *args, **kwargs): + if not args: + # use defaults if not passed + args = [[], mock.Mock()] + return PipelineSnapshot(*args, **kwargs) + + def test_ctor(self): + in_arr = [1, 2, 3] + expected_type = object() + expected_pipeline = mock.Mock() + expected_transaction = object() + expected_read_time = 123 + expected_explain_options = object() + expected_index_mode = "mode" + expected_addtl_options = {} + source = PipelineStream( + expected_type, + expected_pipeline, + expected_transaction, + expected_read_time, + expected_explain_options, + expected_index_mode, + expected_addtl_options, + ) + instance = self._make_one(in_arr, source) + assert instance._return_type == expected_type + assert instance.pipeline == expected_pipeline + assert instance._client == expected_pipeline._client + assert instance._additonal_options == expected_addtl_options + assert instance._index_mode == expected_index_mode + assert instance._explain_options == expected_explain_options + assert instance._explain_stats is None + assert instance._started is True + assert instance.execution_time is None + assert instance.transaction == expected_transaction + assert instance._read_time == expected_read_time + + def test_list_methods(self): + instance = self._make_one(list(range(10)), mock.Mock()) + assert isinstance(instance, list) + assert len(instance) == 10 + assert instance[0] == 0 + assert instance[-1] == 9 + + def test_explain_stats(self): + instance = self._make_one() + expected_stats = mock.Mock() + instance._explain_stats = expected_stats + assert instance.explain_stats == expected_stats + # test different failure modes + instance._explain_stats = None + instance._explain_options = None + # fail if explain_stats set without explain_options + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "explain_options not set" in str(e) + # fail if explain_stats missing + instance._explain_options = object() + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "explain_stats not found" in str(e) + + +class SharedStreamTests: + """ + Shared test logic for PipelineStream and AsyncPipelineStream + """ + + def _make_one(self, *args, **kwargs): + raise NotImplementedError + + def _mock_init_args(self): + # return default mocks for all init args + from google.cloud.firestore_v1.pipeline import Pipeline + + return { + "return_type": PipelineResult, + "pipeline": Pipeline(mock.Mock()), + "transaction": None, + "read_time": None, + "explain_options": None, + "index_mode": None, + "additional_options": {}, + } + + def test_explain_stats(self): + instance = self._make_one() + expected_stats = mock.Mock() + instance._started = True + instance._explain_stats = expected_stats + assert instance.explain_stats == expected_stats + # test different failure modes + instance._explain_stats = None + instance._explain_options = None + # fail if explain_stats set without explain_options + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "explain_options not set" in str(e) + # fail if explain_stats missing + instance._explain_options = object() + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "explain_stats not found" in str(e) + # fail if not started + instance._started = False + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "not available until query is complete" in str(e) + + @pytest.mark.parametrize( + "init_kwargs,expected_options", + [ + ({"index_mode": "mode"}, {"index_mode": encode_value("mode")}), + ( + {"explain_options": PipelineExplainOptions()}, + {"explain_options": encode_value({"mode": "analyze"})}, + ), + ( + {"explain_options": PipelineExplainOptions(mode="explain")}, + {"explain_options": encode_value({"mode": "explain"})}, + ), + ( + {"additional_options": {"explain_options": Constant("custom")}}, + {"explain_options": encode_value("custom")}, + ), + ( + {"additional_options": {"explain_options": encode_value("custom")}}, + {"explain_options": encode_value("custom")}, + ), + ( + { + "explain_options": PipelineExplainOptions(), + "additional_options": {"explain_options": Constant.of("override")}, + }, + {"explain_options": encode_value("override")}, + ), + ( + { + "index_mode": "mode", + "additional_options": {"index_mode": Constant("new")}, + }, + {"index_mode": encode_value("new")}, + ), + ], + ) + def test_build_request_options(self, init_kwargs, expected_options): + """ + Certain Arguments to PipelineStream should be passed to `options` field in proto request + """ + instance = self._make_one(**init_kwargs) + request = instance._build_request() + options = dict(request.structured_pipeline.options) + assert options == expected_options + assert len(options) == len(expected_options) + + def test_build_request_transaction(self): + """Ensure transaction is passed down when building request""" + from google.cloud.firestore_v1.transaction import Transaction + + expected_id = b"expected" + transaction = Transaction(mock.Mock()) + transaction._id = expected_id + instance = self._make_one(transaction=transaction) + request = instance._build_request() + assert request.transaction == expected_id + + def test_build_request_read_time(self): + """Ensure readtime is passed down when building request""" + import datetime + + ts = datetime.datetime.now() + instance = self._make_one(read_time=ts) + request = instance._build_request() + assert request.read_time.timestamp() == ts.timestamp() + + +class TestPipelineStream(SharedStreamTests): + def _make_one(self, **kwargs): + init_kwargs = self._mock_init_args() + init_kwargs.update(kwargs) + return PipelineStream(**init_kwargs) + + def test_explain_stats(self): + instance = self._make_one() + expected_stats = mock.Mock() + instance._started = True + instance._explain_stats = expected_stats + assert instance.explain_stats == expected_stats + # test different failure modes + instance._explain_stats = None + instance._explain_options = None + # fail if explain_stats set without explain_options + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "explain_options not set" in str(e) + # fail if explain_stats missing + instance._explain_options = object() + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "explain_stats not found" in str(e) + # fail if not started + instance._started = False + with pytest.raises(QueryExplainError) as e: + instance.explain_stats + assert "not available until query is complete" in str(e) + + def test_iter(self): + pipeline = mock.Mock() + pipeline._client.project = "project-id" + pipeline._client._database = "database-id" + pipeline._client.document.side_effect = lambda path: mock.Mock( + id=path.split("/")[-1] + ) + pipeline._to_pb.return_value = {} + + instance = self._make_one(pipeline=pipeline) + + instance._client._firestore_api.execute_pipeline.return_value = ( + _mock_stream_responses + ) + + results = list(instance) + + assert len(results) == 2 + assert isinstance(results[0], PipelineResult) + assert results[0].id == "d1" + assert isinstance(results[1], PipelineResult) + assert results[1].id == "d2" + + assert instance.execution_time.seconds == 1 + assert instance.execution_time.nanos == 2 + + # expect empty stats + got_stats = instance.explain_stats.get_raw().data + assert got_stats.value == b"" + + instance._client._firestore_api.execute_pipeline.assert_called_once() + + def test_double_iterate(self): + instance = self._make_one() + instance._client._firestore_api.execute_pipeline.return_value = [] + # consume the iterator + list(instance) + with pytest.raises(RuntimeError): + list(instance) + + +class TestAsyncPipelineStream(SharedStreamTests): + def _make_one(self, **kwargs): + init_kwargs = self._mock_init_args() + init_kwargs.update(kwargs) + return AsyncPipelineStream(**init_kwargs) + + @pytest.mark.asyncio + async def test_aiter(self): + pipeline = mock.Mock() + pipeline._client.project = "project-id" + pipeline._client._database = "database-id" + pipeline._client.document.side_effect = lambda path: mock.Mock( + id=path.split("/")[-1] + ) + pipeline._to_pb.return_value = {} + + instance = self._make_one(pipeline=pipeline) + + async def async_gen(items): + for item in items: + yield item + + instance._client._firestore_api.execute_pipeline = mock.AsyncMock( + return_value=async_gen(_mock_stream_responses) + ) + + results = [item async for item in instance] + + assert len(results) == 2 + assert isinstance(results[0], PipelineResult) + assert results[0].id == "d1" + assert isinstance(results[1], PipelineResult) + assert results[1].id == "d2" + + assert instance.execution_time.seconds == 1 + assert instance.execution_time.nanos == 2 + + # expect empty stats + got_stats = instance.explain_stats.get_raw().data + assert got_stats.value == b"" + + instance._client._firestore_api.execute_pipeline.assert_called_once() + + @pytest.mark.asyncio + async def test_double_iterate(self): + instance = self._make_one() + + async def async_gen(items): + for item in items: + yield item + + # mock the api call to avoid real network requests + instance._client._firestore_api.execute_pipeline = mock.AsyncMock( + return_value=async_gen([]) + ) + + # consume the iterator + [item async for item in instance] + # should fail on second attempt + with pytest.raises(RuntimeError): + [item async for item in instance] diff --git a/tests/unit/v1/test_query_profile.py b/tests/unit/v1/test_query_profile.py index a3b0390c6..5b1e470b8 100644 --- a/tests/unit/v1/test_query_profile.py +++ b/tests/unit/v1/test_query_profile.py @@ -124,3 +124,64 @@ def test_explain_options__to_dict(): assert ExplainOptions(analyze=True)._to_dict() == {"analyze": True} assert ExplainOptions(analyze=False)._to_dict() == {"analyze": False} + + +@pytest.mark.parametrize("mode_str", ["analyze", "explain"]) +def test_pipeline_explain_options__to_value(mode_str): + """ + Should be able to create a Value protobuf representation of ExplainOptions + """ + from google.cloud.firestore_v1.query_profile import PipelineExplainOptions + from google.cloud.firestore_v1.types.document import MapValue + from google.cloud.firestore_v1.types.document import Value + + options = PipelineExplainOptions(mode=mode_str) + expected_value = Value( + map_value=MapValue(fields={"mode": Value(string_value=mode_str)}) + ) + assert options._to_value() == expected_value + + +def test_explain_stats_get_raw(): + """ + Test ExplainStats.get_raw(). Should return input directly + """ + from google.cloud.firestore_v1.query_profile import ExplainStats + + input = object() + stats = ExplainStats(input) + assert stats.get_raw() is input + + +def test_explain_stats_get_text(): + """ + Test ExplainStats.get_text() + """ + from google.cloud.firestore_v1.query_profile import ExplainStats + from google.cloud.firestore_v1.types import explain_stats as explain_stats_pb2 + from google.protobuf import any_pb2 + from google.protobuf import wrappers_pb2 + + expected_text = "some text" + text_pb = any_pb2.Any() + text_pb.Pack(wrappers_pb2.StringValue(value=expected_text)) + expected_stats_pb = explain_stats_pb2.ExplainStats(data=text_pb) + stats = ExplainStats(expected_stats_pb) + assert stats.get_text() == expected_text + + +def test_explain_stats_get_text_error(): + """ + Test ExplainStats.get_text() raises QueryExplainError + """ + from google.cloud.firestore_v1.query_profile import ( + ExplainStats, + QueryExplainError, + ) + from google.cloud.firestore_v1.types import explain_stats as explain_stats_pb2 + + expected_stats_pb = explain_stats_pb2.ExplainStats(data={}) + stats = ExplainStats(expected_stats_pb) + with pytest.raises(QueryExplainError) as exc: + stats.get_text() + assert "Unable to decode explain stats" in str(exc.value) From 18dfc6a850a570700e764f29218f2b7a697a17a4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 16 Dec 2025 08:46:01 -0800 Subject: [PATCH 15/27] chore: update docstring for pipelines array (#1129) Use new object syntax to reference pipeline array type in example --- google/cloud/firestore_v1/pipeline_expressions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 439b224cc..7e86ef6eb 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -1850,7 +1850,7 @@ class Array(Function): Creates an expression that creates a Firestore array value from an input list. Example: - >>> Expression.array(["bar", Field.of("baz")]) + >>> Array(["bar", Field.of("baz")]) Args: elements: The input list to evaluate in the expression From 8026ced1a69ed178dd775d1262c85be2c7b730eb Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 9 Jan 2026 13:53:51 -0800 Subject: [PATCH 16/27] chore: import main back into pipeline_preview branch --- .github/.OwlBot.lock.yaml | 17 - .github/.OwlBot.yaml | 33 - .github/auto-approve.yml | 3 - .github/release-please.yml | 12 - .github/release-trigger.yml | 2 - .github/sync-repo-settings.yaml | 47 - .github/workflows/lint.yml | 2 +- .github/workflows/mypy.yml | 2 +- .github/workflows/system_emulated.yml | 2 +- .github/workflows/unittest.yml | 2 +- .kokoro/presubmit/system-3.14.cfg | 7 - .kokoro/presubmit/system-3.7.cfg | 7 - .kokoro/presubmit/system-3.9.cfg | 7 - .kokoro/presubmit/system.cfg | 2 +- .kokoro/samples/python3.14/common.cfg | 40 + .kokoro/samples/python3.14/continuous.cfg | 6 + .kokoro/samples/python3.14/periodic-head.cfg | 11 + .kokoro/samples/python3.14/periodic.cfg | 6 + .kokoro/samples/python3.14/presubmit.cfg | 6 + .../generator-input/.repo-metadata.json | 18 + .../generator-input/librarian.py | 82 +- .librarian/generator-input/noxfile.py | 584 ++++++++++ .librarian/generator-input/setup.py | 95 ++ .librarian/state.yaml | 49 + CHANGELOG.md | 16 + CONTRIBUTING.rst | 1 + README.rst | 2 +- docs/README.rst | 198 +++- google/cloud/firestore/gapic_version.py | 2 +- .../firestore_admin_v1/gapic_metadata.json | 15 + .../cloud/firestore_admin_v1/gapic_version.py | 4 +- .../services/firestore_admin/async_client.py | 137 +++ .../services/firestore_admin/client.py | 181 ++- .../firestore_admin/transports/base.py | 19 +- .../firestore_admin/transports/grpc.py | 52 +- .../transports/grpc_asyncio.py | 59 +- .../firestore_admin/transports/rest.py | 221 +++- .../firestore_admin/transports/rest_base.py | 57 + .../firestore_admin_v1/types/__init__.py | 8 + .../firestore_admin_v1/types/database.py | 6 +- .../types/firestore_admin.py | 70 +- .../firestore_admin_v1/types/operation.py | 56 + .../firestore_admin_v1/types/snapshot.py | 67 ++ google/cloud/firestore_bundle/__init__.py | 104 ++ .../cloud/firestore_bundle/gapic_version.py | 4 +- google/cloud/firestore_v1/aggregation.py | 18 +- .../cloud/firestore_v1/async_aggregation.py | 12 +- google/cloud/firestore_v1/async_batch.py | 3 +- google/cloud/firestore_v1/async_client.py | 17 +- google/cloud/firestore_v1/async_collection.py | 5 +- google/cloud/firestore_v1/base_batch.py | 8 +- google/cloud/firestore_v1/base_client.py | 6 +- google/cloud/firestore_v1/base_collection.py | 3 +- google/cloud/firestore_v1/base_document.py | 6 +- google/cloud/firestore_v1/client.py | 7 +- google/cloud/firestore_v1/document.py | 2 +- google/cloud/firestore_v1/gapic_metadata.json | 15 + google/cloud/firestore_v1/gapic_version.py | 4 +- .../firestore_v1/services/firestore/client.py | 46 +- .../services/firestore/transports/base.py | 19 +- .../services/firestore/transports/grpc.py | 8 +- .../firestore/transports/grpc_asyncio.py | 22 +- .../services/firestore/transports/rest.py | 106 +- .../firestore/transports/rest_base.py | 57 + google/cloud/firestore_v1/stream_generator.py | 2 +- google/cloud/firestore_v1/types/document.py | 20 +- .../cloud/firestore_v1/types/explain_stats.py | 10 +- google/cloud/firestore_v1/types/firestore.py | 42 +- google/cloud/firestore_v1/types/pipeline.py | 2 +- google/cloud/firestore_v1/types/query.py | 176 +-- librarian.py | 118 ++ noxfile.py | 10 +- scripts/fixup_firestore_admin_v1_keywords.py | 212 ---- scripts/fixup_firestore_v1_keywords.py | 197 ---- setup.py | 8 +- .../test_firestore_admin.py | 1006 ++++++++++++++++- .../unit/gapic/firestore_v1/test_firestore.py | 257 ++++- tests/unit/gapic/v1/__init__.py | 0 tests/unit/v1/test_async_client.py | 3 + 79 files changed, 3806 insertions(+), 942 deletions(-) delete mode 100644 .github/.OwlBot.lock.yaml delete mode 100644 .github/.OwlBot.yaml delete mode 100644 .github/auto-approve.yml delete mode 100644 .github/release-please.yml delete mode 100644 .github/release-trigger.yml delete mode 100644 .github/sync-repo-settings.yaml delete mode 100644 .kokoro/presubmit/system-3.14.cfg delete mode 100644 .kokoro/presubmit/system-3.7.cfg delete mode 100644 .kokoro/presubmit/system-3.9.cfg create mode 100644 .kokoro/samples/python3.14/common.cfg create mode 100644 .kokoro/samples/python3.14/continuous.cfg create mode 100644 .kokoro/samples/python3.14/periodic-head.cfg create mode 100644 .kokoro/samples/python3.14/periodic.cfg create mode 100644 .kokoro/samples/python3.14/presubmit.cfg create mode 100644 .librarian/generator-input/.repo-metadata.json rename owlbot.py => .librarian/generator-input/librarian.py (56%) create mode 100644 .librarian/generator-input/noxfile.py create mode 100644 .librarian/generator-input/setup.py create mode 100644 .librarian/state.yaml mode change 120000 => 100644 docs/README.rst create mode 100644 google/cloud/firestore_admin_v1/types/snapshot.py create mode 100644 librarian.py delete mode 100644 scripts/fixup_firestore_admin_v1_keywords.py delete mode 100644 scripts/fixup_firestore_v1_keywords.py delete mode 100644 tests/unit/gapic/v1/__init__.py diff --git a/.github/.OwlBot.lock.yaml b/.github/.OwlBot.lock.yaml deleted file mode 100644 index 9a7846675..000000000 --- a/.github/.OwlBot.lock.yaml +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -docker: - image: gcr.io/cloud-devrel-public-resources/owlbot-python:latest - digest: sha256:4a9e5d44b98e8672e2037ee22bc6b4f8e844a2d75fcb78ea8a4b38510112abc6 -# created: 2025-10-07 diff --git a/.github/.OwlBot.yaml b/.github/.OwlBot.yaml deleted file mode 100644 index b720d256a..000000000 --- a/.github/.OwlBot.yaml +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2021 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -docker: - image: gcr.io/cloud-devrel-public-resources/owlbot-python:latest - -deep-remove-regex: - - /owl-bot-staging - -deep-preserve-regex: - - /owl-bot-staging/firestore/v1beta1 - -deep-copy-regex: - - source: /google/firestore/(v.*)/.*-py/(.*) - dest: /owl-bot-staging/firestore/$1/$2 - - source: /google/firestore/admin/(v.*)/.*-py/(.*) - dest: /owl-bot-staging/firestore_admin/$1/$2 - - source: /google/firestore/bundle/(.*-py)/(.*) - dest: /owl-bot-staging/firestore_bundle/$1/$2 - -begin-after-commit-hash: 107ed1217b5e87048263f52cd3911d5f851aca7e - diff --git a/.github/auto-approve.yml b/.github/auto-approve.yml deleted file mode 100644 index 311ebbb85..000000000 --- a/.github/auto-approve.yml +++ /dev/null @@ -1,3 +0,0 @@ -# https://github.com/googleapis/repo-automation-bots/tree/main/packages/auto-approve -processes: - - "OwlBotTemplateChanges" diff --git a/.github/release-please.yml b/.github/release-please.yml deleted file mode 100644 index fe749ff6b..000000000 --- a/.github/release-please.yml +++ /dev/null @@ -1,12 +0,0 @@ -releaseType: python -handleGHRelease: true -manifest: true -# NOTE: this section is generated by synthtool.languages.python -# See https://github.com/googleapis/synthtool/blob/master/synthtool/languages/python.py -branches: -- branch: v1 - handleGHRelease: true - releaseType: python -- branch: v0 - handleGHRelease: true - releaseType: python diff --git a/.github/release-trigger.yml b/.github/release-trigger.yml deleted file mode 100644 index 95896588a..000000000 --- a/.github/release-trigger.yml +++ /dev/null @@ -1,2 +0,0 @@ -enabled: true -multiScmName: python-firestore diff --git a/.github/sync-repo-settings.yaml b/.github/sync-repo-settings.yaml deleted file mode 100644 index 389c3747c..000000000 --- a/.github/sync-repo-settings.yaml +++ /dev/null @@ -1,47 +0,0 @@ -# Rules for main branch protection -branchProtectionRules: -# Identifies the protection rule pattern. Name of the branch to be protected. -# Defaults to `main` -- pattern: main - # Can admins overwrite branch protection. - # Defaults to `true` - isAdminEnforced: true - # Number of approving reviews required to update matching branches. - # Defaults to `1` - requiredApprovingReviewCount: 1 - # Are reviews from code owners required to update matching branches. - # Defaults to `false` - requiresCodeOwnerReviews: true - # Require up to date branches - requiresStrictStatusChecks: true - # List of required status check contexts that must pass for commits to be accepted to matching branches. - requiredStatusCheckContexts: - - 'Kokoro' - - 'Kokoro system' - - 'cla/google' - - 'OwlBot Post Processor' - - 'docs' - - 'docfx' - - 'lint' - - 'unit (3.9)' - - 'unit (3.10)' - - 'unit (3.11)' - - 'unit (3.12)' - - 'unit (3.13)' - - 'unit (3.14)' - - 'cover' - - 'run-systests' -# List of explicit permissions to add (additive only) -permissionRules: - # Team slug to add to repository permissions - - team: yoshi-admins - # Access level required, one of push|pull|admin|maintain|triage - permission: admin - # Team slug to add to repository permissions - - team: yoshi-python-admins - # Access level required, one of push|pull|admin|maintain|triage - permission: admin - # Team slug to add to repository permissions - - team: yoshi-python - # Access level required, one of push|pull|admin|maintain|triage - permission: push diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 9a0598202..3ed755f00 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -12,7 +12,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v5 with: - python-version: "3.13" + python-version: "3.14" - name: Install nox run: | python -m pip install --upgrade setuptools pip wheel diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index 27075146a..4997affc7 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -12,7 +12,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v5 with: - python-version: "3.13" + python-version: "3.14" - name: Install nox run: | python -m pip install --upgrade setuptools pip wheel diff --git a/.github/workflows/system_emulated.yml b/.github/workflows/system_emulated.yml index bb7986a0a..62a879072 100644 --- a/.github/workflows/system_emulated.yml +++ b/.github/workflows/system_emulated.yml @@ -17,7 +17,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v5 with: - python-version: '3.13' + python-version: '3.14' # firestore emulator requires java 21+ - name: Setup Java diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index 494bb568f..cc6fe2b2f 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -45,7 +45,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v5 with: - python-version: "3.13" + python-version: "3.14" - name: Install coverage run: | python -m pip install --upgrade setuptools pip wheel diff --git a/.kokoro/presubmit/system-3.14.cfg b/.kokoro/presubmit/system-3.14.cfg deleted file mode 100644 index 86e7c5d77..000000000 --- a/.kokoro/presubmit/system-3.14.cfg +++ /dev/null @@ -1,7 +0,0 @@ -# Format: //devtools/kokoro/config/proto/build.proto - -# Only run this nox session. -env_vars: { - key: "NOX_SESSION" - value: "system-3.14" -} \ No newline at end of file diff --git a/.kokoro/presubmit/system-3.7.cfg b/.kokoro/presubmit/system-3.7.cfg deleted file mode 100644 index 461537b3f..000000000 --- a/.kokoro/presubmit/system-3.7.cfg +++ /dev/null @@ -1,7 +0,0 @@ -# Format: //devtools/kokoro/config/proto/build.proto - -# Only run this nox session. -env_vars: { - key: "NOX_SESSION" - value: "system-3.7" -} \ No newline at end of file diff --git a/.kokoro/presubmit/system-3.9.cfg b/.kokoro/presubmit/system-3.9.cfg deleted file mode 100644 index b8ae66b37..000000000 --- a/.kokoro/presubmit/system-3.9.cfg +++ /dev/null @@ -1,7 +0,0 @@ -# Format: //devtools/kokoro/config/proto/build.proto - -# Only run this nox session. -env_vars: { - key: "NOX_SESSION" - value: "system-3.9" -} \ No newline at end of file diff --git a/.kokoro/presubmit/system.cfg b/.kokoro/presubmit/system.cfg index bd1fb514b..73904141b 100644 --- a/.kokoro/presubmit/system.cfg +++ b/.kokoro/presubmit/system.cfg @@ -3,5 +3,5 @@ # Only run this nox session. env_vars: { key: "NOX_SESSION" - value: "system-3.9" + value: "system-3.14" } diff --git a/.kokoro/samples/python3.14/common.cfg b/.kokoro/samples/python3.14/common.cfg new file mode 100644 index 000000000..4e07d3590 --- /dev/null +++ b/.kokoro/samples/python3.14/common.cfg @@ -0,0 +1,40 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +# Build logs will be here +action { + define_artifacts { + regex: "**/*sponge_log.xml" + } +} + +# Specify which tests to run +env_vars: { + key: "RUN_TESTS_SESSION" + value: "py-3.14" +} + +# Declare build specific Cloud project. +env_vars: { + key: "BUILD_SPECIFIC_GCLOUD_PROJECT" + value: "python-docs-samples-tests-314" +} + +env_vars: { + key: "TRAMPOLINE_BUILD_FILE" + value: "github/python-firestore/.kokoro/test-samples.sh" +} + +# Configure the docker image for kokoro-trampoline. +env_vars: { + key: "TRAMPOLINE_IMAGE" + value: "gcr.io/cloud-devrel-kokoro-resources/python-samples-testing-docker" +} + +# Download secrets for samples +gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/python-docs-samples" + +# Download trampoline resources. +gfile_resources: "/bigstore/cloud-devrel-kokoro-resources/trampoline" + +# Use the trampoline script to run in docker. +build_file: "python-firestore/.kokoro/trampoline_v2.sh" diff --git a/.kokoro/samples/python3.14/continuous.cfg b/.kokoro/samples/python3.14/continuous.cfg new file mode 100644 index 000000000..a1c8d9759 --- /dev/null +++ b/.kokoro/samples/python3.14/continuous.cfg @@ -0,0 +1,6 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +env_vars: { + key: "INSTALL_LIBRARY_FROM_SOURCE" + value: "True" +} \ No newline at end of file diff --git a/.kokoro/samples/python3.14/periodic-head.cfg b/.kokoro/samples/python3.14/periodic-head.cfg new file mode 100644 index 000000000..21998d090 --- /dev/null +++ b/.kokoro/samples/python3.14/periodic-head.cfg @@ -0,0 +1,11 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +env_vars: { + key: "INSTALL_LIBRARY_FROM_SOURCE" + value: "True" +} + +env_vars: { + key: "TRAMPOLINE_BUILD_FILE" + value: "github/python-firestore/.kokoro/test-samples-against-head.sh" +} diff --git a/.kokoro/samples/python3.14/periodic.cfg b/.kokoro/samples/python3.14/periodic.cfg new file mode 100644 index 000000000..71cd1e597 --- /dev/null +++ b/.kokoro/samples/python3.14/periodic.cfg @@ -0,0 +1,6 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +env_vars: { + key: "INSTALL_LIBRARY_FROM_SOURCE" + value: "False" +} diff --git a/.kokoro/samples/python3.14/presubmit.cfg b/.kokoro/samples/python3.14/presubmit.cfg new file mode 100644 index 000000000..a1c8d9759 --- /dev/null +++ b/.kokoro/samples/python3.14/presubmit.cfg @@ -0,0 +1,6 @@ +# Format: //devtools/kokoro/config/proto/build.proto + +env_vars: { + key: "INSTALL_LIBRARY_FROM_SOURCE" + value: "True" +} \ No newline at end of file diff --git a/.librarian/generator-input/.repo-metadata.json b/.librarian/generator-input/.repo-metadata.json new file mode 100644 index 000000000..670bbc0e4 --- /dev/null +++ b/.librarian/generator-input/.repo-metadata.json @@ -0,0 +1,18 @@ +{ + "name": "firestore", + "name_pretty": "Cloud Firestore", + "product_documentation": "https://cloud.google.com/firestore", + "client_documentation": "https://cloud.google.com/python/docs/reference/firestore/latest", + "issue_tracker": "https://issuetracker.google.com/savedsearches/5337669", + "release_level": "stable", + "language": "python", + "library_type": "GAPIC_COMBO", + "repo": "googleapis/python-firestore", + "distribution_name": "google-cloud-firestore", + "api_id": "firestore.googleapis.com", + "requires_billing": true, + "default_version": "v1", + "codeowner_team": "@googleapis/api-firestore @googleapis/api-firestore-partners", + "api_shortname": "firestore", + "api_description": "is a fully-managed NoSQL document database for mobile, web, and server development from Firebase and Google Cloud Platform. It's backed by a multi-region replicated database that ensures once data is committed, it's durable even in the face of unexpected disasters. Not only that, but despite being a distributed database, it's also strongly consistent and offers seamless integration with other Firebase and Google Cloud Platform products, including Google Cloud Functions." +} diff --git a/owlbot.py b/.librarian/generator-input/librarian.py similarity index 56% rename from owlbot.py rename to .librarian/generator-input/librarian.py index a9323ce3c..ec92a9345 100644 --- a/owlbot.py +++ b/.librarian/generator-input/librarian.py @@ -28,50 +28,10 @@ firestore_default_version = "v1" firestore_admin_default_version = "v1" -# This is a customized version of the s.get_staging_dirs() function from synthtool to -# cater for copying 3 different folders from googleapis-gen -# which are firestore, firestore/admin and firestore/bundle. -# Source https://github.com/googleapis/synthtool/blob/master/synthtool/transforms.py#L280 -def get_staging_dirs( - default_version: Optional[str] = None, sub_directory: Optional[str] = None -) -> List[Path]: - """Returns the list of directories, one per version, copied from - https://github.com/googleapis/googleapis-gen. Will return in lexical sorting - order with the exception of the default_version which will be last (if specified). - - Args: - default_version (str): the default version of the API. The directory for this version - will be the last item in the returned list if specified. - sub_directory (str): if a `sub_directory` is provided, only the directories within the - specified `sub_directory` will be returned. - - Returns: the empty list if no file were copied. - """ - - staging = Path("owl-bot-staging") - - if sub_directory: - staging /= sub_directory - - if staging.is_dir(): - # Collect the subdirectories of the staging directory. - versions = [v.name for v in staging.iterdir() if v.is_dir()] - # Reorder the versions so the default version always comes last. - versions = [v for v in versions if v != default_version] - versions.sort() - if default_version is not None: - versions += [default_version] - dirs = [staging / v for v in versions] - for dir in dirs: - s._tracked_paths.add(dir) - return dirs - else: - return [] - -def update_fixup_scripts(library): +def update_fixup_scripts(path): # Add message for missing 'libcst' dependency s.replace( - library / "scripts/fixup*.py", + library / "scripts" / path, """import libcst as cst""", """try: import libcst as cst @@ -82,19 +42,21 @@ def update_fixup_scripts(library): """, ) -for library in get_staging_dirs(default_version=firestore_default_version, sub_directory="firestore"): - s.move(library / f"google/cloud/firestore_{library.name}", excludes=[f"__init__.py", "**/gapic_version.py", "noxfile.py"]) +for library in s.get_staging_dirs(default_version=firestore_default_version): + s.move(library / f"google/cloud/firestore_{library.name}", excludes=[f"__init__.py", "noxfile.py"]) s.move(library / f"tests/", f"tests") - update_fixup_scripts(library) - s.move(library / "scripts") + fixup_script_path = "fixup_firestore_v1_keywords.py" + update_fixup_scripts(fixup_script_path) + s.move(library / "scripts" / fixup_script_path) -for library in get_staging_dirs(default_version=firestore_admin_default_version, sub_directory="firestore_admin"): - s.move(library / f"google/cloud/firestore_admin_{library.name}", excludes=[f"__init__.py", "**/gapic_version.py", "noxfile.py"]) +for library in s.get_staging_dirs(default_version=firestore_admin_default_version): + s.move(library / f"google/cloud/firestore_admin_{library.name}", excludes=[f"__init__.py", "noxfile.py"]) s.move(library / f"tests", f"tests") - update_fixup_scripts(library) - s.move(library / "scripts") + fixup_script_path = "fixup_firestore_admin_v1_keywords.py" + update_fixup_scripts(fixup_script_path) + s.move(library / "scripts" / fixup_script_path) -for library in get_staging_dirs(sub_directory="firestore_bundle"): +for library in s.get_staging_dirs(): s.replace( library / "google/cloud/bundle/types/bundle.py", "from google.firestore.v1 import document_pb2 # type: ignore\n" @@ -127,7 +89,7 @@ def update_fixup_scripts(library): s.move( library / f"google/cloud/bundle", f"google/cloud/firestore_bundle", - excludes=["**/gapic_version.py", "noxfile.py"], + excludes=["noxfile.py"], ) s.move(library / f"tests", f"tests") @@ -143,24 +105,14 @@ def update_fixup_scripts(library): microgenerator=True, cov_level=100, split_system_tests=True, - default_python_version="3.13", - system_test_python_versions=["3.9", "3.14"], + default_python_version="3.14", + system_test_python_versions=["3.14"], unit_test_python_versions=["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13", "3.14"], ) s.move(templated_files, - excludes=[".github/release-please.yml", "renovate.json"]) + excludes=[".github/**", ".kokoro/**", "renovate.json"]) python.py_samples(skip_readmes=True) s.shell.run(["nox", "-s", "blacken"], hide_output=False) - -s.replace( - ".kokoro/build.sh", - "# Setup service account credentials.", - """\ -# Setup firestore account credentials -export FIRESTORE_APPLICATION_CREDENTIALS=${KOKORO_GFILE_DIR}/firebase-credentials.json - -# Setup service account credentials.""", -) diff --git a/.librarian/generator-input/noxfile.py b/.librarian/generator-input/noxfile.py new file mode 100644 index 000000000..4fb209cbc --- /dev/null +++ b/.librarian/generator-input/noxfile.py @@ -0,0 +1,584 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by synthtool. DO NOT EDIT! + +from __future__ import absolute_import + +import os +import pathlib +import re +import shutil +from typing import Dict, List +import warnings + +import nox + +FLAKE8_VERSION = "flake8==6.1.0" +PYTYPE_VERSION = "pytype==2020.7.24" +BLACK_VERSION = "black[jupyter]==23.7.0" +ISORT_VERSION = "isort==5.11.0" +LINT_PATHS = ["docs", "google", "tests", "noxfile.py", "setup.py"] + +DEFAULT_PYTHON_VERSION = "3.14" + +UNIT_TEST_PYTHON_VERSIONS: List[str] = [ + "3.7", + "3.8", + "3.9", + "3.10", + "3.11", + "3.12", + "3.13", + "3.14", +] +UNIT_TEST_STANDARD_DEPENDENCIES = [ + "mock", + "asyncmock", + "pytest", + "pytest-cov", + "pytest-asyncio==0.21.2", +] +UNIT_TEST_EXTERNAL_DEPENDENCIES: List[str] = [ + "aiounittest", + "six", + "freezegun", +] +UNIT_TEST_LOCAL_DEPENDENCIES: List[str] = [] +UNIT_TEST_DEPENDENCIES: List[str] = [] +UNIT_TEST_EXTRAS: List[str] = [] +UNIT_TEST_EXTRAS_BY_PYTHON: Dict[str, List[str]] = {} + +SYSTEM_TEST_PYTHON_VERSIONS: List[str] = ["3.14"] +SYSTEM_TEST_STANDARD_DEPENDENCIES: List[str] = [ + "mock", + "pytest", + "google-cloud-testutils", +] +SYSTEM_TEST_EXTERNAL_DEPENDENCIES: List[str] = [ + "pytest-asyncio==0.21.2", + "six", +] +SYSTEM_TEST_LOCAL_DEPENDENCIES: List[str] = [] +SYSTEM_TEST_DEPENDENCIES: List[str] = [] +SYSTEM_TEST_EXTRAS: List[str] = [] +SYSTEM_TEST_EXTRAS_BY_PYTHON: Dict[str, List[str]] = {} + +CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() + +nox.options.sessions = [ + "unit-3.9", + "unit-3.10", + "unit-3.11", + "unit-3.12", + "unit-3.13", + "unit-3.14", + "system_emulated", + "system", + "mypy", + "cover", + "lint", + "lint_setup_py", + "blacken", + "docs", + "docfx", + "format", +] + +# Error if a python version is missing +nox.options.error_on_missing_interpreters = True + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def lint(session): + """Run linters. + + Returns a failure if the linters find linting errors or sufficiently + serious code quality issues. + """ + session.install(FLAKE8_VERSION, BLACK_VERSION) + session.run( + "black", + "--check", + *LINT_PATHS, + ) + session.run("flake8", "google", "tests") + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def blacken(session): + """Run black. Format code to uniform standard.""" + session.install(BLACK_VERSION) + session.run( + "black", + *LINT_PATHS, + ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def format(session): + """ + Run isort to sort imports. Then run black + to format code to uniform standard. + """ + session.install(BLACK_VERSION, ISORT_VERSION) + # Use the --fss option to sort imports using strict alphabetical order. + # See https://pycqa.github.io/isort/docs/configuration/options.html#force-sort-within-sections + session.run( + "isort", + "--fss", + *LINT_PATHS, + ) + session.run( + "black", + *LINT_PATHS, + ) + + +@nox.session(python="3.7") +def pytype(session): + """Verify type hints are pytype compatible.""" + session.install(PYTYPE_VERSION) + session.run( + "pytype", + ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def mypy(session): + """Verify type hints are mypy compatible.""" + session.install("-e", ".") + session.install("mypy", "types-setuptools", "types-protobuf") + session.run( + "mypy", + "-p", + "google.cloud.firestore_v1", + "--no-incremental", + "--check-untyped-defs", + "--exclude", + "services", + ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def lint_setup_py(session): + """Verify that setup.py is valid (including RST check).""" + session.install("setuptools", "docutils", "pygments") + session.run("python", "setup.py", "check", "--restructuredtext", "--strict") + + +def install_unittest_dependencies(session, *constraints): + standard_deps = UNIT_TEST_STANDARD_DEPENDENCIES + UNIT_TEST_DEPENDENCIES + session.install(*standard_deps, *constraints) + + if UNIT_TEST_EXTERNAL_DEPENDENCIES: + warnings.warn( + "'unit_test_external_dependencies' is deprecated. Instead, please " + "use 'unit_test_dependencies' or 'unit_test_local_dependencies'.", + DeprecationWarning, + ) + session.install(*UNIT_TEST_EXTERNAL_DEPENDENCIES, *constraints) + + if UNIT_TEST_LOCAL_DEPENDENCIES: + session.install(*UNIT_TEST_LOCAL_DEPENDENCIES, *constraints) + + if UNIT_TEST_EXTRAS_BY_PYTHON: + extras = UNIT_TEST_EXTRAS_BY_PYTHON.get(session.python, []) + elif UNIT_TEST_EXTRAS: + extras = UNIT_TEST_EXTRAS + else: + extras = [] + + if extras: + session.install("-e", f".[{','.join(extras)}]", *constraints) + else: + session.install("-e", ".", *constraints) + + +@nox.session(python=UNIT_TEST_PYTHON_VERSIONS) +@nox.parametrize( + "protobuf_implementation", + ["python", "upb", "cpp"], +) +def unit(session, protobuf_implementation): + # Install all test dependencies, then install this package in-place. + + py_version = tuple([int(v) for v in session.python.split(".")]) + if protobuf_implementation == "cpp" and py_version >= (3, 11): + session.skip("cpp implementation is not supported in python 3.11+") + + constraints_path = str( + CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" + ) + install_unittest_dependencies(session, "-c", constraints_path) + + # TODO(https://github.com/googleapis/synthtool/issues/1976): + # Remove the 'cpp' implementation once support for Protobuf 3.x is dropped. + # The 'cpp' implementation requires Protobuf<4. + if protobuf_implementation == "cpp": + session.install("protobuf<4") + + # Run py.test against the unit tests. + session.run( + "py.test", + "--quiet", + f"--junitxml=unit_{session.python}_sponge_log.xml", + "--cov=google", + "--cov=tests/unit", + "--cov-append", + "--cov-config=.coveragerc", + "--cov-report=", + "--cov-fail-under=0", + os.path.join("tests", "unit"), + *session.posargs, + env={ + "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": protobuf_implementation, + }, + ) + + +def install_systemtest_dependencies(session, *constraints): + # Use pre-release gRPC for system tests. + # Exclude version 1.52.0rc1 which has a known issue. + # See https://github.com/grpc/grpc/issues/32163 + session.install("--pre", "grpcio!=1.52.0rc1") + + session.install(*SYSTEM_TEST_STANDARD_DEPENDENCIES, *constraints) + + if SYSTEM_TEST_EXTERNAL_DEPENDENCIES: + session.install(*SYSTEM_TEST_EXTERNAL_DEPENDENCIES, *constraints) + + if SYSTEM_TEST_LOCAL_DEPENDENCIES: + session.install("-e", *SYSTEM_TEST_LOCAL_DEPENDENCIES, *constraints) + + if SYSTEM_TEST_DEPENDENCIES: + session.install("-e", *SYSTEM_TEST_DEPENDENCIES, *constraints) + + if SYSTEM_TEST_EXTRAS_BY_PYTHON: + extras = SYSTEM_TEST_EXTRAS_BY_PYTHON.get(session.python, []) + elif SYSTEM_TEST_EXTRAS: + extras = SYSTEM_TEST_EXTRAS + else: + extras = [] + + if extras: + session.install("-e", f".[{','.join(extras)}]", *constraints) + else: + session.install("-e", ".", *constraints) + + +@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) +def system_emulated(session): + import subprocess + import signal + + try: + # https://github.com/googleapis/python-firestore/issues/472 + # Kokoro image doesn't have java installed, don't attempt to run emulator. + subprocess.call(["java", "--version"]) + except OSError: + session.skip("java not found but required for emulator support") + + try: + subprocess.call(["gcloud", "--version"]) + except OSError: + session.skip("gcloud not found but required for emulator support") + + # Currently, CI/CD doesn't have beta component of gcloud. + subprocess.call( + [ + "gcloud", + "components", + "install", + "beta", + "cloud-firestore-emulator", + ] + ) + + hostport = "localhost:8789" + session.env["FIRESTORE_EMULATOR_HOST"] = hostport + + p = subprocess.Popen( + [ + "gcloud", + "--quiet", + "beta", + "emulators", + "firestore", + "start", + "--host-port", + hostport, + ] + ) + + try: + system(session) + finally: + # Stop Emulator + os.killpg(os.getpgid(p.pid), signal.SIGKILL) + + +@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) +def system(session): + """Run the system test suite.""" + constraints_path = str( + CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" + ) + system_test_path = os.path.join("tests", "system.py") + system_test_folder_path = os.path.join("tests", "system") + + # Check the value of `RUN_SYSTEM_TESTS` env var. It defaults to true. + if os.environ.get("RUN_SYSTEM_TESTS", "true") == "false": + session.skip("RUN_SYSTEM_TESTS is set to false, skipping") + # Install pyopenssl for mTLS testing. + if os.environ.get("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true": + session.install("pyopenssl") + + system_test_exists = os.path.exists(system_test_path) + system_test_folder_exists = os.path.exists(system_test_folder_path) + # Sanity check: only run tests if found. + if not system_test_exists and not system_test_folder_exists: + session.skip("System tests were not found") + + install_systemtest_dependencies(session, "-c", constraints_path) + + # Run py.test against the system tests. + if system_test_exists: + session.run( + "py.test", + "--verbose", + f"--junitxml=system_{session.python}_sponge_log.xml", + system_test_path, + *session.posargs, + ) + if system_test_folder_exists: + session.run( + "py.test", + "--verbose", + f"--junitxml=system_{session.python}_sponge_log.xml", + system_test_folder_path, + *session.posargs, + ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def cover(session): + """Run the final coverage report. + + This outputs the coverage report aggregating coverage from the unit + test runs (not system test runs), and then erases coverage data. + """ + session.install("coverage", "pytest-cov") + session.run( + "coverage", + "report", + "--show-missing", + "--fail-under=100", + "--omit=tests/*", + ) + + session.run("coverage", "erase") + + +@nox.session(python="3.10") +def docs(session): + """Build the docs for this library.""" + + session.install("-e", ".") + session.install( + # We need to pin to specific versions of the `sphinxcontrib-*` packages + # which still support sphinx 4.x. + # See https://github.com/googleapis/sphinx-docfx-yaml/issues/344 + # and https://github.com/googleapis/sphinx-docfx-yaml/issues/345. + "sphinxcontrib-applehelp==1.0.4", + "sphinxcontrib-devhelp==1.0.2", + "sphinxcontrib-htmlhelp==2.0.1", + "sphinxcontrib-qthelp==1.0.3", + "sphinxcontrib-serializinghtml==1.1.5", + "sphinx==4.5.0", + "alabaster", + "recommonmark", + ) + + shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) + session.run( + "sphinx-build", + "-W", # warnings as errors + "-T", # show full traceback on exception + "-N", # no colors + "-b", + "html", + "-d", + os.path.join("docs", "_build", "doctrees", ""), + os.path.join("docs", ""), + os.path.join("docs", "_build", "html", ""), + ) + + +@nox.session(python="3.10") +def docfx(session): + """Build the docfx yaml files for this library.""" + + session.install("-e", ".") + session.install( + # We need to pin to specific versions of the `sphinxcontrib-*` packages + # which still support sphinx 4.x. + # See https://github.com/googleapis/sphinx-docfx-yaml/issues/344 + # and https://github.com/googleapis/sphinx-docfx-yaml/issues/345. + "sphinxcontrib-applehelp==1.0.4", + "sphinxcontrib-devhelp==1.0.2", + "sphinxcontrib-htmlhelp==2.0.1", + "sphinxcontrib-qthelp==1.0.3", + "sphinxcontrib-serializinghtml==1.1.5", + "gcp-sphinx-docfx-yaml", + "alabaster", + "recommonmark", + ) + + shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) + session.run( + "sphinx-build", + "-T", # show full traceback on exception + "-N", # no colors + "-D", + ( + "extensions=sphinx.ext.autodoc," + "sphinx.ext.autosummary," + "docfx_yaml.extension," + "sphinx.ext.intersphinx," + "sphinx.ext.coverage," + "sphinx.ext.napoleon," + "sphinx.ext.todo," + "sphinx.ext.viewcode," + "recommonmark" + ), + "-b", + "html", + "-d", + os.path.join("docs", "_build", "doctrees", ""), + os.path.join("docs", ""), + os.path.join("docs", "_build", "html", ""), + ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +@nox.parametrize( + "protobuf_implementation", + ["python", "upb", "cpp"], +) +def prerelease_deps(session, protobuf_implementation): + """Run all tests with prerelease versions of dependencies installed.""" + + py_version = tuple([int(v) for v in session.python.split(".")]) + if protobuf_implementation == "cpp" and py_version >= (3, 11): + session.skip("cpp implementation is not supported in python 3.11+") + + # Install all dependencies + session.install("-e", ".[all, tests, tracing]") + unit_deps_all = UNIT_TEST_STANDARD_DEPENDENCIES + UNIT_TEST_EXTERNAL_DEPENDENCIES + session.install(*unit_deps_all) + system_deps_all = ( + SYSTEM_TEST_STANDARD_DEPENDENCIES + SYSTEM_TEST_EXTERNAL_DEPENDENCIES + ) + session.install(*system_deps_all) + + # Because we test minimum dependency versions on the minimum Python + # version, the first version we test with in the unit tests sessions has a + # constraints file containing all dependencies and extras. + with open( + CURRENT_DIRECTORY + / "testing" + / f"constraints-{UNIT_TEST_PYTHON_VERSIONS[0]}.txt", + encoding="utf-8", + ) as constraints_file: + constraints_text = constraints_file.read() + + # Ignore leading whitespace and comment lines. + constraints_deps = [ + match.group(1) + for match in re.finditer( + r"^\s*(\S+)(?===\S+)", constraints_text, flags=re.MULTILINE + ) + ] + + session.install(*constraints_deps) + + prerel_deps = [ + "protobuf", + # dependency of grpc + "six", + "grpc-google-iam-v1", + "googleapis-common-protos", + "grpcio", + "grpcio-status", + "google-api-core", + "google-auth", + "proto-plus", + "google-cloud-testutils", + # dependencies of google-cloud-testutils" + "click", + ] + + for dep in prerel_deps: + session.install("--pre", "--no-deps", "--upgrade", dep) + + # Remaining dependencies + other_deps = [ + "requests", + ] + session.install(*other_deps) + + # Print out prerelease package versions + session.run( + "python", "-c", "import google.protobuf; print(google.protobuf.__version__)" + ) + session.run("python", "-c", "import grpc; print(grpc.__version__)") + session.run("python", "-c", "import google.auth; print(google.auth.__version__)") + + session.run( + "py.test", + "tests/unit", + env={ + "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": protobuf_implementation, + }, + ) + + system_test_path = os.path.join("tests", "system.py") + system_test_folder_path = os.path.join("tests", "system") + + # Only run system tests if found. + if os.path.exists(system_test_path): + session.run( + "py.test", + "--verbose", + f"--junitxml=system_{session.python}_sponge_log.xml", + system_test_path, + *session.posargs, + env={ + "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": protobuf_implementation, + }, + ) + if os.path.exists(system_test_folder_path): + session.run( + "py.test", + "--verbose", + f"--junitxml=system_{session.python}_sponge_log.xml", + system_test_folder_path, + *session.posargs, + env={ + "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": protobuf_implementation, + }, + ) diff --git a/.librarian/generator-input/setup.py b/.librarian/generator-input/setup.py new file mode 100644 index 000000000..28d6faf51 --- /dev/null +++ b/.librarian/generator-input/setup.py @@ -0,0 +1,95 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import os + +import setuptools + +# Package metadata. + +name = "google-cloud-firestore" +description = "Google Cloud Firestore API client library" + +package_root = os.path.abspath(os.path.dirname(__file__)) + +version = {} +with open(os.path.join(package_root, "google/cloud/firestore/gapic_version.py")) as fp: + exec(fp.read(), version) +version = version["__version__"] +release_status = "Development Status :: 5 - Production/Stable" +dependencies = [ + "google-api-core[grpc] >= 1.34.0, <3.0.0,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,!=2.9.*,!=2.10.*", + # Exclude incompatible versions of `google-auth` + # See https://github.com/googleapis/google-cloud-python/issues/12364 + "google-auth >= 2.14.1, <3.0.0,!=2.24.0,!=2.25.0", + "google-cloud-core >= 1.4.1, <3.0.0", + "proto-plus >= 1.22.0, <2.0.0", + "proto-plus >= 1.22.2, <2.0.0; python_version>='3.11'", + "proto-plus >= 1.25.0, <2.0.0; python_version>='3.13'", + "protobuf>=3.20.2,<7.0.0dev,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", +] +extras = {} + + +# Setup boilerplate below this line. + +package_root = os.path.abspath(os.path.dirname(__file__)) +readme_filename = os.path.join(package_root, "README.rst") +with io.open(readme_filename, encoding="utf-8") as readme_file: + readme = readme_file.read() + +# Only include packages under the 'google' namespace. Do not include tests, +# benchmarks, etc. +packages = [ + package + for package in setuptools.find_namespace_packages() + if package.startswith("google") +] + +setuptools.setup( + name=name, + version=version, + description=description, + long_description=readme, + author="Google LLC", + author_email="googleapis-packages@google.com", + license="Apache 2.0", + url="https://github.com/googleapis/python-firestore", + classifiers=[ + release_status, + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Operating System :: OS Independent", + "Topic :: Internet", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + platforms="Posix; MacOS X; Windows", + packages=packages, + install_requires=dependencies, + extras_require=extras, + python_requires=">=3.7", + include_package_data=True, + zip_safe=False, +) diff --git a/.librarian/state.yaml b/.librarian/state.yaml new file mode 100644 index 000000000..7715edb27 --- /dev/null +++ b/.librarian/state.yaml @@ -0,0 +1,49 @@ +image: us-central1-docker.pkg.dev/cloud-sdk-librarian-prod/images-prod/python-librarian-generator@sha256:b8058df4c45e9a6e07f6b4d65b458d0d059241dd34c814f151c8bf6b89211209 +libraries: + - id: google-cloud-firestore + version: 2.22.0 + last_generated_commit: 1a9d00bed77e6db82ff67764ffe14e3b5209f5cd + apis: + - path: google/firestore/v1 + service_config: firestore_v1.yaml + - path: google/firestore/admin/v1 + service_config: firestore_v1.yaml + - path: google/firestore/bundle + service_config: "" + source_roots: + - . + preserve_regex: [] + remove_regex: + - ^google/cloud/firestore_v1/services + - ^google/cloud/firestore_v1/types + - ^google/cloud/firestore_v1/gapic + - ^google/cloud/firestore_v1/py.typed + - ^google/cloud/firestore_admin_v1/services + - ^google/cloud/firestore_admin_v1/types + - ^google/cloud/firestore_admin_v1/gapic + - ^google/cloud/firestore_admin_v1/py.typed + - ^google/cloud/firestore_bundle/services + - ^google/cloud/firestore_bundle/types + - ^google/cloud/firestore_bundle/__init__.py + - ^google/cloud/firestore_bundle/gapic + - ^google/cloud/firestore_bundle/py.typed + - ^tests/unit/gapic + - ^tests/__init__.py + - ^tests/unit/__init__.py + - ^.pre-commit-config.yaml + - ^.repo-metadata.json + - ^.trampolinerc + - ^.coveragerc + - ^SECURITY.md + - ^noxfile.py + - ^owlbot.py + - ^samples/AUTHORING_GUIDE.md + - ^samples/CONTRIBUTING.md + - ^samples/generated_samples + - ^scripts/fixup_firestore_v1_keywords.py + - ^scripts/fixup_firestore_admin_v1_keywords.py + - ^setup.py + - ^README.rst + - ^docs/README.rst + - ^docs/summary_overview.md + tag_format: v{version} diff --git a/CHANGELOG.md b/CHANGELOG.md index 893a01297..ee59f43a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,22 @@ [1]: https://pypi.org/project/google-cloud-firestore/#history +## [2.22.0](https://github.com/googleapis/python-firestore/compare/v2.21.0...v2.22.0) (2025-12-16) + + +### Features + +* support mTLS certificates when available (#1140) ([403afb08109c8271eddd97d6172136271cc0a8a9](https://github.com/googleapis/python-firestore/commit/403afb08109c8271eddd97d6172136271cc0a8a9)) +* Add support for Python 3.14 (#1110) ([52b2055d01ab5d2c34e00f8861e29990f89cd3d8](https://github.com/googleapis/python-firestore/commit/52b2055d01ab5d2c34e00f8861e29990f89cd3d8)) +* Expose tags field in Database and RestoreDatabaseRequest public protos (#1074) ([49836391dc712bd482781a26ccd3c8a8408c473b](https://github.com/googleapis/python-firestore/commit/49836391dc712bd482781a26ccd3c8a8408c473b)) +* Added read_time as a parameter to various calls (synchronous/base classes) (#1050) ([d8e3af1f9dbdfaf5df0d993a0a7e28883472c621](https://github.com/googleapis/python-firestore/commit/d8e3af1f9dbdfaf5df0d993a0a7e28883472c621)) + + +### Bug Fixes + +* improve typing (#1136) ([d1c730d9eef867d16d7818a75f7d58439a942c1d](https://github.com/googleapis/python-firestore/commit/d1c730d9eef867d16d7818a75f7d58439a942c1d)) +* update the async transactional types (#1066) ([210a14a4b758d70aad05940665ed2a2a21ae2a8b](https://github.com/googleapis/python-firestore/commit/210a14a4b758d70aad05940665ed2a2a21ae2a8b)) + ## [2.21.0](https://github.com/googleapis/python-firestore/compare/v2.20.2...v2.21.0) (2025-05-23) diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index c91768524..b59294006 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -246,6 +246,7 @@ We support: .. _Python 3.10: https://docs.python.org/3.10/ .. _Python 3.11: https://docs.python.org/3.11/ .. _Python 3.12: https://docs.python.org/3.12/ +.. _Python 3.13: https://docs.python.org/3.13/ .. _Python 3.14: https://docs.python.org/3.14/ diff --git a/README.rst b/README.rst index e349bf783..71250f4f7 100644 --- a/README.rst +++ b/README.rst @@ -61,7 +61,7 @@ Supported Python Versions Our client libraries are compatible with all current `active`_ and `maintenance`_ versions of Python. -Python >= 3.7 +Python >= 3.7, including 3.14 .. _active: https://devguide.python.org/devcycle/#in-development-main-branch .. _maintenance: https://devguide.python.org/devcycle/#maintenance-branches diff --git a/docs/README.rst b/docs/README.rst deleted file mode 120000 index 89a010694..000000000 --- a/docs/README.rst +++ /dev/null @@ -1 +0,0 @@ -../README.rst \ No newline at end of file diff --git a/docs/README.rst b/docs/README.rst new file mode 100644 index 000000000..71250f4f7 --- /dev/null +++ b/docs/README.rst @@ -0,0 +1,197 @@ +Python Client for Cloud Firestore API +===================================== + +|stable| |pypi| |versions| + +`Cloud Firestore API`_: is a fully-managed NoSQL document database for mobile, web, and server development from Firebase and Google Cloud Platform. It's backed by a multi-region replicated database that ensures once data is committed, it's durable even in the face of unexpected disasters. Not only that, but despite being a distributed database, it's also strongly consistent and offers seamless integration with other Firebase and Google Cloud Platform products, including Google Cloud Functions. + +- `Client Library Documentation`_ +- `Product Documentation`_ + +.. |stable| image:: https://img.shields.io/badge/support-stable-gold.svg + :target: https://github.com/googleapis/google-cloud-python/blob/main/README.rst#stability-levels +.. |pypi| image:: https://img.shields.io/pypi/v/google-cloud-firestore.svg + :target: https://pypi.org/project/google-cloud-firestore/ +.. |versions| image:: https://img.shields.io/pypi/pyversions/google-cloud-firestore.svg + :target: https://pypi.org/project/google-cloud-firestore/ +.. _Cloud Firestore API: https://cloud.google.com/firestore +.. _Client Library Documentation: https://cloud.google.com/python/docs/reference/firestore/latest/summary_overview +.. _Product Documentation: https://cloud.google.com/firestore + +Quick Start +----------- + +In order to use this library, you first need to go through the following steps: + +1. `Select or create a Cloud Platform project.`_ +2. `Enable billing for your project.`_ +3. `Enable the Cloud Firestore API.`_ +4. `Set up Authentication.`_ + +.. _Select or create a Cloud Platform project.: https://console.cloud.google.com/project +.. _Enable billing for your project.: https://cloud.google.com/billing/docs/how-to/modify-project#enable_billing_for_a_project +.. _Enable the Cloud Firestore API.: https://cloud.google.com/firestore +.. _Set up Authentication.: https://googleapis.dev/python/google-api-core/latest/auth.html + +Installation +~~~~~~~~~~~~ + +Install this library in a virtual environment using `venv`_. `venv`_ is a tool that +creates isolated Python environments. These isolated environments can have separate +versions of Python packages, which allows you to isolate one project's dependencies +from the dependencies of other projects. + +With `venv`_, it's possible to install this library without needing system +install permissions, and without clashing with the installed system +dependencies. + +.. _`venv`: https://docs.python.org/3/library/venv.html + + +Code samples and snippets +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Code samples and snippets live in the `samples/`_ folder. + +.. _samples/: https://github.com/googleapis/python-firestore/tree/main/samples + + +Supported Python Versions +^^^^^^^^^^^^^^^^^^^^^^^^^ +Our client libraries are compatible with all current `active`_ and `maintenance`_ versions of +Python. + +Python >= 3.7, including 3.14 + +.. _active: https://devguide.python.org/devcycle/#in-development-main-branch +.. _maintenance: https://devguide.python.org/devcycle/#maintenance-branches + +Unsupported Python Versions +^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Python <= 3.6 + +If you are using an `end-of-life`_ +version of Python, we recommend that you update as soon as possible to an actively supported version. + +.. _end-of-life: https://devguide.python.org/devcycle/#end-of-life-branches + +Mac/Linux +^^^^^^^^^ + +.. code-block:: console + + python3 -m venv + source /bin/activate + pip install google-cloud-firestore + + +Windows +^^^^^^^ + +.. code-block:: console + + py -m venv + .\\Scripts\activate + pip install google-cloud-firestore + +Next Steps +~~~~~~~~~~ + +- Read the `Client Library Documentation`_ for Cloud Firestore API + to see other available methods on the client. +- Read the `Cloud Firestore API Product documentation`_ to learn + more about the product and see How-to Guides. +- View this `README`_ to see the full list of Cloud + APIs that we cover. + +.. _Cloud Firestore API Product documentation: https://cloud.google.com/firestore +.. _README: https://github.com/googleapis/google-cloud-python/blob/main/README.rst + +Logging +------- + +This library uses the standard Python :code:`logging` functionality to log some RPC events that could be of interest for debugging and monitoring purposes. +Note the following: + +#. Logs may contain sensitive information. Take care to **restrict access to the logs** if they are saved, whether it be on local storage or on Google Cloud Logging. +#. Google may refine the occurrence, level, and content of various log messages in this library without flagging such changes as breaking. **Do not depend on immutability of the logging events**. +#. By default, the logging events from this library are not handled. You must **explicitly configure log handling** using one of the mechanisms below. + +Simple, environment-based configuration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To enable logging for this library without any changes in your code, set the :code:`GOOGLE_SDK_PYTHON_LOGGING_SCOPE` environment variable to a valid Google +logging scope. This configures handling of logging events (at level :code:`logging.DEBUG` or higher) from this library in a default manner, emitting the logged +messages in a structured format. It does not currently allow customizing the logging levels captured nor the handlers, formatters, etc. used for any logging +event. + +A logging scope is a period-separated namespace that begins with :code:`google`, identifying the Python module or package to log. + +- Valid logging scopes: :code:`google`, :code:`google.cloud.asset.v1`, :code:`google.api`, :code:`google.auth`, etc. +- Invalid logging scopes: :code:`foo`, :code:`123`, etc. + +**NOTE**: If the logging scope is invalid, the library does not set up any logging handlers. + +Environment-Based Examples +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +- Enabling the default handler for all Google-based loggers + +.. code-block:: console + + export GOOGLE_SDK_PYTHON_LOGGING_SCOPE=google + +- Enabling the default handler for a specific Google module (for a client library called :code:`library_v1`): + +.. code-block:: console + + export GOOGLE_SDK_PYTHON_LOGGING_SCOPE=google.cloud.library_v1 + + +Advanced, code-based configuration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can also configure a valid logging scope using Python's standard `logging` mechanism. + +Code-Based Examples +^^^^^^^^^^^^^^^^^^^ + +- Configuring a handler for all Google-based loggers + +.. code-block:: python + + import logging + + from google.cloud import library_v1 + + base_logger = logging.getLogger("google") + base_logger.addHandler(logging.StreamHandler()) + base_logger.setLevel(logging.DEBUG) + +- Configuring a handler for a specific Google module (for a client library called :code:`library_v1`): + +.. code-block:: python + + import logging + + from google.cloud import library_v1 + + base_logger = logging.getLogger("google.cloud.library_v1") + base_logger.addHandler(logging.StreamHandler()) + base_logger.setLevel(logging.DEBUG) + +Logging details +~~~~~~~~~~~~~~~ + +#. Regardless of which of the mechanisms above you use to configure logging for this library, by default logging events are not propagated up to the root + logger from the `google`-level logger. If you need the events to be propagated to the root logger, you must explicitly set + :code:`logging.getLogger("google").propagate = True` in your code. +#. You can mix the different logging configurations above for different Google modules. For example, you may want use a code-based logging configuration for + one library, but decide you need to also set up environment-based logging configuration for another library. + + #. If you attempt to use both code-based and environment-based configuration for the same module, the environment-based configuration will be ineffectual + if the code -based configuration gets applied first. + +#. The Google-specific logging configurations (default handlers for environment-based configuration; not propagating logging events to the root logger) get + executed the first time *any* client library is instantiated in your application, and only if the affected loggers have not been previously configured. + (This is the reason for 2.i. above.) diff --git a/google/cloud/firestore/gapic_version.py b/google/cloud/firestore/gapic_version.py index e546bae05..03d6d0200 100644 --- a/google/cloud/firestore/gapic_version.py +++ b/google/cloud/firestore/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.21.0" # {x-release-please-version} +__version__ = "2.22.0" # {x-release-please-version} diff --git a/google/cloud/firestore_admin_v1/gapic_metadata.json b/google/cloud/firestore_admin_v1/gapic_metadata.json index e2c91bdb5..b8d4cb298 100644 --- a/google/cloud/firestore_admin_v1/gapic_metadata.json +++ b/google/cloud/firestore_admin_v1/gapic_metadata.json @@ -15,6 +15,11 @@ "bulk_delete_documents" ] }, + "CloneDatabase": { + "methods": [ + "clone_database" + ] + }, "CreateBackupSchedule": { "methods": [ "create_backup_schedule" @@ -175,6 +180,11 @@ "bulk_delete_documents" ] }, + "CloneDatabase": { + "methods": [ + "clone_database" + ] + }, "CreateBackupSchedule": { "methods": [ "create_backup_schedule" @@ -335,6 +345,11 @@ "bulk_delete_documents" ] }, + "CloneDatabase": { + "methods": [ + "clone_database" + ] + }, "CreateBackupSchedule": { "methods": [ "create_backup_schedule" diff --git a/google/cloud/firestore_admin_v1/gapic_version.py b/google/cloud/firestore_admin_v1/gapic_version.py index e546bae05..ced4e0faf 100644 --- a/google/cloud/firestore_admin_v1/gapic_version.py +++ b/google/cloud/firestore_admin_v1/gapic_version.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.21.0" # {x-release-please-version} +__version__ = "2.22.0" # {x-release-please-version} diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/async_client.py b/google/cloud/firestore_admin_v1/services/firestore_admin/async_client.py index 56531fa29..a2800e34e 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/async_client.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/async_client.py @@ -4111,6 +4111,143 @@ async def sample_delete_backup_schedule(): metadata=metadata, ) + async def clone_database( + self, + request: Optional[Union[firestore_admin.CloneDatabaseRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates a new database by cloning an existing one. + + The new database must be in the same cloud region or + multi-region location as the existing database. This behaves + similar to + [FirestoreAdmin.CreateDatabase][google.firestore.admin.v1.FirestoreAdmin.CreateDatabase] + except instead of creating a new empty database, a new database + is created with the database type, index configuration, and + documents from an existing database. + + The [long-running operation][google.longrunning.Operation] can + be used to track the progress of the clone, with the Operation's + [metadata][google.longrunning.Operation.metadata] field type + being the + [CloneDatabaseMetadata][google.firestore.admin.v1.CloneDatabaseMetadata]. + The [response][google.longrunning.Operation.response] type is + the [Database][google.firestore.admin.v1.Database] if the clone + was successful. The new database is not readable or writeable + until the LRO has completed. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import firestore_admin_v1 + + async def sample_clone_database(): + # Create a client + client = firestore_admin_v1.FirestoreAdminAsyncClient() + + # Initialize request argument(s) + pitr_snapshot = firestore_admin_v1.PitrSnapshot() + pitr_snapshot.database = "database_value" + + request = firestore_admin_v1.CloneDatabaseRequest( + parent="parent_value", + database_id="database_id_value", + pitr_snapshot=pitr_snapshot, + ) + + # Make the request + operation = client.clone_database(request=request) + + print("Waiting for operation to complete...") + + response = (await operation).result() + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.cloud.firestore_admin_v1.types.CloneDatabaseRequest, dict]]): + The request object. The request message for + [FirestoreAdmin.CloneDatabase][google.firestore.admin.v1.FirestoreAdmin.CloneDatabase]. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.firestore_admin_v1.types.Database` + A Cloud Firestore Database. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, firestore_admin.CloneDatabaseRequest): + request = firestore_admin.CloneDatabaseRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.clone_database + ] + + header_params = {} + + routing_param_regex = re.compile("^projects/(?P[^/]+)(?:/.*)?$") + regex_match = routing_param_regex.match(request.pitr_snapshot.database) + if regex_match and regex_match.group("project_id"): + header_params["project_id"] = regex_match.group("project_id") + + routing_param_regex = re.compile( + "^projects/[^/]+/databases/(?P[^/]+)(?:/.*)?$" + ) + regex_match = routing_param_regex.match(request.pitr_snapshot.database) + if regex_match and regex_match.group("database_id"): + header_params["database_id"] = regex_match.group("database_id") + + if header_params: + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(header_params), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + database.Database, + metadata_type=gfa_operation.CloneDatabaseMetadata, + ) + + # Done; return the response. + return response + async def list_operations( self, request: Optional[operations_pb2.ListOperationsRequest] = None, diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/client.py b/google/cloud/firestore_admin_v1/services/firestore_admin/client.py index d05b82787..b55c157cf 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/client.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/client.py @@ -198,6 +198,34 @@ def _get_default_mtls_endpoint(api_endpoint): _DEFAULT_ENDPOINT_TEMPLATE = "firestore.{UNIVERSE_DOMAIN}" _DEFAULT_UNIVERSE = "googleapis.com" + @staticmethod + def _use_client_cert_effective(): + """Returns whether client certificate should be used for mTLS if the + google-auth version supports should_use_client_cert automatic mTLS enablement. + + Alternatively, read from the GOOGLE_API_USE_CLIENT_CERTIFICATE env var. + + Returns: + bool: whether client certificate should be used for mTLS + Raises: + ValueError: (If using a version of google-auth without should_use_client_cert and + GOOGLE_API_USE_CLIENT_CERTIFICATE is set to an unexpected value.) + """ + # check if google-auth version supports should_use_client_cert for automatic mTLS enablement + if hasattr(mtls, "should_use_client_cert"): # pragma: NO COVER + return mtls.should_use_client_cert() + else: # pragma: NO COVER + # if unsupported, fallback to reading from env var + use_client_cert_str = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + if use_client_cert_str not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be" + " either `true` or `false`" + ) + return use_client_cert_str == "true" + @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): """Creates an instance of this client using the provided credentials @@ -555,12 +583,8 @@ def get_mtls_endpoint_and_cert_source( ) if client_options is None: client_options = client_options_lib.ClientOptions() - use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_client_cert = FirestoreAdminClient._use_client_cert_effective() use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_client_cert not in ("true", "false"): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) if use_mtls_endpoint not in ("auto", "never", "always"): raise MutualTLSChannelError( "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" @@ -568,7 +592,7 @@ def get_mtls_endpoint_and_cert_source( # Figure out the client cert source to use. client_cert_source = None - if use_client_cert == "true": + if use_client_cert: if client_options.client_cert_source: client_cert_source = client_options.client_cert_source elif mtls.has_default_client_cert_source(): @@ -600,20 +624,14 @@ def _read_environment_variables(): google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT is not any of ["auto", "never", "always"]. """ - use_client_cert = os.getenv( - "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" - ).lower() + use_client_cert = FirestoreAdminClient._use_client_cert_effective() use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower() universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN") - if use_client_cert not in ("true", "false"): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) if use_mtls_endpoint not in ("auto", "never", "always"): raise MutualTLSChannelError( "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" ) - return use_client_cert == "true", use_mtls_endpoint, universe_domain_env + return use_client_cert, use_mtls_endpoint, universe_domain_env @staticmethod def _get_client_cert_source(provided_cert_source, use_cert_flag): @@ -4591,6 +4609,141 @@ def sample_delete_backup_schedule(): metadata=metadata, ) + def clone_database( + self, + request: Optional[Union[firestore_admin.CloneDatabaseRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> gac_operation.Operation: + r"""Creates a new database by cloning an existing one. + + The new database must be in the same cloud region or + multi-region location as the existing database. This behaves + similar to + [FirestoreAdmin.CreateDatabase][google.firestore.admin.v1.FirestoreAdmin.CreateDatabase] + except instead of creating a new empty database, a new database + is created with the database type, index configuration, and + documents from an existing database. + + The [long-running operation][google.longrunning.Operation] can + be used to track the progress of the clone, with the Operation's + [metadata][google.longrunning.Operation.metadata] field type + being the + [CloneDatabaseMetadata][google.firestore.admin.v1.CloneDatabaseMetadata]. + The [response][google.longrunning.Operation.response] type is + the [Database][google.firestore.admin.v1.Database] if the clone + was successful. The new database is not readable or writeable + until the LRO has completed. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import firestore_admin_v1 + + def sample_clone_database(): + # Create a client + client = firestore_admin_v1.FirestoreAdminClient() + + # Initialize request argument(s) + pitr_snapshot = firestore_admin_v1.PitrSnapshot() + pitr_snapshot.database = "database_value" + + request = firestore_admin_v1.CloneDatabaseRequest( + parent="parent_value", + database_id="database_id_value", + pitr_snapshot=pitr_snapshot, + ) + + # Make the request + operation = client.clone_database(request=request) + + print("Waiting for operation to complete...") + + response = operation.result() + + # Handle the response + print(response) + + Args: + request (Union[google.cloud.firestore_admin_v1.types.CloneDatabaseRequest, dict]): + The request object. The request message for + [FirestoreAdmin.CloneDatabase][google.firestore.admin.v1.FirestoreAdmin.CloneDatabase]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.firestore_admin_v1.types.Database` + A Cloud Firestore Database. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, firestore_admin.CloneDatabaseRequest): + request = firestore_admin.CloneDatabaseRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.clone_database] + + header_params = {} + + routing_param_regex = re.compile("^projects/(?P[^/]+)(?:/.*)?$") + regex_match = routing_param_regex.match(request.pitr_snapshot.database) + if regex_match and regex_match.group("project_id"): + header_params["project_id"] = regex_match.group("project_id") + + routing_param_regex = re.compile( + "^projects/[^/]+/databases/(?P[^/]+)(?:/.*)?$" + ) + regex_match = routing_param_regex.match(request.pitr_snapshot.database) + if regex_match and regex_match.group("database_id"): + header_params["database_id"] = regex_match.group("database_id") + + if header_params: + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(header_params), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + database.Database, + metadata_type=gfa_operation.CloneDatabaseMetadata, + ) + + # Done; return the response. + return response + def __enter__(self) -> "FirestoreAdminClient": return self diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/base.py b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/base.py index f290fcbfe..7d582d9b5 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/base.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/base.py @@ -81,9 +81,10 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can + credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. + This argument is mutually exclusive with credentials. This argument will be + removed in the next major version of this library. scopes (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. @@ -357,6 +358,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + self.clone_database: gapic_v1.method.wrap_method( + self.clone_database, + default_timeout=120.0, + client_info=client_info, + ), self.cancel_operation: gapic_v1.method.wrap_method( self.cancel_operation, default_timeout=None, @@ -688,6 +694,15 @@ def delete_backup_schedule( ]: raise NotImplementedError() + @property + def clone_database( + self, + ) -> Callable[ + [firestore_admin.CloneDatabaseRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + @property def list_operations( self, diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc.py b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc.py index c6e7824c2..f6531a190 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc.py @@ -192,9 +192,10 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if a ``channel`` instance is provided. - credentials_file (Optional[str]): A file with credentials that can + credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. This argument is ignored if a ``channel`` instance is provided. + This argument will be removed in the next major version of this library. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if a ``channel`` instance is provided. channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): @@ -328,9 +329,10 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can + credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. + This argument is mutually exclusive with credentials. This argument will be + removed in the next major version of this library. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -1279,6 +1281,50 @@ def delete_backup_schedule( ) return self._stubs["delete_backup_schedule"] + @property + def clone_database( + self, + ) -> Callable[[firestore_admin.CloneDatabaseRequest], operations_pb2.Operation]: + r"""Return a callable for the clone database method over gRPC. + + Creates a new database by cloning an existing one. + + The new database must be in the same cloud region or + multi-region location as the existing database. This behaves + similar to + [FirestoreAdmin.CreateDatabase][google.firestore.admin.v1.FirestoreAdmin.CreateDatabase] + except instead of creating a new empty database, a new database + is created with the database type, index configuration, and + documents from an existing database. + + The [long-running operation][google.longrunning.Operation] can + be used to track the progress of the clone, with the Operation's + [metadata][google.longrunning.Operation.metadata] field type + being the + [CloneDatabaseMetadata][google.firestore.admin.v1.CloneDatabaseMetadata]. + The [response][google.longrunning.Operation.response] type is + the [Database][google.firestore.admin.v1.Database] if the clone + was successful. The new database is not readable or writeable + until the LRO has completed. + + Returns: + Callable[[~.CloneDatabaseRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "clone_database" not in self._stubs: + self._stubs["clone_database"] = self._logged_channel.unary_unary( + "/google.firestore.admin.v1.FirestoreAdmin/CloneDatabase", + request_serializer=firestore_admin.CloneDatabaseRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["clone_database"] + def close(self): self._logged_channel.close() diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc_asyncio.py b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc_asyncio.py index 9dd9d6155..117707853 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc_asyncio.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/grpc_asyncio.py @@ -189,8 +189,9 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. + credentials_file (Optional[str]): Deprecated. A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. This argument will be + removed in the next major version of this library. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -241,9 +242,10 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if a ``channel`` instance is provided. - credentials_file (Optional[str]): A file with credentials that can + credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. This argument is ignored if a ``channel`` instance is provided. + This argument will be removed in the next major version of this library. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -1331,6 +1333,52 @@ def delete_backup_schedule( ) return self._stubs["delete_backup_schedule"] + @property + def clone_database( + self, + ) -> Callable[ + [firestore_admin.CloneDatabaseRequest], Awaitable[operations_pb2.Operation] + ]: + r"""Return a callable for the clone database method over gRPC. + + Creates a new database by cloning an existing one. + + The new database must be in the same cloud region or + multi-region location as the existing database. This behaves + similar to + [FirestoreAdmin.CreateDatabase][google.firestore.admin.v1.FirestoreAdmin.CreateDatabase] + except instead of creating a new empty database, a new database + is created with the database type, index configuration, and + documents from an existing database. + + The [long-running operation][google.longrunning.Operation] can + be used to track the progress of the clone, with the Operation's + [metadata][google.longrunning.Operation.metadata] field type + being the + [CloneDatabaseMetadata][google.firestore.admin.v1.CloneDatabaseMetadata]. + The [response][google.longrunning.Operation.response] type is + the [Database][google.firestore.admin.v1.Database] if the clone + was successful. The new database is not readable or writeable + until the LRO has completed. + + Returns: + Callable[[~.CloneDatabaseRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "clone_database" not in self._stubs: + self._stubs["clone_database"] = self._logged_channel.unary_unary( + "/google.firestore.admin.v1.FirestoreAdmin/CloneDatabase", + request_serializer=firestore_admin.CloneDatabaseRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["clone_database"] + def _prep_wrapped_messages(self, client_info): """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" self._wrapped_methods = { @@ -1544,6 +1592,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + self.clone_database: self._wrap_method( + self.clone_database, + default_timeout=120.0, + client_info=client_info, + ), self.cancel_operation: self._wrap_method( self.cancel_operation, default_timeout=None, diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest.py b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest.py index c96be2e32..41e819c87 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest.py @@ -97,6 +97,14 @@ def post_bulk_delete_documents(self, response): logging.log(f"Received response: {response}") return response + def pre_clone_database(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_clone_database(self, response): + logging.log(f"Received response: {response}") + return response + def pre_create_backup_schedule(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata @@ -376,6 +384,54 @@ def post_bulk_delete_documents_with_metadata( """ return response, metadata + def pre_clone_database( + self, + request: firestore_admin.CloneDatabaseRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + firestore_admin.CloneDatabaseRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for clone_database + + Override in a subclass to manipulate the request or metadata + before they are sent to the FirestoreAdmin server. + """ + return request, metadata + + def post_clone_database( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for clone_database + + DEPRECATED. Please use the `post_clone_database_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response + after it is returned by the FirestoreAdmin server but before + it is returned to user code. This `post_clone_database` interceptor runs + before the `post_clone_database_with_metadata` interceptor. + """ + return response + + def post_clone_database_with_metadata( + self, + response: operations_pb2.Operation, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[operations_pb2.Operation, Sequence[Tuple[str, Union[str, bytes]]]]: + """Post-rpc interceptor for clone_database + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the FirestoreAdmin server but before it is returned to user code. + + We recommend only using this `post_clone_database_with_metadata` + interceptor in new development instead of the `post_clone_database` interceptor. + When both interceptors are used, this `post_clone_database_with_metadata` interceptor runs after the + `post_clone_database` interceptor. The (possibly modified) response returned by + `post_clone_database` will be passed to + `post_clone_database_with_metadata`. + """ + return response, metadata + def pre_create_backup_schedule( self, request: firestore_admin.CreateBackupScheduleRequest, @@ -1857,9 +1913,10 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can + credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. + This argument is ignored if ``channel`` is provided. This argument will be + removed in the next major version of this library. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client @@ -2115,6 +2172,158 @@ def __call__( ) return resp + class _CloneDatabase( + _BaseFirestoreAdminRestTransport._BaseCloneDatabase, FirestoreAdminRestStub + ): + def __hash__(self): + return hash("FirestoreAdminRestTransport.CloneDatabase") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: firestore_admin.CloneDatabaseRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Call the clone database method over HTTP. + + Args: + request (~.firestore_admin.CloneDatabaseRequest): + The request object. The request message for + [FirestoreAdmin.CloneDatabase][google.firestore.admin.v1.FirestoreAdmin.CloneDatabase]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options = ( + _BaseFirestoreAdminRestTransport._BaseCloneDatabase._get_http_options() + ) + + request, metadata = self._interceptor.pre_clone_database(request, metadata) + transcoded_request = _BaseFirestoreAdminRestTransport._BaseCloneDatabase._get_transcoded_request( + http_options, request + ) + + body = _BaseFirestoreAdminRestTransport._BaseCloneDatabase._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseFirestoreAdminRestTransport._BaseCloneDatabase._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.firestore.admin_v1.FirestoreAdminClient.CloneDatabase", + extra={ + "serviceName": "google.firestore.admin.v1.FirestoreAdmin", + "rpcName": "CloneDatabase", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = FirestoreAdminRestTransport._CloneDatabase._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_clone_database(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_clone_database_with_metadata( + resp, response_metadata + ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore.admin_v1.FirestoreAdminClient.clone_database", + extra={ + "serviceName": "google.firestore.admin.v1.FirestoreAdmin", + "rpcName": "CloneDatabase", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + class _CreateBackupSchedule( _BaseFirestoreAdminRestTransport._BaseCreateBackupSchedule, FirestoreAdminRestStub, @@ -6507,6 +6716,14 @@ def bulk_delete_documents( # In C++ this would require a dynamic_cast return self._BulkDeleteDocuments(self._session, self._host, self._interceptor) # type: ignore + @property + def clone_database( + self, + ) -> Callable[[firestore_admin.CloneDatabaseRequest], operations_pb2.Operation]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CloneDatabase(self._session, self._host, self._interceptor) # type: ignore + @property def create_backup_schedule( self, diff --git a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest_base.py b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest_base.py index 19a0c9856..56b6ce93f 100644 --- a/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest_base.py +++ b/google/cloud/firestore_admin_v1/services/firestore_admin/transports/rest_base.py @@ -156,6 +156,63 @@ def _get_query_params_json(transcoded_request): query_params["$alt"] = "json;enum-encoding=int" return query_params + class _BaseCloneDatabase: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{parent=projects/*}/databases:clone", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore_admin.CloneDatabaseRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseFirestoreAdminRestTransport._BaseCloneDatabase._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + class _BaseCreateBackupSchedule: def __hash__(self): # pragma: NO COVER return NotImplementedError("__hash__ must be implemented.") diff --git a/google/cloud/firestore_admin_v1/types/__init__.py b/google/cloud/firestore_admin_v1/types/__init__.py index 249147d52..c76372e5d 100644 --- a/google/cloud/firestore_admin_v1/types/__init__.py +++ b/google/cloud/firestore_admin_v1/types/__init__.py @@ -25,6 +25,7 @@ from .firestore_admin import ( BulkDeleteDocumentsRequest, BulkDeleteDocumentsResponse, + CloneDatabaseRequest, CreateBackupScheduleRequest, CreateDatabaseMetadata, CreateDatabaseRequest, @@ -73,6 +74,7 @@ ) from .operation import ( BulkDeleteDocumentsMetadata, + CloneDatabaseMetadata, ExportDocumentsMetadata, ExportDocumentsResponse, FieldOperationMetadata, @@ -87,6 +89,9 @@ DailyRecurrence, WeeklyRecurrence, ) +from .snapshot import ( + PitrSnapshot, +) from .user_creds import ( UserCreds, ) @@ -97,6 +102,7 @@ "Field", "BulkDeleteDocumentsRequest", "BulkDeleteDocumentsResponse", + "CloneDatabaseRequest", "CreateBackupScheduleRequest", "CreateDatabaseMetadata", "CreateDatabaseRequest", @@ -139,6 +145,7 @@ "Index", "LocationMetadata", "BulkDeleteDocumentsMetadata", + "CloneDatabaseMetadata", "ExportDocumentsMetadata", "ExportDocumentsResponse", "FieldOperationMetadata", @@ -150,5 +157,6 @@ "BackupSchedule", "DailyRecurrence", "WeeklyRecurrence", + "PitrSnapshot", "UserCreds", ) diff --git a/google/cloud/firestore_admin_v1/types/database.py b/google/cloud/firestore_admin_v1/types/database.py index eafa21df1..f46bede62 100644 --- a/google/cloud/firestore_admin_v1/types/database.py +++ b/google/cloud/firestore_admin_v1/types/database.py @@ -213,9 +213,9 @@ class PointInTimeRecoveryEnablement(proto.Enum): Reads are supported on selected versions of the data from within the past 7 days: - - Reads against any timestamp within the past hour - - Reads against 1-minute snapshots beyond 1 hour and within - 7 days + - Reads against any timestamp within the past hour + - Reads against 1-minute snapshots beyond 1 hour and within + 7 days ``version_retention_period`` and ``earliest_version_time`` can be used to determine the supported versions. diff --git a/google/cloud/firestore_admin_v1/types/firestore_admin.py b/google/cloud/firestore_admin_v1/types/firestore_admin.py index a4b577b78..9ede35cac 100644 --- a/google/cloud/firestore_admin_v1/types/firestore_admin.py +++ b/google/cloud/firestore_admin_v1/types/firestore_admin.py @@ -24,6 +24,7 @@ from google.cloud.firestore_admin_v1.types import field as gfa_field from google.cloud.firestore_admin_v1.types import index as gfa_index from google.cloud.firestore_admin_v1.types import schedule +from google.cloud.firestore_admin_v1.types import snapshot from google.cloud.firestore_admin_v1.types import user_creds as gfa_user_creds from google.protobuf import field_mask_pb2 # type: ignore from google.protobuf import timestamp_pb2 # type: ignore @@ -73,6 +74,7 @@ "ListBackupsResponse", "DeleteBackupRequest", "RestoreDatabaseRequest", + "CloneDatabaseRequest", }, ) @@ -951,7 +953,7 @@ class ListBackupsRequest(proto.Message): [Backup][google.firestore.admin.v1.Backup] are eligible for filtering: - - ``database_uid`` (supports ``=`` only) + - ``database_uid`` (supports ``=`` only) """ parent: str = proto.Field( @@ -1079,4 +1081,70 @@ class RestoreDatabaseRequest(proto.Message): ) +class CloneDatabaseRequest(proto.Message): + r"""The request message for + [FirestoreAdmin.CloneDatabase][google.firestore.admin.v1.FirestoreAdmin.CloneDatabase]. + + Attributes: + parent (str): + Required. The project to clone the database in. Format is + ``projects/{project_id}``. + database_id (str): + Required. The ID to use for the database, which will become + the final component of the database's resource name. This + database ID must not be associated with an existing + database. + + This value should be 4-63 characters. Valid characters are + /[a-z][0-9]-/ with first character a letter and the last a + letter or a number. Must not be UUID-like + /[0-9a-f]{8}(-[0-9a-f]{4}){3}-[0-9a-f]{12}/. + + "(default)" database ID is also valid. + pitr_snapshot (google.cloud.firestore_admin_v1.types.PitrSnapshot): + Required. Specification of the PITR data to + clone from. The source database must exist. + + The cloned database will be created in the same + location as the source database. + encryption_config (google.cloud.firestore_admin_v1.types.Database.EncryptionConfig): + Optional. Encryption configuration for the cloned database. + + If this field is not specified, the cloned database will use + the same encryption configuration as the source database, + namely + [use_source_encryption][google.firestore.admin.v1.Database.EncryptionConfig.use_source_encryption]. + tags (MutableMapping[str, str]): + Optional. Immutable. Tags to be bound to the cloned + database. + + The tags should be provided in the format of + ``tagKeys/{tag_key_id} -> tagValues/{tag_value_id}``. + """ + + parent: str = proto.Field( + proto.STRING, + number=1, + ) + database_id: str = proto.Field( + proto.STRING, + number=2, + ) + pitr_snapshot: snapshot.PitrSnapshot = proto.Field( + proto.MESSAGE, + number=6, + message=snapshot.PitrSnapshot, + ) + encryption_config: gfa_database.Database.EncryptionConfig = proto.Field( + proto.MESSAGE, + number=4, + message=gfa_database.Database.EncryptionConfig, + ) + tags: MutableMapping[str, str] = proto.MapField( + proto.STRING, + proto.STRING, + number=5, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/firestore_admin_v1/types/operation.py b/google/cloud/firestore_admin_v1/types/operation.py index c58f24273..c50455693 100644 --- a/google/cloud/firestore_admin_v1/types/operation.py +++ b/google/cloud/firestore_admin_v1/types/operation.py @@ -20,6 +20,7 @@ import proto # type: ignore from google.cloud.firestore_admin_v1.types import index as gfa_index +from google.cloud.firestore_admin_v1.types import snapshot from google.protobuf import timestamp_pb2 # type: ignore @@ -34,6 +35,7 @@ "BulkDeleteDocumentsMetadata", "ExportDocumentsResponse", "RestoreDatabaseMetadata", + "CloneDatabaseMetadata", "Progress", }, ) @@ -558,6 +560,60 @@ class RestoreDatabaseMetadata(proto.Message): ) +class CloneDatabaseMetadata(proto.Message): + r"""Metadata for the [long-running + operation][google.longrunning.Operation] from the + [CloneDatabase][google.firestore.admin.v1.CloneDatabase] request. + + Attributes: + start_time (google.protobuf.timestamp_pb2.Timestamp): + The time the clone was started. + end_time (google.protobuf.timestamp_pb2.Timestamp): + The time the clone finished, unset for + ongoing clones. + operation_state (google.cloud.firestore_admin_v1.types.OperationState): + The operation state of the clone. + database (str): + The name of the database being cloned to. + pitr_snapshot (google.cloud.firestore_admin_v1.types.PitrSnapshot): + The snapshot from which this database was + cloned. + progress_percentage (google.cloud.firestore_admin_v1.types.Progress): + How far along the clone is as an estimated + percentage of remaining time. + """ + + start_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=1, + message=timestamp_pb2.Timestamp, + ) + end_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=2, + message=timestamp_pb2.Timestamp, + ) + operation_state: "OperationState" = proto.Field( + proto.ENUM, + number=3, + enum="OperationState", + ) + database: str = proto.Field( + proto.STRING, + number=4, + ) + pitr_snapshot: snapshot.PitrSnapshot = proto.Field( + proto.MESSAGE, + number=7, + message=snapshot.PitrSnapshot, + ) + progress_percentage: "Progress" = proto.Field( + proto.MESSAGE, + number=6, + message="Progress", + ) + + class Progress(proto.Message): r"""Describes the progress of the operation. Unit of work is generic and must be interpreted based on where diff --git a/google/cloud/firestore_admin_v1/types/snapshot.py b/google/cloud/firestore_admin_v1/types/snapshot.py new file mode 100644 index 000000000..e56a125f5 --- /dev/null +++ b/google/cloud/firestore_admin_v1/types/snapshot.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.protobuf import timestamp_pb2 # type: ignore + + +__protobuf__ = proto.module( + package="google.firestore.admin.v1", + manifest={ + "PitrSnapshot", + }, +) + + +class PitrSnapshot(proto.Message): + r"""A consistent snapshot of a database at a specific point in + time. A PITR (Point-in-time recovery) snapshot with previous + versions of a database's data is available for every minute up + to the associated database's data retention period. If the PITR + feature is enabled, the retention period is 7 days; otherwise, + it is one hour. + + Attributes: + database (str): + Required. The name of the database that this was a snapshot + of. Format: ``projects/{project}/databases/{database}``. + database_uid (bytes): + Output only. Public UUID of the database the + snapshot was associated with. + snapshot_time (google.protobuf.timestamp_pb2.Timestamp): + Required. Snapshot time of the database. + """ + + database: str = proto.Field( + proto.STRING, + number=1, + ) + database_uid: bytes = proto.Field( + proto.BYTES, + number=2, + ) + snapshot_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=3, + message=timestamp_pb2.Timestamp, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/firestore_bundle/__init__.py b/google/cloud/firestore_bundle/__init__.py index 1b6469437..30faafe58 100644 --- a/google/cloud/firestore_bundle/__init__.py +++ b/google/cloud/firestore_bundle/__init__.py @@ -15,8 +15,18 @@ # from google.cloud.firestore_bundle import gapic_version as package_version +import google.api_core as api_core +import sys + __version__ = package_version.__version__ +if sys.version_info >= (3, 8): # pragma: NO COVER + from importlib import metadata +else: # pragma: NO COVER + # TODO(https://github.com/googleapis/python-api-core/issues/835): Remove + # this code path once we drop support for Python 3.7 + import importlib_metadata as metadata + from .types.bundle import BundledDocumentMetadata from .types.bundle import BundledQuery @@ -26,6 +36,100 @@ from .bundle import FirestoreBundle +if hasattr(api_core, "check_python_version") and hasattr( + api_core, "check_dependency_versions" +): # pragma: NO COVER + api_core.check_python_version("google.cloud.bundle") # type: ignore + api_core.check_dependency_versions("google.cloud.bundle") # type: ignore +else: # pragma: NO COVER + # An older version of api_core is installed which does not define the + # functions above. We do equivalent checks manually. + try: + import warnings + import sys + + _py_version_str = sys.version.split()[0] + _package_label = "google.cloud.bundle" + if sys.version_info < (3, 9): + warnings.warn( + "You are using a non-supported Python version " + + f"({_py_version_str}). Google will not post any further " + + f"updates to {_package_label} supporting this Python version. " + + "Please upgrade to the latest Python version, or at " + + f"least to Python 3.9, and then update {_package_label}.", + FutureWarning, + ) + if sys.version_info[:2] == (3, 9): + warnings.warn( + f"You are using a Python version ({_py_version_str}) " + + f"which Google will stop supporting in {_package_label} in " + + "January 2026. Please " + + "upgrade to the latest Python version, or at " + + "least to Python 3.10, before then, and " + + f"then update {_package_label}.", + FutureWarning, + ) + + def parse_version_to_tuple(version_string: str): + """Safely converts a semantic version string to a comparable tuple of integers. + Example: "4.25.8" -> (4, 25, 8) + Ignores non-numeric parts and handles common version formats. + Args: + version_string: Version string in the format "x.y.z" or "x.y.z" + Returns: + Tuple of integers for the parsed version string. + """ + parts = [] + for part in version_string.split("."): + try: + parts.append(int(part)) + except ValueError: + # If it's a non-numeric part (e.g., '1.0.0b1' -> 'b1'), stop here. + # This is a simplification compared to 'packaging.parse_version', but sufficient + # for comparing strictly numeric semantic versions. + break + return tuple(parts) + + def _get_version(dependency_name): + try: + version_string: str = metadata.version(dependency_name) + parsed_version = parse_version_to_tuple(version_string) + return (parsed_version, version_string) + except Exception: + # Catch exceptions from metadata.version() (e.g., PackageNotFoundError) + # or errors during parse_version_to_tuple + return (None, "--") + + _dependency_package = "google.protobuf" + _next_supported_version = "4.25.8" + _next_supported_version_tuple = (4, 25, 8) + _recommendation = " (we recommend 6.x)" + (_version_used, _version_used_string) = _get_version(_dependency_package) + if _version_used and _version_used < _next_supported_version_tuple: + warnings.warn( + f"Package {_package_label} depends on " + + f"{_dependency_package}, currently installed at version " + + f"{_version_used_string}. Future updates to " + + f"{_package_label} will require {_dependency_package} at " + + f"version {_next_supported_version} or higher{_recommendation}." + + " Please ensure " + + "that either (a) your Python environment doesn't pin the " + + f"version of {_dependency_package}, so that updates to " + + f"{_package_label} can require the higher version, or " + + "(b) you manually update your Python environment to use at " + + f"least version {_next_supported_version} of " + + f"{_dependency_package}.", + FutureWarning, + ) + except Exception: + warnings.warn( + "Could not determine the version of Python " + + "currently being used. To continue receiving " + + "updates for {_package_label}, ensure you are " + + "using a supported version of Python; see " + + "https://devguide.python.org/versions/" + ) + __all__ = ( "BundleElement", "BundleMetadata", diff --git a/google/cloud/firestore_bundle/gapic_version.py b/google/cloud/firestore_bundle/gapic_version.py index e546bae05..ced4e0faf 100644 --- a/google/cloud/firestore_bundle/gapic_version.py +++ b/google/cloud/firestore_bundle/gapic_version.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.21.0" # {x-release-please-version} +__version__ = "2.22.0" # {x-release-please-version} diff --git a/google/cloud/firestore_v1/aggregation.py b/google/cloud/firestore_v1/aggregation.py index 4070cd22b..69c4dc6bd 100644 --- a/google/cloud/firestore_v1/aggregation.py +++ b/google/cloud/firestore_v1/aggregation.py @@ -67,8 +67,7 @@ def get( messages. Args: - transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not @@ -77,8 +76,7 @@ def get( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. - explain_options - (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given @@ -155,16 +153,14 @@ def _make_stream( this method cannot be used (i.e. read-after-write is not allowed). Args: - transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. retry (Optional[google.api_core.retry.Retry]): Designation of what errors, if any, should be retried. Defaults to a system-specified policy. timeout (Optional[float]): The timeout for this request. Defaults to a system-specified value. - explain_options - (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given @@ -237,16 +233,14 @@ def stream( this method cannot be used (i.e. read-after-write is not allowed). Args: - transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. retry (Optional[google.api_core.retry.Retry]): Designation of what errors, if any, should be retried. Defaults to a system-specified policy. timeout (Optinal[float]): The timeout for this request. Defaults to a system-specified value. - explain_options - (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given diff --git a/google/cloud/firestore_v1/async_aggregation.py b/google/cloud/firestore_v1/async_aggregation.py index e273f514a..5825a06d8 100644 --- a/google/cloud/firestore_v1/async_aggregation.py +++ b/google/cloud/firestore_v1/async_aggregation.py @@ -63,8 +63,7 @@ async def get( This sends a ``RunAggregationQuery`` RPC and returns a list of aggregation results in the stream of ``RunAggregationQueryResponse`` messages. Args: - transaction - (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that this query will run in. If a ``transaction`` is used and it already has write operations added, this method cannot be used (i.e. read-after-write is not @@ -73,8 +72,7 @@ async def get( should be retried. Defaults to a system-specified policy. timeout (float): The timeout for this request. Defaults to a system-specified value. - explain_options - (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given @@ -134,8 +132,7 @@ async def _make_stream( system-specified policy. timeout (Optional[float]): The timeout for this request. Defaults to a system-specified value. - explain_options - (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given @@ -200,8 +197,7 @@ def stream( system-specified policy. timeout (Optional[float]): The timeout for this request. Defaults to a system-specified value. - explain_options - (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): + explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given diff --git a/google/cloud/firestore_v1/async_batch.py b/google/cloud/firestore_v1/async_batch.py index 689753fe9..f74ccacea 100644 --- a/google/cloud/firestore_v1/async_batch.py +++ b/google/cloud/firestore_v1/async_batch.py @@ -19,6 +19,7 @@ from google.api_core import retry_async as retries from google.cloud.firestore_v1.base_batch import BaseWriteBatch +from google.cloud.firestore_v1.types.write import WriteResult class AsyncWriteBatch(BaseWriteBatch): @@ -40,7 +41,7 @@ async def commit( self, retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, - ) -> list: + ) -> list[WriteResult]: """Commit the changes accumulated in this batch. Args: diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 3acbedc76..efc4a47c0 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -25,7 +25,15 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Iterable, + List, + Optional, + Union, +) from google.api_core import gapic_v1 from google.api_core import retry_async as retries @@ -40,6 +48,7 @@ from google.cloud.firestore_v1.async_transaction import AsyncTransaction from google.cloud.firestore_v1.base_client import _parse_batch_get # type: ignore from google.cloud.firestore_v1.base_client import _CLIENT_INFO, BaseClient, _path_helper +from google.cloud.firestore_v1.base_transaction import MAX_ATTEMPTS from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.services.firestore import ( async_client as firestore_client, @@ -412,7 +421,9 @@ def batch(self) -> AsyncWriteBatch: """ return AsyncWriteBatch(self) - def transaction(self, **kwargs) -> AsyncTransaction: + def transaction( + self, max_attempts: int = MAX_ATTEMPTS, read_only: bool = False + ) -> AsyncTransaction: """Get a transaction that uses this client. See :class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction` for @@ -428,7 +439,7 @@ def transaction(self, **kwargs) -> AsyncTransaction: :class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`: A transaction attached to this client. """ - return AsyncTransaction(self, **kwargs) + return AsyncTransaction(self, max_attempts=max_attempts, read_only=read_only) @property def _pipeline_cls(self): diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index cc99aa460..561111163 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -15,7 +15,7 @@ """Classes for representing collections for the Google Cloud Firestore API.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Tuple +from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Tuple, cast from google.api_core import gapic_v1 from google.api_core import retry_async as retries @@ -153,7 +153,8 @@ def document(self, document_id: str | None = None) -> AsyncDocumentReference: :class:`~google.cloud.firestore_v1.document.async_document.AsyncDocumentReference`: The child document. """ - return super(AsyncCollectionReference, self).document(document_id) + doc = super(AsyncCollectionReference, self).document(document_id) + return cast("AsyncDocumentReference", doc) async def list_documents( self, diff --git a/google/cloud/firestore_v1/base_batch.py b/google/cloud/firestore_v1/base_batch.py index b0d50f1f4..851c7849f 100644 --- a/google/cloud/firestore_v1/base_batch.py +++ b/google/cloud/firestore_v1/base_batch.py @@ -15,7 +15,7 @@ """Helpers for batch requests to the Google Cloud Firestore API.""" from __future__ import annotations import abc -from typing import Dict, Union +from typing import Any, Dict, Union # Types needed only for Type Hints from google.api_core import retry as retries @@ -67,7 +67,9 @@ def commit(self): write depend on the implementing class.""" raise NotImplementedError() - def create(self, reference: BaseDocumentReference, document_data: dict) -> None: + def create( + self, reference: BaseDocumentReference, document_data: dict[str, Any] + ) -> None: """Add a "change" to this batch to create a document. If the document given by ``reference`` already exists, then this @@ -120,7 +122,7 @@ def set( def update( self, reference: BaseDocumentReference, - field_updates: dict, + field_updates: dict[str, Any], option: _helpers.WriteOption | None = None, ) -> None: """Add a "change" to update a document. diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index 8c8b9532d..4ba8a7c06 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -58,7 +58,7 @@ DocumentSnapshot, ) from google.cloud.firestore_v1.base_query import BaseQuery -from google.cloud.firestore_v1.base_transaction import BaseTransaction +from google.cloud.firestore_v1.base_transaction import MAX_ATTEMPTS, BaseTransaction from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions from google.cloud.firestore_v1.field_path import render_field_path from google.cloud.firestore_v1.services.firestore import client as firestore_client @@ -500,7 +500,9 @@ def collections( def batch(self) -> BaseWriteBatch: raise NotImplementedError - def transaction(self, **kwargs) -> BaseTransaction: + def transaction( + self, max_attempts: int = MAX_ATTEMPTS, read_only: bool = False + ) -> BaseTransaction: raise NotImplementedError def pipeline(self) -> PipelineSource: diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index 567fe4d8a..070e54cc4 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -35,6 +35,7 @@ from google.api_core import retry as retries from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.base_document import BaseDocumentReference from google.cloud.firestore_v1.base_query import QueryType if TYPE_CHECKING: # pragma: NO COVER @@ -134,7 +135,7 @@ def _aggregation_query(self) -> BaseAggregationQuery: def _vector_query(self) -> BaseVectorQuery: raise NotImplementedError - def document(self, document_id: Optional[str] = None): + def document(self, document_id: Optional[str] = None) -> BaseDocumentReference: """Create a sub-document underneath the current collection. Args: diff --git a/google/cloud/firestore_v1/base_document.py b/google/cloud/firestore_v1/base_document.py index 517db20d3..fe6113bfc 100644 --- a/google/cloud/firestore_v1/base_document.py +++ b/google/cloud/firestore_v1/base_document.py @@ -418,7 +418,7 @@ def _client(self): return self._reference._client @property - def exists(self): + def exists(self) -> bool: """Existence flag. Indicates if the document existed at the time this snapshot @@ -430,7 +430,7 @@ def exists(self): return self._exists @property - def id(self): + def id(self) -> str: """The document identifier (within its collection). Returns: @@ -439,7 +439,7 @@ def id(self): return self._reference.id @property - def reference(self): + def reference(self) -> BaseDocumentReference: """Document reference corresponding to document that owns this data. Returns: diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index c23943b24..ba2ca176d 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -39,6 +39,7 @@ # Types needed only for Type Hints from google.cloud.firestore_v1.base_document import DocumentSnapshot +from google.cloud.firestore_v1.base_transaction import MAX_ATTEMPTS from google.cloud.firestore_v1.batch import WriteBatch from google.cloud.firestore_v1.collection import CollectionReference from google.cloud.firestore_v1.document import DocumentReference @@ -393,7 +394,9 @@ def batch(self) -> WriteBatch: """ return WriteBatch(self) - def transaction(self, **kwargs) -> Transaction: + def transaction( + self, max_attempts: int = MAX_ATTEMPTS, read_only: bool = False + ) -> Transaction: """Get a transaction that uses this client. See :class:`~google.cloud.firestore_v1.transaction.Transaction` for @@ -409,7 +412,7 @@ def transaction(self, **kwargs) -> Transaction: :class:`~google.cloud.firestore_v1.transaction.Transaction`: A transaction attached to this client. """ - return Transaction(self, **kwargs) + return Transaction(self, max_attempts=max_attempts, read_only=read_only) @property def _pipeline_cls(self): diff --git a/google/cloud/firestore_v1/document.py b/google/cloud/firestore_v1/document.py index 4e0132e49..4bb6399a7 100644 --- a/google/cloud/firestore_v1/document.py +++ b/google/cloud/firestore_v1/document.py @@ -169,7 +169,7 @@ def set( def update( self, - field_updates: dict, + field_updates: dict[str, Any], option: _helpers.WriteOption | None = None, retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, timeout: float | None = None, diff --git a/google/cloud/firestore_v1/gapic_metadata.json b/google/cloud/firestore_v1/gapic_metadata.json index d0462f964..03a6e428b 100644 --- a/google/cloud/firestore_v1/gapic_metadata.json +++ b/google/cloud/firestore_v1/gapic_metadata.json @@ -40,6 +40,11 @@ "delete_document" ] }, + "ExecutePipeline": { + "methods": [ + "execute_pipeline" + ] + }, "GetDocument": { "methods": [ "get_document" @@ -125,6 +130,11 @@ "delete_document" ] }, + "ExecutePipeline": { + "methods": [ + "execute_pipeline" + ] + }, "GetDocument": { "methods": [ "get_document" @@ -210,6 +220,11 @@ "delete_document" ] }, + "ExecutePipeline": { + "methods": [ + "execute_pipeline" + ] + }, "GetDocument": { "methods": [ "get_document" diff --git a/google/cloud/firestore_v1/gapic_version.py b/google/cloud/firestore_v1/gapic_version.py index e546bae05..ced4e0faf 100644 --- a/google/cloud/firestore_v1/gapic_version.py +++ b/google/cloud/firestore_v1/gapic_version.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "2.21.0" # {x-release-please-version} +__version__ = "2.22.0" # {x-release-please-version} diff --git a/google/cloud/firestore_v1/services/firestore/client.py b/google/cloud/firestore_v1/services/firestore/client.py index 49ea18d2a..e362896af 100644 --- a/google/cloud/firestore_v1/services/firestore/client.py +++ b/google/cloud/firestore_v1/services/firestore/client.py @@ -169,6 +169,34 @@ def _get_default_mtls_endpoint(api_endpoint): _DEFAULT_ENDPOINT_TEMPLATE = "firestore.{UNIVERSE_DOMAIN}" _DEFAULT_UNIVERSE = "googleapis.com" + @staticmethod + def _use_client_cert_effective(): + """Returns whether client certificate should be used for mTLS if the + google-auth version supports should_use_client_cert automatic mTLS enablement. + + Alternatively, read from the GOOGLE_API_USE_CLIENT_CERTIFICATE env var. + + Returns: + bool: whether client certificate should be used for mTLS + Raises: + ValueError: (If using a version of google-auth without should_use_client_cert and + GOOGLE_API_USE_CLIENT_CERTIFICATE is set to an unexpected value.) + """ + # check if google-auth version supports should_use_client_cert for automatic mTLS enablement + if hasattr(mtls, "should_use_client_cert"): # pragma: NO COVER + return mtls.should_use_client_cert() + else: # pragma: NO COVER + # if unsupported, fallback to reading from env var + use_client_cert_str = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + if use_client_cert_str not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be" + " either `true` or `false`" + ) + return use_client_cert_str == "true" + @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): """Creates an instance of this client using the provided credentials @@ -334,12 +362,8 @@ def get_mtls_endpoint_and_cert_source( ) if client_options is None: client_options = client_options_lib.ClientOptions() - use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_client_cert = FirestoreClient._use_client_cert_effective() use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_client_cert not in ("true", "false"): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) if use_mtls_endpoint not in ("auto", "never", "always"): raise MutualTLSChannelError( "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" @@ -347,7 +371,7 @@ def get_mtls_endpoint_and_cert_source( # Figure out the client cert source to use. client_cert_source = None - if use_client_cert == "true": + if use_client_cert: if client_options.client_cert_source: client_cert_source = client_options.client_cert_source elif mtls.has_default_client_cert_source(): @@ -379,20 +403,14 @@ def _read_environment_variables(): google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT is not any of ["auto", "never", "always"]. """ - use_client_cert = os.getenv( - "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" - ).lower() + use_client_cert = FirestoreClient._use_client_cert_effective() use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower() universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN") - if use_client_cert not in ("true", "false"): - raise ValueError( - "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) if use_mtls_endpoint not in ("auto", "never", "always"): raise MutualTLSChannelError( "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" ) - return use_client_cert == "true", use_mtls_endpoint, universe_domain_env + return use_client_cert, use_mtls_endpoint, universe_domain_env @staticmethod def _get_client_cert_source(provided_cert_source, use_cert_flag): diff --git a/google/cloud/firestore_v1/services/firestore/transports/base.py b/google/cloud/firestore_v1/services/firestore/transports/base.py index ffccd7f0d..905dded09 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/base.py @@ -75,9 +75,10 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can + credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. + This argument is mutually exclusive with credentials. This argument will be + removed in the next major version of this library. scopes (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. @@ -292,7 +293,19 @@ def _prep_wrapped_messages(self, client_info): ), self.execute_pipeline: gapic_v1.method.wrap_method( self.execute_pipeline, - default_timeout=None, + default_retry=retries.Retry( + initial=0.1, + maximum=60.0, + multiplier=1.3, + predicate=retries.if_exception_type( + core_exceptions.DeadlineExceeded, + core_exceptions.InternalServerError, + core_exceptions.ResourceExhausted, + core_exceptions.ServiceUnavailable, + ), + deadline=300.0, + ), + default_timeout=300.0, client_info=client_info, ), self.run_aggregation_query: gapic_v1.method.wrap_method( diff --git a/google/cloud/firestore_v1/services/firestore/transports/grpc.py b/google/cloud/firestore_v1/services/firestore/transports/grpc.py index 2a8f4caf9..f057d16e3 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/grpc.py +++ b/google/cloud/firestore_v1/services/firestore/transports/grpc.py @@ -164,9 +164,10 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if a ``channel`` instance is provided. - credentials_file (Optional[str]): A file with credentials that can + credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. This argument is ignored if a ``channel`` instance is provided. + This argument will be removed in the next major version of this library. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if a ``channel`` instance is provided. channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): @@ -299,9 +300,10 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can + credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. + This argument is mutually exclusive with credentials. This argument will be + removed in the next major version of this library. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. diff --git a/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py b/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py index 8801dc45a..cf6006672 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py +++ b/google/cloud/firestore_v1/services/firestore/transports/grpc_asyncio.py @@ -161,8 +161,9 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. + credentials_file (Optional[str]): Deprecated. A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. This argument will be + removed in the next major version of this library. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -213,9 +214,10 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if a ``channel`` instance is provided. - credentials_file (Optional[str]): A file with credentials that can + credentials_file (Optional[str]): Deprecated. A file with credentials that can be loaded with :func:`google.auth.load_credentials_from_file`. This argument is ignored if a ``channel`` instance is provided. + This argument will be removed in the next major version of this library. scopes (Optional[Sequence[str]]): A optional list of scopes needed for this service. These are only used when credentials are not specified and are passed to :func:`google.auth.default`. @@ -992,7 +994,19 @@ def _prep_wrapped_messages(self, client_info): ), self.execute_pipeline: self._wrap_method( self.execute_pipeline, - default_timeout=None, + default_retry=retries.AsyncRetry( + initial=0.1, + maximum=60.0, + multiplier=1.3, + predicate=retries.if_exception_type( + core_exceptions.DeadlineExceeded, + core_exceptions.InternalServerError, + core_exceptions.ResourceExhausted, + core_exceptions.ServiceUnavailable, + ), + deadline=300.0, + ), + default_timeout=300.0, client_info=client_info, ), self.run_aggregation_query: self._wrap_method( diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest.py b/google/cloud/firestore_v1/services/firestore/transports/rest.py index 121aa7386..07e2cdbca 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest.py @@ -1006,27 +1006,27 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client - certificate to configure mutual TLS HTTP channel. It is ignored - if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you are developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - url_scheme: the protocol scheme for the API endpoint. Normally - "https", but for testing or local servers, - "http" can be specified. + credentials_file (Optional[str]): Deprecated. A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. This argument will be + removed in the next major version of this library. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. """ # Run the base constructor # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. @@ -1180,6 +1180,22 @@ def __call__( resp, _ = self._interceptor.post_batch_get_documents_with_metadata( resp, response_metadata ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + http_response = { + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreClient.batch_get_documents", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "BatchGetDocuments", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) return resp class _BatchWrite(_BaseFirestoreRestTransport._BaseBatchWrite, FirestoreRestStub): @@ -2052,6 +2068,22 @@ def __call__( resp, _ = self._interceptor.post_execute_pipeline_with_metadata( resp, response_metadata ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + http_response = { + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreClient.execute_pipeline", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "ExecutePipeline", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) return resp class _GetDocument(_BaseFirestoreRestTransport._BaseGetDocument, FirestoreRestStub): @@ -2934,6 +2966,22 @@ def __call__( resp, _ = self._interceptor.post_run_aggregation_query_with_metadata( resp, response_metadata ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + http_response = { + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreClient.run_aggregation_query", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "RunAggregationQuery", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) return resp class _RunQuery(_BaseFirestoreRestTransport._BaseRunQuery, FirestoreRestStub): @@ -3064,6 +3112,22 @@ def __call__( resp, _ = self._interceptor.post_run_query_with_metadata( resp, response_metadata ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + http_response = { + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.firestore_v1.FirestoreClient.run_query", + extra={ + "serviceName": "google.firestore.v1.Firestore", + "rpcName": "RunQuery", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) return resp class _UpdateDocument( diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py index 721f0792f..7d0c52f94 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py @@ -476,6 +476,63 @@ def _get_query_params_json(transcoded_request): return query_params + class _BaseExecutePipeline: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1/{database=projects/*/databases/*}/documents:executePipeline", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = firestore.ExecutePipelineRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseFirestoreRestTransport._BaseExecutePipeline._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + class _BaseGetDocument: def __hash__(self): # pragma: NO COVER return NotImplementedError("__hash__ must be implemented.") diff --git a/google/cloud/firestore_v1/stream_generator.py b/google/cloud/firestore_v1/stream_generator.py index 7e39a3fba..b4b585601 100644 --- a/google/cloud/firestore_v1/stream_generator.py +++ b/google/cloud/firestore_v1/stream_generator.py @@ -50,7 +50,7 @@ def __init__( self._explain_options = explain_options self._explain_metrics = None - def __iter__(self) -> StreamGenerator: + def __iter__(self) -> StreamGenerator[T]: return self def __next__(self) -> T: diff --git a/google/cloud/firestore_v1/types/document.py b/google/cloud/firestore_v1/types/document.py index 1757571b1..8073ad97a 100644 --- a/google/cloud/firestore_v1/types/document.py +++ b/google/cloud/firestore_v1/types/document.py @@ -74,7 +74,7 @@ class Document(proto.Message): may contain any character. Some characters, including :literal:`\``, must be escaped using a ``\``. For example, :literal:`\`x&y\`` represents ``x&y`` and - :literal:`\`bak\`tik\`` represents :literal:`bak`tik`. + :literal:`\`bak\\`tik\`` represents :literal:`bak`tik`. create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. The time at which the document was created. @@ -195,10 +195,10 @@ class Value(proto.Message): **Requires:** - - Must follow [field reference][FieldReference.field_path] - limitations. + - Must follow [field reference][FieldReference.field_path] + limitations. - - Not allowed to be used when writing documents. + - Not allowed to be used when writing documents. This field is a member of `oneof`_ ``value_type``. function_value (google.cloud.firestore_v1.types.Function): @@ -206,7 +206,7 @@ class Value(proto.Message): **Requires:** - - Not allowed to be used when writing documents. + - Not allowed to be used when writing documents. This field is a member of `oneof`_ ``value_type``. pipeline_value (google.cloud.firestore_v1.types.Pipeline): @@ -214,7 +214,7 @@ class Value(proto.Message): **Requires:** - - Not allowed to be used when writing documents. + - Not allowed to be used when writing documents. This field is a member of `oneof`_ ``value_type``. """ @@ -353,8 +353,8 @@ class Function(proto.Message): **Requires:** - - must be in snake case (lower case with underscore - separator). + - must be in snake case (lower case with underscore + separator). args (MutableSequence[google.cloud.firestore_v1.types.Value]): Optional. Ordered list of arguments the given function expects. @@ -417,8 +417,8 @@ class Stage(proto.Message): **Requires:** - - must be in snake case (lower case with underscore - separator). + - must be in snake case (lower case with underscore + separator). args (MutableSequence[google.cloud.firestore_v1.types.Value]): Optional. Ordered list of arguments the given stage expects. diff --git a/google/cloud/firestore_v1/types/explain_stats.py b/google/cloud/firestore_v1/types/explain_stats.py index 1fda228b6..b0f9421ba 100644 --- a/google/cloud/firestore_v1/types/explain_stats.py +++ b/google/cloud/firestore_v1/types/explain_stats.py @@ -31,16 +31,18 @@ class ExplainStats(proto.Message): - r"""Explain stats for an RPC request, includes both the optimized - plan and execution stats. + r"""Pipeline explain stats. + + Depending on the explain options in the original request, this + can contain the optimized plan and / or execution stats. Attributes: data (google.protobuf.any_pb2.Any): The format depends on the ``output_format`` options in the request. - The only option today is ``TEXT``, which is a - ``google.protobuf.StringValue``. + Currently there are two supported options: ``TEXT`` and + ``JSON``. Both supply a ``google.protobuf.StringValue``. """ data: any_pb2.Any = proto.Field( diff --git a/google/cloud/firestore_v1/types/firestore.py b/google/cloud/firestore_v1/types/firestore.py index f1753c92f..4e53ba313 100644 --- a/google/cloud/firestore_v1/types/firestore.py +++ b/google/cloud/firestore_v1/types/firestore.py @@ -932,21 +932,23 @@ class ExecutePipelineResponse(proto.Message): only a partial progress message is returned. The fields present in the returned documents are only those - that were explicitly requested in the pipeline, this include - those like [``__name__``][google.firestore.v1.Document.name] - & + that were explicitly requested in the pipeline, this + includes those like + [``__name__``][google.firestore.v1.Document.name] and [``__update_time__``][google.firestore.v1.Document.update_time]. This is explicitly a divergence from ``Firestore.RunQuery`` / ``Firestore.GetDocument`` RPCs which always return such fields even when they are not specified in the [``mask``][google.firestore.v1.DocumentMask]. execution_time (google.protobuf.timestamp_pb2.Timestamp): - The time at which the document(s) were read. + The time at which the results are valid. - This may be monotonically increasing; in this case, the - previous documents in the result stream are guaranteed not - to have changed between their ``execution_time`` and this - one. + This is a (not strictly) monotonically increasing value + across multiple responses in the same stream. The API + guarantees that all previously returned results are still + valid at the latest ``execution_time``. This allows the API + consumer to treat the query if it ran at the latest + ``execution_time`` returned. If the query returns no results, a response with ``execution_time`` and no ``results`` will be sent, and this @@ -954,9 +956,11 @@ class ExecutePipelineResponse(proto.Message): explain_stats (google.cloud.firestore_v1.types.ExplainStats): Query explain stats. - Contains all metadata related to pipeline - planning and execution, specific contents depend - on the supplied pipeline options. + This is present on the **last** response if the request + configured explain to run in 'analyze' or 'explain' mode in + the pipeline options. If the query does not return any + results, a response with ``explain_stats`` and no + ``results`` will still be sent. """ transaction: bytes = proto.Field( @@ -1162,8 +1166,8 @@ class PartitionQueryRequest(proto.Message): For example, two subsequent calls using a page_token may return: - - cursor B, cursor M, cursor Q - - cursor A, cursor U, cursor W + - cursor B, cursor M, cursor Q + - cursor A, cursor U, cursor W To obtain a complete result set ordered with respect to the results of the query supplied to PartitionQuery, the results @@ -1237,9 +1241,9 @@ class PartitionQueryResponse(proto.Message): cursors A and B, running the following three queries will return the entire result set of the original query: - - query, end_at A - - query, start_at A, end_at B - - query, start_at B + - query, end_at A + - query, start_at A, end_at B + - query, start_at B An empty result may indicate that the query has too few results to be partitioned, or that the query is not yet @@ -1561,9 +1565,9 @@ class Target(proto.Message): Note that if the client sends multiple ``AddTarget`` requests without an ID, the order of IDs returned in - ``TargetChage.target_ids`` are undefined. Therefore, clients - should provide a target ID instead of relying on the server - to assign one. + ``TargetChange.target_ids`` are undefined. Therefore, + clients should provide a target ID instead of relying on the + server to assign one. If ``target_id`` is non-zero, there must not be an existing active target on this stream with the same ID. diff --git a/google/cloud/firestore_v1/types/pipeline.py b/google/cloud/firestore_v1/types/pipeline.py index 29fbe884b..07688dda7 100644 --- a/google/cloud/firestore_v1/types/pipeline.py +++ b/google/cloud/firestore_v1/types/pipeline.py @@ -34,7 +34,7 @@ class StructuredPipeline(proto.Message): r"""A Firestore query represented as an ordered list of operations / stages. - This is considered the top-level function which plans & executes a + This is considered the top-level function which plans and executes a query. It is logically equivalent to ``query(stages, options)``, but prevents the client from having to build a function wrapper. diff --git a/google/cloud/firestore_v1/types/query.py b/google/cloud/firestore_v1/types/query.py index 9aa8977dd..d50742785 100644 --- a/google/cloud/firestore_v1/types/query.py +++ b/google/cloud/firestore_v1/types/query.py @@ -66,25 +66,25 @@ class StructuredQuery(proto.Message): Firestore guarantees a stable ordering through the following rules: - - The ``order_by`` is required to reference all fields used - with an inequality filter. - - All fields that are required to be in the ``order_by`` - but are not already present are appended in - lexicographical ordering of the field name. - - If an order on ``__name__`` is not specified, it is - appended by default. + - The ``order_by`` is required to reference all fields used + with an inequality filter. + - All fields that are required to be in the ``order_by`` but + are not already present are appended in lexicographical + ordering of the field name. + - If an order on ``__name__`` is not specified, it is + appended by default. Fields are appended with the same sort direction as the last order specified, or 'ASCENDING' if no order was specified. For example: - - ``ORDER BY a`` becomes ``ORDER BY a ASC, __name__ ASC`` - - ``ORDER BY a DESC`` becomes - ``ORDER BY a DESC, __name__ DESC`` - - ``WHERE a > 1`` becomes - ``WHERE a > 1 ORDER BY a ASC, __name__ ASC`` - - ``WHERE __name__ > ... AND a > 1`` becomes - ``WHERE __name__ > ... AND a > 1 ORDER BY a ASC, __name__ ASC`` + - ``ORDER BY a`` becomes ``ORDER BY a ASC, __name__ ASC`` + - ``ORDER BY a DESC`` becomes + ``ORDER BY a DESC, __name__ DESC`` + - ``WHERE a > 1`` becomes + ``WHERE a > 1 ORDER BY a ASC, __name__ ASC`` + - ``WHERE __name__ > ... AND a > 1`` becomes + ``WHERE __name__ > ... AND a > 1 ORDER BY a ASC, __name__ ASC`` start_at (google.cloud.firestore_v1.types.Cursor): A potential prefix of a position in the result set to start the query at. @@ -106,10 +106,10 @@ class StructuredQuery(proto.Message): Continuing off the example above, attaching the following start cursors will have varying impact: - - ``START BEFORE (2, /k/123)``: start the query right - before ``a = 1 AND b > 2 AND __name__ > /k/123``. - - ``START AFTER (10)``: start the query right after - ``a = 1 AND b > 10``. + - ``START BEFORE (2, /k/123)``: start the query right before + ``a = 1 AND b > 2 AND __name__ > /k/123``. + - ``START AFTER (10)``: start the query right after + ``a = 1 AND b > 10``. Unlike ``OFFSET`` which requires scanning over the first N results to skip, a start cursor allows the query to begin at @@ -119,8 +119,8 @@ class StructuredQuery(proto.Message): Requires: - - The number of values cannot be greater than the number of - fields specified in the ``ORDER BY`` clause. + - The number of values cannot be greater than the number of + fields specified in the ``ORDER BY`` clause. end_at (google.cloud.firestore_v1.types.Cursor): A potential prefix of a position in the result set to end the query at. @@ -130,8 +130,8 @@ class StructuredQuery(proto.Message): Requires: - - The number of values cannot be greater than the number of - fields specified in the ``ORDER BY`` clause. + - The number of values cannot be greater than the number of + fields specified in the ``ORDER BY`` clause. offset (int): The number of documents to skip before returning the first result. @@ -142,8 +142,8 @@ class StructuredQuery(proto.Message): Requires: - - The value must be greater than or equal to zero if - specified. + - The value must be greater than or equal to zero if + specified. limit (google.protobuf.wrappers_pb2.Int32Value): The maximum number of results to return. @@ -151,8 +151,8 @@ class StructuredQuery(proto.Message): Requires: - - The value must be greater than or equal to zero if - specified. + - The value must be greater than or equal to zero if + specified. find_nearest (google.cloud.firestore_v1.types.StructuredQuery.FindNearest): Optional. A potential nearest neighbors search. @@ -256,7 +256,7 @@ class CompositeFilter(proto.Message): Requires: - - At least one filter is present. + - At least one filter is present. """ class Operator(proto.Enum): @@ -310,27 +310,27 @@ class Operator(proto.Enum): Requires: - - That ``field`` come first in ``order_by``. + - That ``field`` come first in ``order_by``. LESS_THAN_OR_EQUAL (2): The given ``field`` is less than or equal to the given ``value``. Requires: - - That ``field`` come first in ``order_by``. + - That ``field`` come first in ``order_by``. GREATER_THAN (3): The given ``field`` is greater than the given ``value``. Requires: - - That ``field`` come first in ``order_by``. + - That ``field`` come first in ``order_by``. GREATER_THAN_OR_EQUAL (4): The given ``field`` is greater than or equal to the given ``value``. Requires: - - That ``field`` come first in ``order_by``. + - That ``field`` come first in ``order_by``. EQUAL (5): The given ``field`` is equal to the given ``value``. NOT_EQUAL (6): @@ -338,9 +338,9 @@ class Operator(proto.Enum): Requires: - - No other ``NOT_EQUAL``, ``NOT_IN``, ``IS_NOT_NULL``, or - ``IS_NOT_NAN``. - - That ``field`` comes first in the ``order_by``. + - No other ``NOT_EQUAL``, ``NOT_IN``, ``IS_NOT_NULL``, or + ``IS_NOT_NAN``. + - That ``field`` comes first in the ``order_by``. ARRAY_CONTAINS (7): The given ``field`` is an array that contains the given ``value``. @@ -350,31 +350,31 @@ class Operator(proto.Enum): Requires: - - That ``value`` is a non-empty ``ArrayValue``, subject to - disjunction limits. - - No ``NOT_IN`` filters in the same query. + - That ``value`` is a non-empty ``ArrayValue``, subject to + disjunction limits. + - No ``NOT_IN`` filters in the same query. ARRAY_CONTAINS_ANY (9): The given ``field`` is an array that contains any of the values in the given array. Requires: - - That ``value`` is a non-empty ``ArrayValue``, subject to - disjunction limits. - - No other ``ARRAY_CONTAINS_ANY`` filters within the same - disjunction. - - No ``NOT_IN`` filters in the same query. + - That ``value`` is a non-empty ``ArrayValue``, subject to + disjunction limits. + - No other ``ARRAY_CONTAINS_ANY`` filters within the same + disjunction. + - No ``NOT_IN`` filters in the same query. NOT_IN (10): The value of the ``field`` is not in the given array. Requires: - - That ``value`` is a non-empty ``ArrayValue`` with at most - 10 values. - - No other ``OR``, ``IN``, ``ARRAY_CONTAINS_ANY``, - ``NOT_IN``, ``NOT_EQUAL``, ``IS_NOT_NULL``, or - ``IS_NOT_NAN``. - - That ``field`` comes first in the ``order_by``. + - That ``value`` is a non-empty ``ArrayValue`` with at most + 10 values. + - No other ``OR``, ``IN``, ``ARRAY_CONTAINS_ANY``, + ``NOT_IN``, ``NOT_EQUAL``, ``IS_NOT_NULL``, or + ``IS_NOT_NAN``. + - That ``field`` comes first in the ``order_by``. """ OPERATOR_UNSPECIFIED = 0 LESS_THAN = 1 @@ -433,17 +433,17 @@ class Operator(proto.Enum): Requires: - - No other ``NOT_EQUAL``, ``NOT_IN``, ``IS_NOT_NULL``, or - ``IS_NOT_NAN``. - - That ``field`` comes first in the ``order_by``. + - No other ``NOT_EQUAL``, ``NOT_IN``, ``IS_NOT_NULL``, or + ``IS_NOT_NAN``. + - That ``field`` comes first in the ``order_by``. IS_NOT_NULL (5): The given ``field`` is not equal to ``NULL``. Requires: - - A single ``NOT_EQUAL``, ``NOT_IN``, ``IS_NOT_NULL``, or - ``IS_NOT_NAN``. - - That ``field`` comes first in the ``order_by``. + - A single ``NOT_EQUAL``, ``NOT_IN``, ``IS_NOT_NULL``, or + ``IS_NOT_NAN``. + - That ``field`` comes first in the ``order_by``. """ OPERATOR_UNSPECIFIED = 0 IS_NAN = 2 @@ -493,9 +493,9 @@ class FieldReference(proto.Message): Requires: - - MUST be a dot-delimited (``.``) string of segments, where - each segment conforms to [document field - name][google.firestore.v1.Document.fields] limitations. + - MUST be a dot-delimited (``.``) string of segments, where + each segment conforms to [document field + name][google.firestore.v1.Document.fields] limitations. """ field_path: str = proto.Field( @@ -555,9 +555,9 @@ class FindNearest(proto.Message): when the vectors are more similar, the comparison is inverted. - - For EUCLIDEAN, COSINE: WHERE distance <= - distance_threshold - - For DOT_PRODUCT: WHERE distance >= distance_threshold + - For EUCLIDEAN, COSINE: + ``WHERE distance <= distance_threshold`` + - For DOT_PRODUCT: ``WHERE distance >= distance_threshold`` """ class DistanceMeasure(proto.Enum): @@ -688,8 +688,8 @@ class StructuredAggregationQuery(proto.Message): Requires: - - A minimum of one and maximum of five aggregations per - query. + - A minimum of one and maximum of five aggregations per + query. """ class Aggregation(proto.Message): @@ -749,9 +749,9 @@ class Aggregation(proto.Message): Requires: - - Must be unique across all aggregation aliases. - - Conform to [document field - name][google.firestore.v1.Document.fields] limitations. + - Must be unique across all aggregation aliases. + - Conform to [document field + name][google.firestore.v1.Document.fields] limitations. """ class Count(proto.Message): @@ -778,7 +778,7 @@ class Count(proto.Message): Requires: - - Must be greater than zero when present. + - Must be greater than zero when present. """ up_to: wrappers_pb2.Int64Value = proto.Field( @@ -790,26 +790,26 @@ class Count(proto.Message): class Sum(proto.Message): r"""Sum of the values of the requested field. - - Only numeric values will be aggregated. All non-numeric values - including ``NULL`` are skipped. + - Only numeric values will be aggregated. All non-numeric values + including ``NULL`` are skipped. - - If the aggregated values contain ``NaN``, returns ``NaN``. - Infinity math follows IEEE-754 standards. + - If the aggregated values contain ``NaN``, returns ``NaN``. + Infinity math follows IEEE-754 standards. - - If the aggregated value set is empty, returns 0. + - If the aggregated value set is empty, returns 0. - - Returns a 64-bit integer if all aggregated numbers are integers - and the sum result does not overflow. Otherwise, the result is - returned as a double. Note that even if all the aggregated values - are integers, the result is returned as a double if it cannot fit - within a 64-bit signed integer. When this occurs, the returned - value will lose precision. + - Returns a 64-bit integer if all aggregated numbers are integers + and the sum result does not overflow. Otherwise, the result is + returned as a double. Note that even if all the aggregated values + are integers, the result is returned as a double if it cannot fit + within a 64-bit signed integer. When this occurs, the returned + value will lose precision. - - When underflow occurs, floating-point aggregation is - non-deterministic. This means that running the same query - repeatedly without any changes to the underlying values could - produce slightly different results each time. In those cases, - values should be stored as integers over floating-point numbers. + - When underflow occurs, floating-point aggregation is + non-deterministic. This means that running the same query + repeatedly without any changes to the underlying values could + produce slightly different results each time. In those cases, + values should be stored as integers over floating-point numbers. Attributes: field (google.cloud.firestore_v1.types.StructuredQuery.FieldReference): @@ -825,15 +825,15 @@ class Sum(proto.Message): class Avg(proto.Message): r"""Average of the values of the requested field. - - Only numeric values will be aggregated. All non-numeric values - including ``NULL`` are skipped. + - Only numeric values will be aggregated. All non-numeric values + including ``NULL`` are skipped. - - If the aggregated values contain ``NaN``, returns ``NaN``. - Infinity math follows IEEE-754 standards. + - If the aggregated values contain ``NaN``, returns ``NaN``. + Infinity math follows IEEE-754 standards. - - If the aggregated value set is empty, returns ``NULL``. + - If the aggregated value set is empty, returns ``NULL``. - - Always returns the result as a double. + - Always returns the result as a double. Attributes: field (google.cloud.firestore_v1.types.StructuredQuery.FieldReference): diff --git a/librarian.py b/librarian.py new file mode 100644 index 000000000..ec92a9345 --- /dev/null +++ b/librarian.py @@ -0,0 +1,118 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This script is used to synthesize generated parts of this library.""" +from pathlib import Path +from typing import List, Optional + +import synthtool as s +from synthtool import gcp +from synthtool.languages import python + +common = gcp.CommonTemplates() + +# This library ships clients for 3 different APIs, +# firestore, firestore_admin and firestore_bundle. +# firestore_bundle is not versioned +firestore_default_version = "v1" +firestore_admin_default_version = "v1" + +def update_fixup_scripts(path): + # Add message for missing 'libcst' dependency + s.replace( + library / "scripts" / path, + """import libcst as cst""", + """try: + import libcst as cst +except ImportError: + raise ImportError('Run `python -m pip install "libcst >= 0.2.5"` to install libcst.') + + + """, + ) + +for library in s.get_staging_dirs(default_version=firestore_default_version): + s.move(library / f"google/cloud/firestore_{library.name}", excludes=[f"__init__.py", "noxfile.py"]) + s.move(library / f"tests/", f"tests") + fixup_script_path = "fixup_firestore_v1_keywords.py" + update_fixup_scripts(fixup_script_path) + s.move(library / "scripts" / fixup_script_path) + +for library in s.get_staging_dirs(default_version=firestore_admin_default_version): + s.move(library / f"google/cloud/firestore_admin_{library.name}", excludes=[f"__init__.py", "noxfile.py"]) + s.move(library / f"tests", f"tests") + fixup_script_path = "fixup_firestore_admin_v1_keywords.py" + update_fixup_scripts(fixup_script_path) + s.move(library / "scripts" / fixup_script_path) + +for library in s.get_staging_dirs(): + s.replace( + library / "google/cloud/bundle/types/bundle.py", + "from google.firestore.v1 import document_pb2 # type: ignore\n" + "from google.firestore.v1 import query_pb2 # type: ignore", + "from google.cloud.firestore_v1.types import document as document_pb2 # type: ignore\n" + "from google.cloud.firestore_v1.types import query as query_pb2 # type: ignore" + ) + + s.replace( + library / "google/cloud/bundle/__init__.py", + "from .types.bundle import BundleMetadata\n" + "from .types.bundle import NamedQuery\n", + "from .types.bundle import BundleMetadata\n" + "from .types.bundle import NamedQuery\n" + "\n" + "from .bundle import FirestoreBundle\n", + ) + + s.replace( + library / "google/cloud/bundle/__init__.py", + "from google.cloud.bundle import gapic_version as package_version\n", + "from google.cloud.firestore_bundle import gapic_version as package_version\n", + ) + + s.replace( + library / "google/cloud/bundle/__init__.py", + "\'BundledQuery\',", + "\"BundledQuery\",\n\"FirestoreBundle\",",) + + s.move( + library / f"google/cloud/bundle", + f"google/cloud/firestore_bundle", + excludes=["noxfile.py"], + ) + s.move(library / f"tests", f"tests") + +s.remove_staging_dirs() + +# ---------------------------------------------------------------------------- +# Add templated files +# ---------------------------------------------------------------------------- +templated_files = common.py_library( + samples=False, # set to True only if there are samples + unit_test_external_dependencies=["aiounittest", "six", "freezegun"], + system_test_external_dependencies=["pytest-asyncio", "six"], + microgenerator=True, + cov_level=100, + split_system_tests=True, + default_python_version="3.14", + system_test_python_versions=["3.14"], + unit_test_python_versions=["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13", "3.14"], +) + +s.move(templated_files, + excludes=[".github/**", ".kokoro/**", "renovate.json"]) + +python.py_samples(skip_readmes=True) + +s.shell.run(["nox", "-s", "blacken"], hide_output=False) diff --git a/noxfile.py b/noxfile.py index 1e7039284..a1bce0822 100644 --- a/noxfile.py +++ b/noxfile.py @@ -14,6 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +# DO NOT EDIT THIS FILE OUTSIDE OF `.librarian/generator-input` +# The source of truth for this file is `.librarian/generator-input` + + # Generated by synthtool. DO NOT EDIT! from __future__ import absolute_import @@ -33,7 +37,7 @@ ISORT_VERSION = "isort==5.11.0" LINT_PATHS = ["docs", "google", "tests", "noxfile.py", "setup.py"] -DEFAULT_PYTHON_VERSION = "3.13" +DEFAULT_PYTHON_VERSION = "3.14" UNIT_TEST_PYTHON_VERSIONS: List[str] = [ "3.7", @@ -62,7 +66,7 @@ UNIT_TEST_EXTRAS: List[str] = [] UNIT_TEST_EXTRAS_BY_PYTHON: Dict[str, List[str]] = {} -SYSTEM_TEST_PYTHON_VERSIONS: List[str] = ["3.9", "3.14"] +SYSTEM_TEST_PYTHON_VERSIONS: List[str] = ["3.14"] SYSTEM_TEST_STANDARD_DEPENDENCIES: List[str] = [ "mock", "pytest", @@ -475,7 +479,7 @@ def docfx(session): ) -@nox.session(python="3.13") +@nox.session(python=DEFAULT_PYTHON_VERSION) @nox.parametrize( "protobuf_implementation", ["python", "upb", "cpp"], diff --git a/scripts/fixup_firestore_admin_v1_keywords.py b/scripts/fixup_firestore_admin_v1_keywords.py deleted file mode 100644 index 0920ce408..000000000 --- a/scripts/fixup_firestore_admin_v1_keywords.py +++ /dev/null @@ -1,212 +0,0 @@ -#! /usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import argparse -import os -try: - import libcst as cst -except ImportError: - raise ImportError('Run `python -m pip install "libcst >= 0.2.5"` to install libcst.') - - - -import pathlib -import sys -from typing import (Any, Callable, Dict, List, Sequence, Tuple) - - -def partition( - predicate: Callable[[Any], bool], - iterator: Sequence[Any] -) -> Tuple[List[Any], List[Any]]: - """A stable, out-of-place partition.""" - results = ([], []) - - for i in iterator: - results[int(predicate(i))].append(i) - - # Returns trueList, falseList - return results[1], results[0] - - -class firestore_adminCallTransformer(cst.CSTTransformer): - CTRL_PARAMS: Tuple[str] = ('retry', 'timeout', 'metadata') - METHOD_TO_PARAMS: Dict[str, Tuple[str]] = { - 'bulk_delete_documents': ('name', 'collection_ids', 'namespace_ids', ), - 'create_backup_schedule': ('parent', 'backup_schedule', ), - 'create_database': ('parent', 'database', 'database_id', ), - 'create_index': ('parent', 'index', ), - 'create_user_creds': ('parent', 'user_creds', 'user_creds_id', ), - 'delete_backup': ('name', ), - 'delete_backup_schedule': ('name', ), - 'delete_database': ('name', 'etag', ), - 'delete_index': ('name', ), - 'delete_user_creds': ('name', ), - 'disable_user_creds': ('name', ), - 'enable_user_creds': ('name', ), - 'export_documents': ('name', 'collection_ids', 'output_uri_prefix', 'namespace_ids', 'snapshot_time', ), - 'get_backup': ('name', ), - 'get_backup_schedule': ('name', ), - 'get_database': ('name', ), - 'get_field': ('name', ), - 'get_index': ('name', ), - 'get_user_creds': ('name', ), - 'import_documents': ('name', 'collection_ids', 'input_uri_prefix', 'namespace_ids', ), - 'list_backups': ('parent', 'filter', ), - 'list_backup_schedules': ('parent', ), - 'list_databases': ('parent', 'show_deleted', ), - 'list_fields': ('parent', 'filter', 'page_size', 'page_token', ), - 'list_indexes': ('parent', 'filter', 'page_size', 'page_token', ), - 'list_user_creds': ('parent', ), - 'reset_user_password': ('name', ), - 'restore_database': ('parent', 'database_id', 'backup', 'encryption_config', 'tags', ), - 'update_backup_schedule': ('backup_schedule', 'update_mask', ), - 'update_database': ('database', 'update_mask', ), - 'update_field': ('field', 'update_mask', ), - } - - def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: - try: - key = original.func.attr.value - kword_params = self.METHOD_TO_PARAMS[key] - except (AttributeError, KeyError): - # Either not a method from the API or too convoluted to be sure. - return updated - - # If the existing code is valid, keyword args come after positional args. - # Therefore, all positional args must map to the first parameters. - args, kwargs = partition(lambda a: not bool(a.keyword), updated.args) - if any(k.keyword.value == "request" for k in kwargs): - # We've already fixed this file, don't fix it again. - return updated - - kwargs, ctrl_kwargs = partition( - lambda a: a.keyword.value not in self.CTRL_PARAMS, - kwargs - ) - - args, ctrl_args = args[:len(kword_params)], args[len(kword_params):] - ctrl_kwargs.extend(cst.Arg(value=a.value, keyword=cst.Name(value=ctrl)) - for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS)) - - request_arg = cst.Arg( - value=cst.Dict([ - cst.DictElement( - cst.SimpleString("'{}'".format(name)), -cst.Element(value=arg.value) - ) - # Note: the args + kwargs looks silly, but keep in mind that - # the control parameters had to be stripped out, and that - # those could have been passed positionally or by keyword. - for name, arg in zip(kword_params, args + kwargs)]), - keyword=cst.Name("request") - ) - - return updated.with_changes( - args=[request_arg] + ctrl_kwargs - ) - - -def fix_files( - in_dir: pathlib.Path, - out_dir: pathlib.Path, - *, - transformer=firestore_adminCallTransformer(), -): - """Duplicate the input dir to the output dir, fixing file method calls. - - Preconditions: - * in_dir is a real directory - * out_dir is a real, empty directory - """ - pyfile_gen = ( - pathlib.Path(os.path.join(root, f)) - for root, _, files in os.walk(in_dir) - for f in files if os.path.splitext(f)[1] == ".py" - ) - - for fpath in pyfile_gen: - with open(fpath, 'r') as f: - src = f.read() - - # Parse the code and insert method call fixes. - tree = cst.parse_module(src) - updated = tree.visit(transformer) - - # Create the path and directory structure for the new file. - updated_path = out_dir.joinpath(fpath.relative_to(in_dir)) - updated_path.parent.mkdir(parents=True, exist_ok=True) - - # Generate the updated source file at the corresponding path. - with open(updated_path, 'w') as f: - f.write(updated.code) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser( - description="""Fix up source that uses the firestore_admin client library. - -The existing sources are NOT overwritten but are copied to output_dir with changes made. - -Note: This tool operates at a best-effort level at converting positional - parameters in client method calls to keyword based parameters. - Cases where it WILL FAIL include - A) * or ** expansion in a method call. - B) Calls via function or method alias (includes free function calls) - C) Indirect or dispatched calls (e.g. the method is looked up dynamically) - - These all constitute false negatives. The tool will also detect false - positives when an API method shares a name with another method. -""") - parser.add_argument( - '-d', - '--input-directory', - required=True, - dest='input_dir', - help='the input directory to walk for python files to fix up', - ) - parser.add_argument( - '-o', - '--output-directory', - required=True, - dest='output_dir', - help='the directory to output files fixed via un-flattening', - ) - args = parser.parse_args() - input_dir = pathlib.Path(args.input_dir) - output_dir = pathlib.Path(args.output_dir) - if not input_dir.is_dir(): - print( - f"input directory '{input_dir}' does not exist or is not a directory", - file=sys.stderr, - ) - sys.exit(-1) - - if not output_dir.is_dir(): - print( - f"output directory '{output_dir}' does not exist or is not a directory", - file=sys.stderr, - ) - sys.exit(-1) - - if os.listdir(output_dir): - print( - f"output directory '{output_dir}' is not empty", - file=sys.stderr, - ) - sys.exit(-1) - - fix_files(input_dir, output_dir) diff --git a/scripts/fixup_firestore_v1_keywords.py b/scripts/fixup_firestore_v1_keywords.py deleted file mode 100644 index 6481e76bb..000000000 --- a/scripts/fixup_firestore_v1_keywords.py +++ /dev/null @@ -1,197 +0,0 @@ -#! /usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import argparse -import os -try: - import libcst as cst -except ImportError: - raise ImportError('Run `python -m pip install "libcst >= 0.2.5"` to install libcst.') - - - -import pathlib -import sys -from typing import (Any, Callable, Dict, List, Sequence, Tuple) - - -def partition( - predicate: Callable[[Any], bool], - iterator: Sequence[Any] -) -> Tuple[List[Any], List[Any]]: - """A stable, out-of-place partition.""" - results = ([], []) - - for i in iterator: - results[int(predicate(i))].append(i) - - # Returns trueList, falseList - return results[1], results[0] - - -class firestoreCallTransformer(cst.CSTTransformer): - CTRL_PARAMS: Tuple[str] = ('retry', 'timeout', 'metadata') - METHOD_TO_PARAMS: Dict[str, Tuple[str]] = { - 'batch_get_documents': ('database', 'documents', 'mask', 'transaction', 'new_transaction', 'read_time', ), - 'batch_write': ('database', 'writes', 'labels', ), - 'begin_transaction': ('database', 'options', ), - 'commit': ('database', 'writes', 'transaction', ), - 'create_document': ('parent', 'collection_id', 'document', 'document_id', 'mask', ), - 'delete_document': ('name', 'current_document', ), - 'get_document': ('name', 'mask', 'transaction', 'read_time', ), - 'list_collection_ids': ('parent', 'page_size', 'page_token', 'read_time', ), - 'list_documents': ('parent', 'collection_id', 'page_size', 'page_token', 'order_by', 'mask', 'transaction', 'read_time', 'show_missing', ), - 'listen': ('database', 'add_target', 'remove_target', 'labels', ), - 'partition_query': ('parent', 'structured_query', 'partition_count', 'page_token', 'page_size', 'read_time', ), - 'rollback': ('database', 'transaction', ), - 'run_aggregation_query': ('parent', 'structured_aggregation_query', 'transaction', 'new_transaction', 'read_time', 'explain_options', ), - 'run_query': ('parent', 'structured_query', 'transaction', 'new_transaction', 'read_time', 'explain_options', ), - 'update_document': ('document', 'update_mask', 'mask', 'current_document', ), - 'write': ('database', 'stream_id', 'writes', 'stream_token', 'labels', ), - } - - def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: - try: - key = original.func.attr.value - kword_params = self.METHOD_TO_PARAMS[key] - except (AttributeError, KeyError): - # Either not a method from the API or too convoluted to be sure. - return updated - - # If the existing code is valid, keyword args come after positional args. - # Therefore, all positional args must map to the first parameters. - args, kwargs = partition(lambda a: not bool(a.keyword), updated.args) - if any(k.keyword.value == "request" for k in kwargs): - # We've already fixed this file, don't fix it again. - return updated - - kwargs, ctrl_kwargs = partition( - lambda a: a.keyword.value not in self.CTRL_PARAMS, - kwargs - ) - - args, ctrl_args = args[:len(kword_params)], args[len(kword_params):] - ctrl_kwargs.extend(cst.Arg(value=a.value, keyword=cst.Name(value=ctrl)) - for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS)) - - request_arg = cst.Arg( - value=cst.Dict([ - cst.DictElement( - cst.SimpleString("'{}'".format(name)), -cst.Element(value=arg.value) - ) - # Note: the args + kwargs looks silly, but keep in mind that - # the control parameters had to be stripped out, and that - # those could have been passed positionally or by keyword. - for name, arg in zip(kword_params, args + kwargs)]), - keyword=cst.Name("request") - ) - - return updated.with_changes( - args=[request_arg] + ctrl_kwargs - ) - - -def fix_files( - in_dir: pathlib.Path, - out_dir: pathlib.Path, - *, - transformer=firestoreCallTransformer(), -): - """Duplicate the input dir to the output dir, fixing file method calls. - - Preconditions: - * in_dir is a real directory - * out_dir is a real, empty directory - """ - pyfile_gen = ( - pathlib.Path(os.path.join(root, f)) - for root, _, files in os.walk(in_dir) - for f in files if os.path.splitext(f)[1] == ".py" - ) - - for fpath in pyfile_gen: - with open(fpath, 'r') as f: - src = f.read() - - # Parse the code and insert method call fixes. - tree = cst.parse_module(src) - updated = tree.visit(transformer) - - # Create the path and directory structure for the new file. - updated_path = out_dir.joinpath(fpath.relative_to(in_dir)) - updated_path.parent.mkdir(parents=True, exist_ok=True) - - # Generate the updated source file at the corresponding path. - with open(updated_path, 'w') as f: - f.write(updated.code) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser( - description="""Fix up source that uses the firestore client library. - -The existing sources are NOT overwritten but are copied to output_dir with changes made. - -Note: This tool operates at a best-effort level at converting positional - parameters in client method calls to keyword based parameters. - Cases where it WILL FAIL include - A) * or ** expansion in a method call. - B) Calls via function or method alias (includes free function calls) - C) Indirect or dispatched calls (e.g. the method is looked up dynamically) - - These all constitute false negatives. The tool will also detect false - positives when an API method shares a name with another method. -""") - parser.add_argument( - '-d', - '--input-directory', - required=True, - dest='input_dir', - help='the input directory to walk for python files to fix up', - ) - parser.add_argument( - '-o', - '--output-directory', - required=True, - dest='output_dir', - help='the directory to output files fixed via un-flattening', - ) - args = parser.parse_args() - input_dir = pathlib.Path(args.input_dir) - output_dir = pathlib.Path(args.output_dir) - if not input_dir.is_dir(): - print( - f"input directory '{input_dir}' does not exist or is not a directory", - file=sys.stderr, - ) - sys.exit(-1) - - if not output_dir.is_dir(): - print( - f"output directory '{output_dir}' does not exist or is not a directory", - file=sys.stderr, - ) - sys.exit(-1) - - if os.listdir(output_dir): - print( - f"output directory '{output_dir}' is not empty", - file=sys.stderr, - ) - sys.exit(-1) - - fix_files(input_dir, output_dir) diff --git a/setup.py b/setup.py index 8625abce9..72a6f53bd 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +# DO NOT EDIT THIS FILE OUTSIDE OF `.librarian/generator-input` +# The source of truth for this file is `.librarian/generator-input` + + import io import os @@ -90,10 +94,6 @@ install_requires=dependencies, extras_require=extras, python_requires=">=3.7", - scripts=[ - "scripts/fixup_firestore_v1_keywords.py", - "scripts/fixup_firestore_admin_v1_keywords.py", - ], include_package_data=True, zip_safe=False, ) diff --git a/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py b/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py index db8276d57..7ef138d15 100644 --- a/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py +++ b/tests/unit/gapic/firestore_admin_v1/test_firestore_admin.py @@ -75,6 +75,7 @@ from google.cloud.firestore_admin_v1.types import index as gfa_index from google.cloud.firestore_admin_v1.types import operation as gfa_operation from google.cloud.firestore_admin_v1.types import schedule +from google.cloud.firestore_admin_v1.types import snapshot from google.cloud.firestore_admin_v1.types import user_creds from google.cloud.firestore_admin_v1.types import user_creds as gfa_user_creds from google.cloud.location import locations_pb2 @@ -185,12 +186,19 @@ def test__read_environment_variables(): with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): - with pytest.raises(ValueError) as excinfo: - FirestoreAdminClient._read_environment_variables() - assert ( - str(excinfo.value) - == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with pytest.raises(ValueError) as excinfo: + FirestoreAdminClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + else: + assert FirestoreAdminClient._read_environment_variables() == ( + False, + "auto", + None, + ) with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): assert FirestoreAdminClient._read_environment_variables() == ( @@ -229,6 +237,105 @@ def test__read_environment_variables(): ) +def test_use_client_cert_effective(): + # Test case 1: Test when `should_use_client_cert` returns True. + # We mock the `should_use_client_cert` function to simulate a scenario where + # the google-auth library supports automatic mTLS and determines that a + # client certificate should be used. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch( + "google.auth.transport.mtls.should_use_client_cert", return_value=True + ): + assert FirestoreAdminClient._use_client_cert_effective() is True + + # Test case 2: Test when `should_use_client_cert` returns False. + # We mock the `should_use_client_cert` function to simulate a scenario where + # the google-auth library supports automatic mTLS and determines that a + # client certificate should NOT be used. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch( + "google.auth.transport.mtls.should_use_client_cert", return_value=False + ): + assert FirestoreAdminClient._use_client_cert_effective() is False + + # Test case 3: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "true". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + assert FirestoreAdminClient._use_client_cert_effective() is True + + # Test case 4: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "false". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"} + ): + assert FirestoreAdminClient._use_client_cert_effective() is False + + # Test case 5: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "True". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "True"}): + assert FirestoreAdminClient._use_client_cert_effective() is True + + # Test case 6: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "False". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "False"} + ): + assert FirestoreAdminClient._use_client_cert_effective() is False + + # Test case 7: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "TRUE". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "TRUE"}): + assert FirestoreAdminClient._use_client_cert_effective() is True + + # Test case 8: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "FALSE". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "FALSE"} + ): + assert FirestoreAdminClient._use_client_cert_effective() is False + + # Test case 9: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not set. + # In this case, the method should return False, which is the default value. + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, clear=True): + assert FirestoreAdminClient._use_client_cert_effective() is False + + # Test case 10: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to an invalid value. + # The method should raise a ValueError as the environment variable must be either + # "true" or "false". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "unsupported"} + ): + with pytest.raises(ValueError): + FirestoreAdminClient._use_client_cert_effective() + + # Test case 11: Test when `should_use_client_cert` is available and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to an invalid value. + # The method should return False as the environment variable is set to an invalid value. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "unsupported"} + ): + assert FirestoreAdminClient._use_client_cert_effective() is False + + # Test case 12: Test when `should_use_client_cert` is available and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is unset. Also, + # the GOOGLE_API_CONFIG environment variable is unset. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": ""}): + with mock.patch.dict(os.environ, {"GOOGLE_API_CERTIFICATE_CONFIG": ""}): + assert FirestoreAdminClient._use_client_cert_effective() is False + + def test__get_client_cert_source(): mock_provided_cert_source = mock.Mock() mock_default_cert_source = mock.Mock() @@ -594,17 +701,6 @@ def test_firestore_admin_client_client_options( == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" ) - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): - with pytest.raises(ValueError) as excinfo: - client = client_class(transport=transport_name) - assert ( - str(excinfo.value) - == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: @@ -820,6 +916,119 @@ def test_firestore_admin_client_get_mtls_endpoint_and_cert_source(client_class): assert api_endpoint == mock_api_endpoint assert cert_source is None + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "Unsupported". + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, + api_endpoint=mock_api_endpoint, + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test cases for mTLS enablement when GOOGLE_API_USE_CLIENT_CERTIFICATE is unset. + test_cases = [ + ( + # With workloads present in config, mTLS is enabled. + { + "version": 1, + "cert_configs": { + "workload": { + "cert_path": "path/to/cert/file", + "key_path": "path/to/key/file", + } + }, + }, + mock_client_cert_source, + ), + ( + # With workloads not present in config, mTLS is disabled. + { + "version": 1, + "cert_configs": {}, + }, + None, + ), + ] + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + for config_data, expected_cert_source in test_cases: + env = os.environ.copy() + env.pop("GOOGLE_API_USE_CLIENT_CERTIFICATE", None) + with mock.patch.dict(os.environ, env, clear=True): + config_filename = "mock_certificate_config.json" + config_file_content = json.dumps(config_data) + m = mock.mock_open(read_data=config_file_content) + with mock.patch("builtins.open", m): + with mock.patch.dict( + os.environ, {"GOOGLE_API_CERTIFICATE_CONFIG": config_filename} + ): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, + api_endpoint=mock_api_endpoint, + ) + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source is expected_cert_source + + # Test cases for mTLS enablement when GOOGLE_API_USE_CLIENT_CERTIFICATE is unset(empty). + test_cases = [ + ( + # With workloads present in config, mTLS is enabled. + { + "version": 1, + "cert_configs": { + "workload": { + "cert_path": "path/to/cert/file", + "key_path": "path/to/key/file", + } + }, + }, + mock_client_cert_source, + ), + ( + # With workloads not present in config, mTLS is disabled. + { + "version": 1, + "cert_configs": {}, + }, + None, + ), + ] + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + for config_data, expected_cert_source in test_cases: + env = os.environ.copy() + env.pop("GOOGLE_API_USE_CLIENT_CERTIFICATE", "") + with mock.patch.dict(os.environ, env, clear=True): + config_filename = "mock_certificate_config.json" + config_file_content = json.dumps(config_data) + m = mock.mock_open(read_data=config_file_content) + with mock.patch("builtins.open", m): + with mock.patch.dict( + os.environ, {"GOOGLE_API_CERTIFICATE_CONFIG": config_filename} + ): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, + api_endpoint=mock_api_endpoint, + ) + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source is expected_cert_source + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() @@ -870,18 +1079,6 @@ def test_firestore_admin_client_get_mtls_endpoint_and_cert_source(client_class): == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" ) - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): - with pytest.raises(ValueError) as excinfo: - client_class.get_mtls_endpoint_and_cert_source() - - assert ( - str(excinfo.value) - == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - @pytest.mark.parametrize( "client_class", [FirestoreAdminClient, FirestoreAdminAsyncClient] @@ -11859,6 +12056,192 @@ async def test_delete_backup_schedule_flattened_error_async(): ) +@pytest.mark.parametrize( + "request_type", + [ + firestore_admin.CloneDatabaseRequest, + dict, + ], +) +def test_clone_database(request_type, transport: str = "grpc"): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.clone_database(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = firestore_admin.CloneDatabaseRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_clone_database_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = firestore_admin.CloneDatabaseRequest( + parent="parent_value", + database_id="database_id_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.clone_database(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == firestore_admin.CloneDatabaseRequest( + parent="parent_value", + database_id="database_id_value", + ) + + +def test_clone_database_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.clone_database in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.clone_database] = mock_rpc + request = {} + client.clone_database(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods call wrapper_fn to build a cached + # client._transport.operations_client instance on first rpc call. + # Subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.clone_database(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_clone_database_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FirestoreAdminAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.clone_database + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.clone_database + ] = mock_rpc + + request = {} + await client.clone_database(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods call wrapper_fn to build a cached + # client._transport.operations_client instance on first rpc call. + # Subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.clone_database(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_clone_database_async( + transport: str = "grpc_asyncio", request_type=firestore_admin.CloneDatabaseRequest +): + client = FirestoreAdminAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.clone_database(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = firestore_admin.CloneDatabaseRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_clone_database_async_from_dict(): + await test_clone_database_async(request_type=dict) + + def test_create_index_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call @@ -17625,31 +18008,166 @@ def test_delete_backup_schedule_rest_flattened_error(transport: str = "rest"): ) -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.FirestoreAdminGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): +def test_clone_database_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: client = FirestoreAdminClient( credentials=ga_credentials.AnonymousCredentials(), - transport=transport, + transport="rest", ) - # It is an error to provide a credentials file and a transport instance. - transport = transports.FirestoreAdminGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = FirestoreAdminClient( - client_options={"credentials_file": "credentials.json"}, - transport=transport, + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.clone_database in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. ) + client._transport._wrapped_methods[client._transport.clone_database] = mock_rpc - # It is an error to provide an api_key and a transport instance. - transport = transports.FirestoreAdminGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) + request = {} + client.clone_database(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.clone_database(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_clone_database_rest_required_fields( + request_type=firestore_admin.CloneDatabaseRequest, +): + transport_class = transports.FirestoreAdminRestTransport + + request_init = {} + request_init["parent"] = "" + request_init["database_id"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).clone_database._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + jsonified_request["databaseId"] = "database_id_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).clone_database._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + assert "databaseId" in jsonified_request + assert jsonified_request["databaseId"] == "database_id_value" + + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.clone_database(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_clone_database_rest_unset_required_fields(): + transport = transports.FirestoreAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.clone_database._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "parent", + "databaseId", + "pitrSnapshot", + ) + ) + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.FirestoreAdminGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.FirestoreAdminGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = FirestoreAdminClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.FirestoreAdminGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) options = client_options.ClientOptions() options.api_key = "api_key" with pytest.raises(ValueError): @@ -18404,6 +18922,91 @@ def test_delete_backup_schedule_empty_call_grpc(): assert args[0] == request_msg +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_clone_database_empty_call_grpc(): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.clone_database(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = firestore_admin.CloneDatabaseRequest() + + assert args[0] == request_msg + + +def test_clone_database_routing_parameters_request_1_grpc(): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.clone_database( + request={"pitr_snapshot": {"database": "projects/sample1/sample2"}} + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore_admin.CloneDatabaseRequest( + **{"pitr_snapshot": {"database": "projects/sample1/sample2"}} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + +def test_clone_database_routing_parameters_request_2_grpc(): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.clone_database( + request={ + "pitr_snapshot": { + "database": "projects/sample1/databases/sample2/sample3" + } + } + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore_admin.CloneDatabaseRequest( + **{ + "pitr_snapshot": { + "database": "projects/sample1/databases/sample2/sample3" + } + } + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1", "database_id": "sample2"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + def test_transport_kind_grpc_asyncio(): transport = FirestoreAdminAsyncClient.get_transport_class("grpc_asyncio")( credentials=async_anonymous_credentials() @@ -19270,6 +19873,103 @@ async def test_delete_backup_schedule_empty_call_grpc_asyncio(): assert args[0] == request_msg +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_clone_database_empty_call_grpc_asyncio(): + client = FirestoreAdminAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + await client.clone_database(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = firestore_admin.CloneDatabaseRequest() + + assert args[0] == request_msg + + +@pytest.mark.asyncio +async def test_clone_database_routing_parameters_request_1_grpc_asyncio(): + client = FirestoreAdminAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + await client.clone_database( + request={"pitr_snapshot": {"database": "projects/sample1/sample2"}} + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore_admin.CloneDatabaseRequest( + **{"pitr_snapshot": {"database": "projects/sample1/sample2"}} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + +@pytest.mark.asyncio +async def test_clone_database_routing_parameters_request_2_grpc_asyncio(): + client = FirestoreAdminAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + await client.clone_database( + request={ + "pitr_snapshot": { + "database": "projects/sample1/databases/sample2/sample3" + } + } + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore_admin.CloneDatabaseRequest( + **{ + "pitr_snapshot": { + "database": "projects/sample1/databases/sample2/sample3" + } + } + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1", "database_id": "sample2"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + def test_transport_kind_rest(): transport = FirestoreAdminClient.get_transport_class("rest")( credentials=ga_credentials.AnonymousCredentials() @@ -23839,6 +24539,129 @@ def test_delete_backup_schedule_rest_interceptors(null_interceptor): pre.assert_called_once() +def test_clone_database_rest_bad_request( + request_type=firestore_admin.CloneDatabaseRequest, +): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.clone_database(request) + + +@pytest.mark.parametrize( + "request_type", + [ + firestore_admin.CloneDatabaseRequest, + dict, + ], +) +def test_clone_database_rest_call_success(request_type): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.clone_database(request) + + # Establish that the response is the type that we expect. + json_return_value = json_format.MessageToJson(return_value) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_clone_database_rest_interceptors(null_interceptor): + transport = transports.FirestoreAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.FirestoreAdminRestInterceptor(), + ) + client = FirestoreAdminClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.FirestoreAdminRestInterceptor, "post_clone_database" + ) as post, mock.patch.object( + transports.FirestoreAdminRestInterceptor, "post_clone_database_with_metadata" + ) as post_with_metadata, mock.patch.object( + transports.FirestoreAdminRestInterceptor, "pre_clone_database" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + post_with_metadata.assert_not_called() + pb_message = firestore_admin.CloneDatabaseRequest.pb( + firestore_admin.CloneDatabaseRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = json_format.MessageToJson(operations_pb2.Operation()) + req.return_value.content = return_value + + request = firestore_admin.CloneDatabaseRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + post_with_metadata.return_value = operations_pb2.Operation(), metadata + + client.clone_database( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + post_with_metadata.assert_called_once() + + def test_cancel_operation_rest_bad_request( request_type=operations_pb2.CancelOperationRequest, ): @@ -24736,6 +25559,88 @@ def test_delete_backup_schedule_empty_call_rest(): assert args[0] == request_msg +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_clone_database_empty_call_rest(): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + client.clone_database(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = firestore_admin.CloneDatabaseRequest() + + assert args[0] == request_msg + + +def test_clone_database_routing_parameters_request_1_rest(): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + client.clone_database( + request={"pitr_snapshot": {"database": "projects/sample1/sample2"}} + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore_admin.CloneDatabaseRequest( + **{"pitr_snapshot": {"database": "projects/sample1/sample2"}} + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + +def test_clone_database_routing_parameters_request_2_rest(): + client = FirestoreAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.clone_database), "__call__") as call: + client.clone_database( + request={ + "pitr_snapshot": { + "database": "projects/sample1/databases/sample2/sample3" + } + } + ) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, kw = call.mock_calls[0] + request_msg = firestore_admin.CloneDatabaseRequest( + **{ + "pitr_snapshot": { + "database": "projects/sample1/databases/sample2/sample3" + } + } + ) + + assert args[0] == request_msg + + expected_headers = {"project_id": "sample1", "database_id": "sample2"} + assert ( + gapic_v1.routing_header.to_grpc_metadata(expected_headers) in kw["metadata"] + ) + + def test_firestore_admin_rest_lro_client(): client = FirestoreAdminClient( credentials=ga_credentials.AnonymousCredentials(), @@ -24817,6 +25722,7 @@ def test_firestore_admin_base_transport(): "list_backup_schedules", "update_backup_schedule", "delete_backup_schedule", + "clone_database", "get_operation", "cancel_operation", "delete_operation", @@ -25189,6 +26095,9 @@ def test_firestore_admin_client_transport_session_collision(transport_name): session1 = client1.transport.delete_backup_schedule._session session2 = client2.transport.delete_backup_schedule._session assert session1 != session2 + session1 = client1.transport.clone_database._session + session2 = client2.transport.clone_database._session + assert session1 != session2 def test_firestore_admin_grpc_transport_channel(): @@ -25219,6 +26128,7 @@ def test_firestore_admin_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.filterwarnings("ignore::FutureWarning") @pytest.mark.parametrize( "transport_class", [ diff --git a/tests/unit/gapic/firestore_v1/test_firestore.py b/tests/unit/gapic/firestore_v1/test_firestore.py index d91e91c96..af45e4326 100644 --- a/tests/unit/gapic/firestore_v1/test_firestore.py +++ b/tests/unit/gapic/firestore_v1/test_firestore.py @@ -162,12 +162,19 @@ def test__read_environment_variables(): with mock.patch.dict( os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} ): - with pytest.raises(ValueError) as excinfo: - FirestoreClient._read_environment_variables() - assert ( - str(excinfo.value) - == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with pytest.raises(ValueError) as excinfo: + FirestoreClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + else: + assert FirestoreClient._read_environment_variables() == ( + False, + "auto", + None, + ) with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): assert FirestoreClient._read_environment_variables() == (False, "never", None) @@ -194,6 +201,105 @@ def test__read_environment_variables(): ) +def test_use_client_cert_effective(): + # Test case 1: Test when `should_use_client_cert` returns True. + # We mock the `should_use_client_cert` function to simulate a scenario where + # the google-auth library supports automatic mTLS and determines that a + # client certificate should be used. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch( + "google.auth.transport.mtls.should_use_client_cert", return_value=True + ): + assert FirestoreClient._use_client_cert_effective() is True + + # Test case 2: Test when `should_use_client_cert` returns False. + # We mock the `should_use_client_cert` function to simulate a scenario where + # the google-auth library supports automatic mTLS and determines that a + # client certificate should NOT be used. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch( + "google.auth.transport.mtls.should_use_client_cert", return_value=False + ): + assert FirestoreClient._use_client_cert_effective() is False + + # Test case 3: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "true". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + assert FirestoreClient._use_client_cert_effective() is True + + # Test case 4: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "false". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"} + ): + assert FirestoreClient._use_client_cert_effective() is False + + # Test case 5: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "True". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "True"}): + assert FirestoreClient._use_client_cert_effective() is True + + # Test case 6: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "False". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "False"} + ): + assert FirestoreClient._use_client_cert_effective() is False + + # Test case 7: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "TRUE". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "TRUE"}): + assert FirestoreClient._use_client_cert_effective() is True + + # Test case 8: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "FALSE". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "FALSE"} + ): + assert FirestoreClient._use_client_cert_effective() is False + + # Test case 9: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not set. + # In this case, the method should return False, which is the default value. + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, clear=True): + assert FirestoreClient._use_client_cert_effective() is False + + # Test case 10: Test when `should_use_client_cert` is unavailable and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to an invalid value. + # The method should raise a ValueError as the environment variable must be either + # "true" or "false". + if not hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "unsupported"} + ): + with pytest.raises(ValueError): + FirestoreClient._use_client_cert_effective() + + # Test case 11: Test when `should_use_client_cert` is available and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to an invalid value. + # The method should return False as the environment variable is set to an invalid value. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "unsupported"} + ): + assert FirestoreClient._use_client_cert_effective() is False + + # Test case 12: Test when `should_use_client_cert` is available and the + # `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is unset. Also, + # the GOOGLE_API_CONFIG environment variable is unset. + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": ""}): + with mock.patch.dict(os.environ, {"GOOGLE_API_CERTIFICATE_CONFIG": ""}): + assert FirestoreClient._use_client_cert_effective() is False + + def test__get_client_cert_source(): mock_provided_cert_source = mock.Mock() mock_default_cert_source = mock.Mock() @@ -557,17 +663,6 @@ def test_firestore_client_client_options(client_class, transport_class, transpor == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" ) - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): - with pytest.raises(ValueError) as excinfo: - client = client_class(transport=transport_name) - assert ( - str(excinfo.value) - == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") with mock.patch.object(transport_class, "__init__") as patched: @@ -779,6 +874,119 @@ def test_firestore_client_get_mtls_endpoint_and_cert_source(client_class): assert api_endpoint == mock_api_endpoint assert cert_source is None + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "Unsupported". + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, + api_endpoint=mock_api_endpoint, + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test cases for mTLS enablement when GOOGLE_API_USE_CLIENT_CERTIFICATE is unset. + test_cases = [ + ( + # With workloads present in config, mTLS is enabled. + { + "version": 1, + "cert_configs": { + "workload": { + "cert_path": "path/to/cert/file", + "key_path": "path/to/key/file", + } + }, + }, + mock_client_cert_source, + ), + ( + # With workloads not present in config, mTLS is disabled. + { + "version": 1, + "cert_configs": {}, + }, + None, + ), + ] + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + for config_data, expected_cert_source in test_cases: + env = os.environ.copy() + env.pop("GOOGLE_API_USE_CLIENT_CERTIFICATE", None) + with mock.patch.dict(os.environ, env, clear=True): + config_filename = "mock_certificate_config.json" + config_file_content = json.dumps(config_data) + m = mock.mock_open(read_data=config_file_content) + with mock.patch("builtins.open", m): + with mock.patch.dict( + os.environ, {"GOOGLE_API_CERTIFICATE_CONFIG": config_filename} + ): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, + api_endpoint=mock_api_endpoint, + ) + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source is expected_cert_source + + # Test cases for mTLS enablement when GOOGLE_API_USE_CLIENT_CERTIFICATE is unset(empty). + test_cases = [ + ( + # With workloads present in config, mTLS is enabled. + { + "version": 1, + "cert_configs": { + "workload": { + "cert_path": "path/to/cert/file", + "key_path": "path/to/key/file", + } + }, + }, + mock_client_cert_source, + ), + ( + # With workloads not present in config, mTLS is disabled. + { + "version": 1, + "cert_configs": {}, + }, + None, + ), + ] + if hasattr(google.auth.transport.mtls, "should_use_client_cert"): + for config_data, expected_cert_source in test_cases: + env = os.environ.copy() + env.pop("GOOGLE_API_USE_CLIENT_CERTIFICATE", "") + with mock.patch.dict(os.environ, env, clear=True): + config_filename = "mock_certificate_config.json" + config_file_content = json.dumps(config_data) + m = mock.mock_open(read_data=config_file_content) + with mock.patch("builtins.open", m): + with mock.patch.dict( + os.environ, {"GOOGLE_API_CERTIFICATE_CONFIG": config_filename} + ): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, + api_endpoint=mock_api_endpoint, + ) + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source is expected_cert_source + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() @@ -829,18 +1037,6 @@ def test_firestore_client_get_mtls_endpoint_and_cert_source(client_class): == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" ) - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): - with pytest.raises(ValueError) as excinfo: - client_class.get_mtls_endpoint_and_cert_source() - - assert ( - str(excinfo.value) - == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" - ) - @pytest.mark.parametrize("client_class", [FirestoreClient, FirestoreAsyncClient]) @mock.patch.object( @@ -7703,7 +7899,7 @@ def test_execute_pipeline_rest_required_fields( iter_content.return_value = iter(json_return_value) response = client.execute_pipeline(request) - expected_params = [] + expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -12590,6 +12786,7 @@ def test_firestore_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.filterwarnings("ignore::FutureWarning") @pytest.mark.parametrize( "transport_class", [transports.FirestoreGrpcTransport, transports.FirestoreGrpcAsyncIOTransport], diff --git a/tests/unit/gapic/v1/__init__.py b/tests/unit/gapic/v1/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index 210aae88d..3aeef8f9f 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -374,6 +374,9 @@ async def test_asyncclient_get_all_read_time(): @pytest.mark.asyncio +@pytest.mark.filterwarnings( + "ignore:coroutine method 'aclose' of 'AsyncIter' was never awaited:RuntimeWarning" +) async def test_asyncclient_get_all_unknown_result(): from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE From 631bda89e2c1de20467e316faddda9f0f9114c0b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 9 Jan 2026 14:58:27 -0800 Subject: [PATCH 17/27] updated filtered warnings --- pytest.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/pytest.ini b/pytest.ini index 308d1b494..7cd904ecc 100644 --- a/pytest.ini +++ b/pytest.ini @@ -24,3 +24,4 @@ filterwarnings = ignore:.*\'asyncio.iscoroutinefunction\' is deprecated.*:DeprecationWarning ignore:.*\'asyncio.get_event_loop_policy\' is deprecated.*:DeprecationWarning ignore:.*Please upgrade to the latest Python version.*:FutureWarning + ignore:(?s).*using a Python version.*past its end of life.*:FutureWarning From 4ee909a39582e7a25acb15ca5f90becf0eb14711 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 9 Jan 2026 16:16:15 -0800 Subject: [PATCH 18/27] removed duplicate _BaseExecutePipeline --- .../firestore/transports/rest_base.py | 56 ------------------- 1 file changed, 56 deletions(-) diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py index 7d0c52f94..4b5ce5319 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py @@ -474,62 +474,6 @@ def _get_query_params_json(transcoded_request): ) ) - return query_params - - class _BaseExecutePipeline: - def __hash__(self): # pragma: NO COVER - return NotImplementedError("__hash__ must be implemented.") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return { - k: v - for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() - if k not in message_dict - } - - @staticmethod - def _get_http_options(): - http_options: List[Dict[str, str]] = [ - { - "method": "post", - "uri": "/v1/{database=projects/*/databases/*}/documents:executePipeline", - "body": "*", - }, - ] - return http_options - - @staticmethod - def _get_transcoded_request(http_options, request): - pb_request = firestore.ExecutePipelineRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) - return transcoded_request - - @staticmethod - def _get_request_body_json(transcoded_request): - # Jsonify the request body - - body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True - ) - return body - - @staticmethod - def _get_query_params_json(transcoded_request): - query_params = json.loads( - json_format.MessageToJson( - transcoded_request["query_params"], - use_integers_for_enums=True, - ) - ) - query_params.update( - _BaseFirestoreRestTransport._BaseExecutePipeline._get_unset_required_fields( - query_params - ) - ) - query_params["$alt"] = "json;enum-encoding=int" return query_params From bbe8f4580807f8ff2593953a2a2d39b44a642e10 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 13 Jan 2026 13:47:49 -0800 Subject: [PATCH 19/27] chore(tests): re-enable pipeline system tests on kokoro (#1153) Kokoro tests for pipelines were previously disabled until the backend supports the feature. This branch will re-enable those tests, when the backend is ready Aslo removing index_mode, since this feature was pushed back to a future release --- google/cloud/firestore_v1/async_pipeline.py | 10 ++----- google/cloud/firestore_v1/pipeline.py | 10 ++----- google/cloud/firestore_v1/pipeline_result.py | 4 --- tests/system/test__helpers.py | 4 +-- tests/system/test_pipeline_acceptance.py | 8 +----- tests/system/test_system.py | 29 -------------------- tests/system/test_system_async.py | 13 --------- tests/unit/v1/test_pipeline_result.py | 12 -------- 8 files changed, 6 insertions(+), 84 deletions(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 6b017d88e..5e3748b0f 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -68,7 +68,6 @@ async def execute( transaction: "AsyncTransaction" | None = None, read_time: datetime.datetime | None = None, explain_options: PipelineExplainOptions | None = None, - index_mode: str | None = None, additional_options: dict[str, Value | Constant] = {}, ) -> PipelineSnapshot[PipelineResult]: """ @@ -87,10 +86,8 @@ async def execute( explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.PipelineExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned list. - index_mode (Optional[str]): Configures the pipeline to require a certain type of indexes to be present. - Firestore will reject the request if there is not appropiate indexes to serve the query. additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query. - These options will take precedence over method argument if there is a conflict (e.g. explain_options, index_mode) + These options will take precedence over method argument if there is a conflict (e.g. explain_options) """ kwargs = {k: v for k, v in locals().items() if k != "self"} stream = AsyncPipelineStream(PipelineResult, self, **kwargs) @@ -103,7 +100,6 @@ def stream( read_time: datetime.datetime | None = None, transaction: "AsyncTransaction" | None = None, explain_options: PipelineExplainOptions | None = None, - index_mode: str | None = None, additional_options: dict[str, Value | Constant] = {}, ) -> AsyncPipelineStream[PipelineResult]: """ @@ -122,10 +118,8 @@ def stream( explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.PipelineExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. - index_mode (Optional[str]): Configures the pipeline to require a certain type of indexes to be present. - Firestore will reject the request if there is not appropiate indexes to serve the query. additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query. - These options will take precedence over method argument if there is a conflict (e.g. explain_options, index_mode) + These options will take precedence over method argument if there is a conflict (e.g. explain_options) """ kwargs = {k: v for k, v in locals().items() if k != "self"} return AsyncPipelineStream(PipelineResult, self, **kwargs) diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 950eb6ffa..0c922ba78 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -65,7 +65,6 @@ def execute( transaction: "Transaction" | None = None, read_time: datetime.datetime | None = None, explain_options: PipelineExplainOptions | None = None, - index_mode: str | None = None, additional_options: dict[str, Value | Constant] = {}, ) -> PipelineSnapshot[PipelineResult]: """ @@ -84,10 +83,8 @@ def execute( explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.PipelineExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned list. - index_mode (Optional[str]): Configures the pipeline to require a certain type of indexes to be present. - Firestore will reject the request if there is not appropiate indexes to serve the query. additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query. - These options will take precedence over method argument if there is a conflict (e.g. explain_options, index_mode) + These options will take precedence over method argument if there is a conflict (e.g. explain_options) """ kwargs = {k: v for k, v in locals().items() if k != "self"} stream = PipelineStream(PipelineResult, self, **kwargs) @@ -100,7 +97,6 @@ def stream( transaction: "Transaction" | None = None, read_time: datetime.datetime | None = None, explain_options: PipelineExplainOptions | None = None, - index_mode: str | None = None, additional_options: dict[str, Value | Constant] = {}, ) -> PipelineStream[PipelineResult]: """ @@ -119,10 +115,8 @@ def stream( explain_options (Optional[:class:`~google.cloud.firestore_v1.query_profile.PipelineExplainOptions`]): Options to enable query profiling for this query. When set, explain_metrics will be available on the returned generator. - index_mode (Optional[str]): Configures the pipeline to require a certain type of indexes to be present. - Firestore will reject the request if there is not appropiate indexes to serve the query. additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query. - These options will take precedence over method argument if there is a conflict (e.g. explain_options, index_mode) + These options will take precedence over method argument if there is a conflict (e.g. explain_options) """ kwargs = {k: v for k, v in locals().items() if k != "self"} return PipelineStream(PipelineResult, self, **kwargs) diff --git a/google/cloud/firestore_v1/pipeline_result.py b/google/cloud/firestore_v1/pipeline_result.py index 6be08fa57..923010815 100644 --- a/google/cloud/firestore_v1/pipeline_result.py +++ b/google/cloud/firestore_v1/pipeline_result.py @@ -178,7 +178,6 @@ def __init__( transaction: Transaction | AsyncTransaction | None, read_time: datetime.datetime | None, explain_options: PipelineExplainOptions | None, - index_mode: str | None, additional_options: dict[str, Constant | Value], ): # public @@ -192,7 +191,6 @@ def __init__( self._explain_stats: ExplainStats | None = None self._explain_options: PipelineExplainOptions | None = explain_options self._return_type = return_type - self._index_mode = index_mode self._additonal_options = { k: v if isinstance(v, Value) else v._to_pb() for k, v in additional_options.items() @@ -226,8 +224,6 @@ def _build_request(self) -> ExecutePipelineRequest: options = {} if self._explain_options: options["explain_options"] = self._explain_options._to_value() - if self._index_mode: - options["index_mode"] = Value(string_value=self._index_mode) if self._additonal_options: options.update(self._additonal_options) request = ExecutePipelineRequest( diff --git a/tests/system/test__helpers.py b/tests/system/test__helpers.py index 74b12b7c3..8032ae119 100644 --- a/tests/system/test__helpers.py +++ b/tests/system/test__helpers.py @@ -15,10 +15,8 @@ EMULATOR_CREDS = EmulatorCreds() FIRESTORE_EMULATOR = os.environ.get(_FIRESTORE_EMULATOR_HOST) is not None FIRESTORE_OTHER_DB = os.environ.get("SYSTEM_TESTS_DATABASE", "system-tests-named-db") -FIRESTORE_ENTERPRISE_DB = os.environ.get("ENTERPRISE_DATABASE", "enterprise-db") +FIRESTORE_ENTERPRISE_DB = os.environ.get("ENTERPRISE_DATABASE", "enterprise-db-native") # run all tests against default database, and a named database TEST_DATABASES = [None, FIRESTORE_OTHER_DB] TEST_DATABASES_W_ENTERPRISE = TEST_DATABASES + [FIRESTORE_ENTERPRISE_DB] -# TODO remove when kokoro fully supports enterprise mode/pipelines -IS_KOKORO_TEST = os.getenv("KOKORO_JOB_NAME") is not None diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 02a27ca86..4634037ab 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -33,16 +33,10 @@ from google.cloud.firestore import Client, AsyncClient -from test__helpers import FIRESTORE_ENTERPRISE_DB, IS_KOKORO_TEST +from test__helpers import FIRESTORE_ENTERPRISE_DB FIRESTORE_PROJECT = os.environ.get("GCLOUD_PROJECT") -# TODO: enable kokoro tests when internal test project is whitelisted -pytestmark = pytest.mark.skipif( - condition=IS_KOKORO_TEST, - reason="Pipeline tests are currently not supported by kokoro", -) - test_dir_name = os.path.dirname(__file__) id_format = ( diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 615ff1226..328b29098 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -46,7 +46,6 @@ ENTERPRISE_MODE_ERROR, TEST_DATABASES, TEST_DATABASES_W_ENTERPRISE, - IS_KOKORO_TEST, FIRESTORE_ENTERPRISE_DB, ) @@ -67,12 +66,6 @@ def _get_credentials_and_project(): @pytest.fixture(scope="session") def database(request): - from test__helpers import FIRESTORE_ENTERPRISE_DB - - # enterprise mode currently does not support RunQuery calls in prod on kokoro test project - # TODO: remove skip when kokoro test project supports full enterprise mode - if request.param == FIRESTORE_ENTERPRISE_DB and IS_KOKORO_TEST: - pytest.skip("enterprise mode does not support RunQuery on kokoro") return request.param @@ -101,11 +94,6 @@ def verify_pipeline(query): """ from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery - # return early on kokoro. Test project doesn't currently support pipelines - # TODO: enable pipeline verification when kokoro test project is whitelisted - if IS_KOKORO_TEST: - pytest.skip("skipping pipeline verification on kokoro") - def _clean_results(results): if isinstance(results, dict): return {k: _clean_results(v) for k, v in results.items()} @@ -1771,22 +1759,6 @@ def test_pipeline_explain_options_using_additional_options( assert "Execution:" in text_stats -@pytest.mark.skipif( - FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." -) -@pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) -def test_pipeline_index_mode(database, query_docs): - """test pipeline query with explicit index mode""" - - collection, _, allowed_vals = query_docs - client = collection._client - query = collection.where(filter=FieldFilter("a", "==", 1)) - pipeline = client.pipeline().create_from(query) - with pytest.raises(InvalidArgument) as e: - pipeline.execute(index_mode="fake_index") - assert "Invalid index_mode: fake_index" in str(e) - - @pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_read_time(query_docs, cleanup, database): collection, stored, allowed_vals = query_docs @@ -1825,7 +1797,6 @@ def test_query_stream_w_read_time(query_docs, cleanup, database): assert new_values[new_ref.id] == new_data -@pytest.mark.skipif(IS_KOKORO_TEST, reason="skipping pipeline verification on kokoro") @pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) def test_pipeline_w_read_time(query_docs, cleanup, database): collection, stored, allowed_vals = query_docs diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 373c40118..1aaa79591 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -57,7 +57,6 @@ ENTERPRISE_MODE_ERROR, TEST_DATABASES, TEST_DATABASES_W_ENTERPRISE, - IS_KOKORO_TEST, FIRESTORE_ENTERPRISE_DB, ) @@ -145,12 +144,6 @@ def _verify_explain_metrics_analyze_false(explain_metrics): @pytest.fixture(scope="session") def database(request): - from test__helpers import FIRESTORE_ENTERPRISE_DB - - # enterprise mode currently does not support RunQuery calls in prod on kokoro test project - # TODO: remove skip when kokoro test project supports full enterprise mode - if request.param == FIRESTORE_ENTERPRISE_DB and IS_KOKORO_TEST: - pytest.skip("enterprise mode does not support RunQuery on kokoro") return request.param @@ -181,11 +174,6 @@ async def verify_pipeline(query): """ from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery - # return early on kokoro. Test project doesn't currently support pipelines - # TODO: enable pipeline verification when kokoro test project is whitelisted - if IS_KOKORO_TEST: - pytest.skip("skipping pipeline verification on kokoro") - def _clean_results(results): if isinstance(results, dict): return {k: _clean_results(v) for k, v in results.items()} @@ -1694,7 +1682,6 @@ async def test_pipeline_explain_options_using_additional_options( assert "Execution:" in text_stats -@pytest.mark.skipif(IS_KOKORO_TEST, reason="skipping pipeline verification on kokoro") @pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) async def test_pipeline_w_read_time(query_docs, cleanup, database): collection, stored, allowed_vals = query_docs diff --git a/tests/unit/v1/test_pipeline_result.py b/tests/unit/v1/test_pipeline_result.py index 579992741..eca622801 100644 --- a/tests/unit/v1/test_pipeline_result.py +++ b/tests/unit/v1/test_pipeline_result.py @@ -213,7 +213,6 @@ def test_ctor(self): expected_transaction = object() expected_read_time = 123 expected_explain_options = object() - expected_index_mode = "mode" expected_addtl_options = {} source = PipelineStream( expected_type, @@ -221,7 +220,6 @@ def test_ctor(self): expected_transaction, expected_read_time, expected_explain_options, - expected_index_mode, expected_addtl_options, ) instance = self._make_one(in_arr, source) @@ -229,7 +227,6 @@ def test_ctor(self): assert instance.pipeline == expected_pipeline assert instance._client == expected_pipeline._client assert instance._additonal_options == expected_addtl_options - assert instance._index_mode == expected_index_mode assert instance._explain_options == expected_explain_options assert instance._explain_stats is None assert instance._started is True @@ -281,7 +278,6 @@ def _mock_init_args(self): "transaction": None, "read_time": None, "explain_options": None, - "index_mode": None, "additional_options": {}, } @@ -312,7 +308,6 @@ def test_explain_stats(self): @pytest.mark.parametrize( "init_kwargs,expected_options", [ - ({"index_mode": "mode"}, {"index_mode": encode_value("mode")}), ( {"explain_options": PipelineExplainOptions()}, {"explain_options": encode_value({"mode": "analyze"})}, @@ -336,13 +331,6 @@ def test_explain_stats(self): }, {"explain_options": encode_value("override")}, ), - ( - { - "index_mode": "mode", - "additional_options": {"index_mode": Constant("new")}, - }, - {"index_mode": encode_value("new")}, - ), ], ) def test_build_request_options(self, init_kwargs, expected_options): From 1e0ec961f36318299c06663cf58ec3f44a37439b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 14 Jan 2026 08:37:16 -0800 Subject: [PATCH 20/27] chore: rename Function to FunctionExpression (#1155) Rename Function to FunctionExpression, and add a warning that other changes in the API surface may come before GA --- google/cloud/firestore_v1/async_pipeline.py | 9 ++ google/cloud/firestore_v1/pipeline.py | 9 ++ .../firestore_v1/pipeline_expressions.py | 139 ++++++++++-------- google/cloud/firestore_v1/pipeline_result.py | 5 + google/cloud/firestore_v1/pipeline_source.py | 5 + google/cloud/firestore_v1/pipeline_stages.py | 5 + tests/system/pipeline_e2e/aggregates.yaml | 34 ++--- tests/system/pipeline_e2e/array.yaml | 36 ++--- tests/system/pipeline_e2e/date_and_time.yaml | 26 ++-- tests/system/pipeline_e2e/general.yaml | 24 +-- tests/system/pipeline_e2e/logical.yaml | 74 +++++----- tests/system/pipeline_e2e/map.yaml | 20 +-- tests/system/pipeline_e2e/math.yaml | 48 +++--- tests/system/pipeline_e2e/string.yaml | 72 ++++----- tests/system/pipeline_e2e/vector.yaml | 14 +- tests/unit/v1/test_pipeline_expressions.py | 10 +- 16 files changed, 287 insertions(+), 243 deletions(-) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py index 5e3748b0f..d476cc283 100644 --- a/google/cloud/firestore_v1/async_pipeline.py +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -11,6 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +.. warning:: + **Preview API**: Firestore Pipelines is currently in preview and is + subject to potential breaking changes in future releases +""" from __future__ import annotations from typing import TYPE_CHECKING @@ -50,6 +55,10 @@ class AsyncPipeline(_BasePipeline): ... print(result) Use `client.pipeline()` to create instances of this class. + + .. warning:: + **Preview API**: Firestore Pipelines is currently in preview and is + subject to potential breaking changes in future releases """ def __init__(self, client: AsyncClient, *stages: stages.Stage): diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py index 0c922ba78..bce43fc86 100644 --- a/google/cloud/firestore_v1/pipeline.py +++ b/google/cloud/firestore_v1/pipeline.py @@ -11,6 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +.. warning:: + **Preview API**: Firestore Pipelines is currently in preview and is + subject to potential breaking changes in future releases. +""" from __future__ import annotations from typing import TYPE_CHECKING @@ -47,6 +52,10 @@ class Pipeline(_BasePipeline): ... print(result) Use `client.pipeline()` to create instances of this class. + + .. warning:: + **Preview API**: Firestore Pipelines is currently in preview and is + subject to potential breaking changes in future releases. """ def __init__(self, client: Client, *stages: stages.Stage): diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 7e86ef6eb..c0ff3923a 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -11,6 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +.. warning:: + **Preview API**: Firestore Pipelines is currently in preview and is + subject to potential breaking changes in future releases. +""" from __future__ import annotations from typing import ( @@ -95,7 +100,7 @@ class Expression(ABC): - **Field references:** Access values from document fields. - **Literals:** Represent constant values (strings, numbers, booleans). - - **Function calls:** Apply functions to one or more expressions. + - **FunctionExpression calls:** Apply functions to one or more expressions. - **Aggregations:** Calculate aggregate values (e.g., sum, average) over a set of documents. The `Expression` class provides a fluent API for building expressions. You can chain @@ -134,7 +139,7 @@ class expose_as_static: Example: >>> Field.of("test").add(5) - >>> Function.add("test", 5) + >>> FunctionExpression.add("test", 5) """ def __init__(self, instance_func): @@ -174,7 +179,9 @@ def add(self, other: Expression | float) -> "Expression": Returns: A new `Expression` representing the addition operation. """ - return Function("add", [self, self._cast_to_expr_or_convert_to_constant(other)]) + return FunctionExpression( + "add", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) @expose_as_static def subtract(self, other: Expression | float) -> "Expression": @@ -192,7 +199,7 @@ def subtract(self, other: Expression | float) -> "Expression": Returns: A new `Expression` representing the subtraction operation. """ - return Function( + return FunctionExpression( "subtract", [self, self._cast_to_expr_or_convert_to_constant(other)] ) @@ -212,7 +219,7 @@ def multiply(self, other: Expression | float) -> "Expression": Returns: A new `Expression` representing the multiplication operation. """ - return Function( + return FunctionExpression( "multiply", [self, self._cast_to_expr_or_convert_to_constant(other)] ) @@ -232,7 +239,7 @@ def divide(self, other: Expression | float) -> "Expression": Returns: A new `Expression` representing the division operation. """ - return Function( + return FunctionExpression( "divide", [self, self._cast_to_expr_or_convert_to_constant(other)] ) @@ -252,7 +259,9 @@ def mod(self, other: Expression | float) -> "Expression": Returns: A new `Expression` representing the modulo operation. """ - return Function("mod", [self, self._cast_to_expr_or_convert_to_constant(other)]) + return FunctionExpression( + "mod", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) @expose_as_static def abs(self) -> "Expression": @@ -265,7 +274,7 @@ def abs(self) -> "Expression": Returns: A new `Expression` representing the absolute value. """ - return Function("abs", [self]) + return FunctionExpression("abs", [self]) @expose_as_static def ceil(self) -> "Expression": @@ -278,7 +287,7 @@ def ceil(self) -> "Expression": Returns: A new `Expression` representing the ceiling value. """ - return Function("ceil", [self]) + return FunctionExpression("ceil", [self]) @expose_as_static def exp(self) -> "Expression": @@ -291,7 +300,7 @@ def exp(self) -> "Expression": Returns: A new `Expression` representing the exponential value. """ - return Function("exp", [self]) + return FunctionExpression("exp", [self]) @expose_as_static def floor(self) -> "Expression": @@ -304,7 +313,7 @@ def floor(self) -> "Expression": Returns: A new `Expression` representing the floor value. """ - return Function("floor", [self]) + return FunctionExpression("floor", [self]) @expose_as_static def ln(self) -> "Expression": @@ -317,7 +326,7 @@ def ln(self) -> "Expression": Returns: A new `Expression` representing the natural logarithm. """ - return Function("ln", [self]) + return FunctionExpression("ln", [self]) @expose_as_static def log(self, base: Expression | float) -> "Expression": @@ -335,7 +344,9 @@ def log(self, base: Expression | float) -> "Expression": Returns: A new `Expression` representing the logarithm. """ - return Function("log", [self, self._cast_to_expr_or_convert_to_constant(base)]) + return FunctionExpression( + "log", [self, self._cast_to_expr_or_convert_to_constant(base)] + ) @expose_as_static def log10(self) -> "Expression": @@ -347,7 +358,7 @@ def log10(self) -> "Expression": Returns: A new `Expression` representing the logarithm. """ - return Function("log10", [self]) + return FunctionExpression("log10", [self]) @expose_as_static def pow(self, exponent: Expression | float) -> "Expression": @@ -365,7 +376,7 @@ def pow(self, exponent: Expression | float) -> "Expression": Returns: A new `Expression` representing the power operation. """ - return Function( + return FunctionExpression( "pow", [self, self._cast_to_expr_or_convert_to_constant(exponent)] ) @@ -380,7 +391,7 @@ def round(self) -> "Expression": Returns: A new `Expression` representing the rounded value. """ - return Function("round", [self]) + return FunctionExpression("round", [self]) @expose_as_static def sqrt(self) -> "Expression": @@ -393,7 +404,7 @@ def sqrt(self) -> "Expression": Returns: A new `Expression` representing the square root. """ - return Function("sqrt", [self]) + return FunctionExpression("sqrt", [self]) @expose_as_static def logical_maximum(self, *others: Expression | CONSTANT_TYPE) -> "Expression": @@ -415,7 +426,7 @@ def logical_maximum(self, *others: Expression | CONSTANT_TYPE) -> "Expression": Returns: A new `Expression` representing the logical maximum operation. """ - return Function( + return FunctionExpression( "maximum", [self] + [self._cast_to_expr_or_convert_to_constant(o) for o in others], infix_name_override="logical_maximum", @@ -441,7 +452,7 @@ def logical_minimum(self, *others: Expression | CONSTANT_TYPE) -> "Expression": Returns: A new `Expression` representing the logical minimum operation. """ - return Function( + return FunctionExpression( "minimum", [self] + [self._cast_to_expr_or_convert_to_constant(o) for o in others], infix_name_override="logical_minimum", @@ -630,7 +641,7 @@ def not_equal_any( ) @expose_as_static - def array_get(self, offset: Expression | int) -> "Function": + def array_get(self, offset: Expression | int) -> "FunctionExpression": """ Creates an expression that indexes into an array from the beginning or end and returns the element. A negative offset starts from the end. @@ -644,7 +655,7 @@ def array_get(self, offset: Expression | int) -> "Function": Returns: A new `Expression` representing the `array_get` operation. """ - return Function( + return FunctionExpression( "array_get", [self, self._cast_to_expr_or_convert_to_constant(offset)] ) @@ -736,7 +747,7 @@ def array_length(self) -> "Expression": Returns: A new `Expression` representing the length of the array. """ - return Function("array_length", [self]) + return FunctionExpression("array_length", [self]) @expose_as_static def array_reverse(self) -> "Expression": @@ -749,7 +760,7 @@ def array_reverse(self) -> "Expression": Returns: A new `Expression` representing the reversed array. """ - return Function("array_reverse", [self]) + return FunctionExpression("array_reverse", [self]) @expose_as_static def array_concat( @@ -767,7 +778,7 @@ def array_concat( Returns: A new `Expression` representing the concatenated array. """ - return Function( + return FunctionExpression( "array_concat", [self] + [self._cast_to_expr_or_convert_to_constant(arr) for arr in other_arrays], @@ -783,7 +794,7 @@ def concat(self, *others: Expression | CONSTANT_TYPE) -> "Expression": Returns: A new `Expression` representing the concatenated value. """ - return Function( + return FunctionExpression( "concat", [self] + [self._cast_to_expr_or_convert_to_constant(o) for o in others], ) @@ -800,7 +811,7 @@ def length(self) -> "Expression": Returns: A new `Expression` representing the length of the expression. """ - return Function("length", [self]) + return FunctionExpression("length", [self]) @expose_as_static def is_absent(self) -> "BooleanExpression": @@ -830,7 +841,7 @@ def if_absent(self, default_value: Expression | CONSTANT_TYPE) -> "Expression": Returns: A new `Expression` representing the ifAbsent operation. """ - return Function( + return FunctionExpression( "if_absent", [self, self._cast_to_expr_or_convert_to_constant(default_value)], ) @@ -846,7 +857,7 @@ def is_error(self): Returns: A new `Expression` representing the isError operation. """ - return Function("is_error", [self]) + return FunctionExpression("is_error", [self]) @expose_as_static def if_error(self, then_value: Expression | CONSTANT_TYPE) -> "Expression": @@ -863,7 +874,7 @@ def if_error(self, then_value: Expression | CONSTANT_TYPE) -> "Expression": Returns: A new `Expression` representing the ifError operation. """ - return Function( + return FunctionExpression( "if_error", [self, self._cast_to_expr_or_convert_to_constant(then_value)] ) @@ -987,7 +998,7 @@ def char_length(self) -> "Expression": Returns: A new `Expression` representing the length of the string. """ - return Function("char_length", [self]) + return FunctionExpression("char_length", [self]) @expose_as_static def byte_length(self) -> "Expression": @@ -1000,7 +1011,7 @@ def byte_length(self) -> "Expression": Returns: A new `Expression` representing the byte length of the string. """ - return Function("byte_length", [self]) + return FunctionExpression("byte_length", [self]) @expose_as_static def like(self, pattern: Expression | str) -> "BooleanExpression": @@ -1138,7 +1149,7 @@ def string_concat(self, *elements: Expression | CONSTANT_TYPE) -> "Expression": Returns: A new `Expression` representing the concatenated string. """ - return Function( + return FunctionExpression( "string_concat", [self] + [self._cast_to_expr_or_convert_to_constant(el) for el in elements], ) @@ -1154,7 +1165,7 @@ def to_lower(self) -> "Expression": Returns: A new `Expression` representing the lowercase string. """ - return Function("to_lower", [self]) + return FunctionExpression("to_lower", [self]) @expose_as_static def to_upper(self) -> "Expression": @@ -1167,7 +1178,7 @@ def to_upper(self) -> "Expression": Returns: A new `Expression` representing the uppercase string. """ - return Function("to_upper", [self]) + return FunctionExpression("to_upper", [self]) @expose_as_static def trim(self) -> "Expression": @@ -1180,7 +1191,7 @@ def trim(self) -> "Expression": Returns: A new `Expression` representing the trimmed string. """ - return Function("trim", [self]) + return FunctionExpression("trim", [self]) @expose_as_static def string_reverse(self) -> "Expression": @@ -1193,7 +1204,7 @@ def string_reverse(self) -> "Expression": Returns: A new `Expression` representing the reversed string. """ - return Function("string_reverse", [self]) + return FunctionExpression("string_reverse", [self]) @expose_as_static def substring( @@ -1217,7 +1228,7 @@ def substring( args = [self, self._cast_to_expr_or_convert_to_constant(position)] if length is not None: args.append(self._cast_to_expr_or_convert_to_constant(length)) - return Function("substring", args) + return FunctionExpression("substring", args) @expose_as_static def join(self, delimeter: Expression | str) -> "Expression": @@ -1233,7 +1244,7 @@ def join(self, delimeter: Expression | str) -> "Expression": Returns: A new `Expression` representing the joined string. """ - return Function( + return FunctionExpression( "join", [self, self._cast_to_expr_or_convert_to_constant(delimeter)] ) @@ -1251,7 +1262,7 @@ def map_get(self, key: str | Constant[str]) -> "Expression": Returns: A new `Expression` representing the value associated with the given key in the map. """ - return Function( + return FunctionExpression( "map_get", [self, self._cast_to_expr_or_convert_to_constant(key)] ) @@ -1269,7 +1280,7 @@ def map_remove(self, key: str | Constant[str]) -> "Expression": Returns: A new `Expression` representing the map_remove operation. """ - return Function( + return FunctionExpression( "map_remove", [self, self._cast_to_expr_or_convert_to_constant(key)] ) @@ -1284,7 +1295,7 @@ def map_merge( Example: >>> Map({"city": "London"}).map_merge({"country": "UK"}, {"isCapital": True}) - >>> Field.of("settings").map_merge({"enabled":True}, Function.conditional(Field.of('isAdmin'), {"admin":True}, {}}) + >>> Field.of("settings").map_merge({"enabled":True}, FunctionExpression.conditional(Field.of('isAdmin'), {"admin":True}, {}}) Args: *other_maps: Sequence of maps to merge into the resulting map. @@ -1292,7 +1303,7 @@ def map_merge( Returns: A new `Expression` representing the value associated with the given key in the map. """ - return Function( + return FunctionExpression( "map_merge", [self] + [self._cast_to_expr_or_convert_to_constant(m) for m in other_maps], ) @@ -1313,7 +1324,7 @@ def cosine_distance(self, other: Expression | list[float] | Vector) -> "Expressi Returns: A new `Expression` representing the cosine distance between the two vectors. """ - return Function( + return FunctionExpression( "cosine_distance", [ self, @@ -1339,7 +1350,7 @@ def euclidean_distance( Returns: A new `Expression` representing the Euclidean distance between the two vectors. """ - return Function( + return FunctionExpression( "euclidean_distance", [ self, @@ -1363,7 +1374,7 @@ def dot_product(self, other: Expression | list[float] | Vector) -> "Expression": Returns: A new `Expression` representing the dot product between the two vectors. """ - return Function( + return FunctionExpression( "dot_product", [ self, @@ -1382,7 +1393,7 @@ def vector_length(self) -> "Expression": Returns: A new `Expression` representing the length of the vector. """ - return Function("vector_length", [self]) + return FunctionExpression("vector_length", [self]) @expose_as_static def timestamp_to_unix_micros(self) -> "Expression": @@ -1398,7 +1409,7 @@ def timestamp_to_unix_micros(self) -> "Expression": Returns: A new `Expression` representing the number of microseconds since the epoch. """ - return Function("timestamp_to_unix_micros", [self]) + return FunctionExpression("timestamp_to_unix_micros", [self]) @expose_as_static def unix_micros_to_timestamp(self) -> "Expression": @@ -1412,7 +1423,7 @@ def unix_micros_to_timestamp(self) -> "Expression": Returns: A new `Expression` representing the timestamp. """ - return Function("unix_micros_to_timestamp", [self]) + return FunctionExpression("unix_micros_to_timestamp", [self]) @expose_as_static def timestamp_to_unix_millis(self) -> "Expression": @@ -1428,7 +1439,7 @@ def timestamp_to_unix_millis(self) -> "Expression": Returns: A new `Expression` representing the number of milliseconds since the epoch. """ - return Function("timestamp_to_unix_millis", [self]) + return FunctionExpression("timestamp_to_unix_millis", [self]) @expose_as_static def unix_millis_to_timestamp(self) -> "Expression": @@ -1442,7 +1453,7 @@ def unix_millis_to_timestamp(self) -> "Expression": Returns: A new `Expression` representing the timestamp. """ - return Function("unix_millis_to_timestamp", [self]) + return FunctionExpression("unix_millis_to_timestamp", [self]) @expose_as_static def timestamp_to_unix_seconds(self) -> "Expression": @@ -1458,7 +1469,7 @@ def timestamp_to_unix_seconds(self) -> "Expression": Returns: A new `Expression` representing the number of seconds since the epoch. """ - return Function("timestamp_to_unix_seconds", [self]) + return FunctionExpression("timestamp_to_unix_seconds", [self]) @expose_as_static def unix_seconds_to_timestamp(self) -> "Expression": @@ -1472,7 +1483,7 @@ def unix_seconds_to_timestamp(self) -> "Expression": Returns: A new `Expression` representing the timestamp. """ - return Function("unix_seconds_to_timestamp", [self]) + return FunctionExpression("unix_seconds_to_timestamp", [self]) @expose_as_static def timestamp_add( @@ -1494,7 +1505,7 @@ def timestamp_add( Returns: A new `Expression` representing the resulting timestamp. """ - return Function( + return FunctionExpression( "timestamp_add", [ self, @@ -1523,7 +1534,7 @@ def timestamp_subtract( Returns: A new `Expression` representing the resulting timestamp. """ - return Function( + return FunctionExpression( "timestamp_subtract", [ self, @@ -1543,7 +1554,7 @@ def collection_id(self): Returns: A new `Expression` representing the collection ID. """ - return Function("collection_id", [self]) + return FunctionExpression("collection_id", [self]) @expose_as_static def document_id(self): @@ -1556,7 +1567,7 @@ def document_id(self): Returns: A new `Expression` representing the document ID. """ - return Function("document_id", [self]) + return FunctionExpression("document_id", [self]) def ascending(self) -> Ordering: """Creates an `Ordering` that sorts documents in ascending order based on this expression. @@ -1634,7 +1645,7 @@ def _to_pb(self) -> Value: return encode_value(self.value) -class Function(Expression): +class FunctionExpression(Expression): """A base class for expressions that represent function calls.""" def __init__( @@ -1652,7 +1663,7 @@ def __init__( def __repr__(self): """ - Most Functions can be triggered infix. Eg: Field.of('age').greater_than(18). + Most FunctionExpressions can be triggered infix. Eg: Field.of('age').greater_than(18). Display them this way in the repr string where possible """ @@ -1667,7 +1678,7 @@ def __repr__(self): return f"{self.__class__.__name__}({', '.join([repr(p) for p in self.params])})" def __eq__(self, other): - if not isinstance(other, Function): + if not isinstance(other, FunctionExpression): return False else: return other.name == self.name and other.params == self.params @@ -1681,7 +1692,7 @@ def _to_pb(self): ) -class AggregateFunction(Function): +class AggregateFunction(FunctionExpression): """A base class for aggregation functions that operate across multiple inputs.""" @@ -1778,7 +1789,7 @@ def _to_pb(self): return Value(field_reference_value=self.path) -class BooleanExpression(Function): +class BooleanExpression(FunctionExpression): """Filters the given data in some way.""" @staticmethod @@ -1845,7 +1856,7 @@ def _from_query_filter_pb(filter_pb, client): raise TypeError(f"Unexpected filter type: {type(filter_pb)}") -class Array(Function): +class Array(FunctionExpression): """ Creates an expression that creates a Firestore array value from an input list. @@ -1868,7 +1879,7 @@ def __repr__(self): return f"Array({self.params})" -class Map(Function): +class Map(FunctionExpression): """ Creates an expression that creates a Firestore map value from an input dict. @@ -2004,7 +2015,7 @@ def __init__(self, expression: Expression | None = None): super().__init__("count", expression_list, use_infix_repr=bool(expression_list)) -class CurrentTimestamp(Function): +class CurrentTimestamp(FunctionExpression): """Creates an expression that returns the current timestamp Returns: diff --git a/google/cloud/firestore_v1/pipeline_result.py b/google/cloud/firestore_v1/pipeline_result.py index 923010815..704811b94 100644 --- a/google/cloud/firestore_v1/pipeline_result.py +++ b/google/cloud/firestore_v1/pipeline_result.py @@ -11,6 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +.. warning:: + **Preview API**: Firestore Pipelines is currently in preview and is + subject to potential breaking changes in future releases. +""" from __future__ import annotations from typing import ( diff --git a/google/cloud/firestore_v1/pipeline_source.py b/google/cloud/firestore_v1/pipeline_source.py index 3fb73b365..8f3c0a626 100644 --- a/google/cloud/firestore_v1/pipeline_source.py +++ b/google/cloud/firestore_v1/pipeline_source.py @@ -11,6 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +.. warning:: + **Preview API**: Firestore Pipelines is currently in preview and is + subject to potential breaking changes in future releases. +""" from __future__ import annotations from typing import Generic, TypeVar, TYPE_CHECKING diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 37829465e..18aa27044 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -11,6 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +.. warning:: + **Preview API**: Firestore Pipelines is currently in preview and is + subject to potential breaking changes in future releases. +""" from __future__ import annotations from typing import Optional, Sequence, TYPE_CHECKING diff --git a/tests/system/pipeline_e2e/aggregates.yaml b/tests/system/pipeline_e2e/aggregates.yaml index 9593213ed..64a42698b 100644 --- a/tests/system/pipeline_e2e/aggregates.yaml +++ b/tests/system/pipeline_e2e/aggregates.yaml @@ -4,7 +4,7 @@ tests: - Collection: books - Aggregate: - AliasedExpression: - - Function.count: + - FunctionExpression.count: - Field: rating - "count" assert_results: @@ -30,8 +30,8 @@ tests: - Collection: books - Aggregate: - AliasedExpression: - - Function.count_if: - - Function.greater_than: + - FunctionExpression.count_if: + - FunctionExpression.greater_than: - Field: rating - Constant: 4.2 - "count_if_rating_gt_4_2" @@ -62,7 +62,7 @@ tests: - Collection: books - Aggregate: - AliasedExpression: - - Function.count_distinct: + - FunctionExpression.count_distinct: - Field: genre - "distinct_genres" assert_results: @@ -87,20 +87,20 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: genre - Constant: Science Fiction - Aggregate: - AliasedExpression: - - Function.count: + - FunctionExpression.count: - Field: rating - "count" - AliasedExpression: - - Function.average: + - FunctionExpression.average: - Field: rating - "avg_rating" - AliasedExpression: - - Function.maximum: + - FunctionExpression.maximum: - Field: rating - "max_rating" assert_results: @@ -144,7 +144,7 @@ tests: pipeline: - Collection: books - Where: - - Function.less_than: + - FunctionExpression.less_than: - Field: published - Constant: 1900 - Aggregate: @@ -155,18 +155,18 @@ tests: pipeline: - Collection: books - Where: - - Function.less_than: + - FunctionExpression.less_than: - Field: published - Constant: 1984 - Aggregate: accumulators: - AliasedExpression: - - Function.average: + - FunctionExpression.average: - Field: rating - "avg_rating" groups: [genre] - Where: - - Function.greater_than: + - FunctionExpression.greater_than: - Field: avg_rating - Constant: 4.3 - Sort: @@ -226,15 +226,15 @@ tests: - Collection: books - Aggregate: - AliasedExpression: - - Function.count: + - FunctionExpression.count: - Field: rating - "count" - AliasedExpression: - - Function.maximum: + - FunctionExpression.maximum: - Field: rating - "max_rating" - AliasedExpression: - - Function.minimum: + - FunctionExpression.minimum: - Field: published - "min_published" assert_results: @@ -271,12 +271,12 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: genre - Constant: Science Fiction - Aggregate: - AliasedExpression: - - Function.sum: + - FunctionExpression.sum: - Field: rating - "total_rating" assert_results: diff --git a/tests/system/pipeline_e2e/array.yaml b/tests/system/pipeline_e2e/array.yaml index acdded36b..f82f1cbc1 100644 --- a/tests/system/pipeline_e2e/array.yaml +++ b/tests/system/pipeline_e2e/array.yaml @@ -3,7 +3,7 @@ tests: pipeline: - Collection: books - Where: - - Function.array_contains: + - FunctionExpression.array_contains: - Field: tags - Constant: comedy assert_results: @@ -33,7 +33,7 @@ tests: pipeline: - Collection: books - Where: - - Function.array_contains_any: + - FunctionExpression.array_contains_any: - Field: tags - - Constant: comedy - Constant: classic @@ -81,7 +81,7 @@ tests: pipeline: - Collection: books - Where: - - Function.array_contains_all: + - FunctionExpression.array_contains_all: - Field: tags - - Constant: adventure - Constant: magic @@ -117,11 +117,11 @@ tests: - Collection: books - Select: - AliasedExpression: - - Function.array_length: + - FunctionExpression.array_length: - Field: tags - "tagsCount" - Where: - - Function.equal: + - FunctionExpression.equal: - Field: tagsCount - Constant: 3 assert_results: # All documents have 3 tags @@ -161,12 +161,12 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: title - Constant: "The Hitchhiker's Guide to the Galaxy" - Select: - AliasedExpression: - - Function.array_reverse: + - FunctionExpression.array_reverse: - Field: tags - "reversedTags" assert_results: @@ -178,12 +178,12 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: title - Constant: "The Hitchhiker's Guide to the Galaxy" - Select: - AliasedExpression: - - Function.array_concat: + - FunctionExpression.array_concat: - Field: tags - ["new_tag", "another_tag"] - "concatenatedTags" @@ -225,12 +225,12 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: title - Constant: "Dune" - Select: - AliasedExpression: - - Function.array_concat: + - FunctionExpression.array_concat: - Field: tags - ["sci-fi"] - ["classic", "epic"] @@ -279,12 +279,12 @@ tests: - Collection: books - AddFields: - AliasedExpression: - - Function.array_concat: + - FunctionExpression.array_concat: - Field: tags - Array: ["Dystopian"] - "new_tags" - Where: - - Function.array_contains_any: + - FunctionExpression.array_contains_any: - Field: new_tags - - Constant: non_existent_tag - Field: genre @@ -352,7 +352,7 @@ tests: - Limit: 1 - Select: - AliasedExpression: - - Function.array_concat: + - FunctionExpression.array_concat: - Array: [1, 2, 3] - Array: [4, 5] - "concatenated" @@ -390,12 +390,12 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: title - Constant: "The Hitchhiker's Guide to the Galaxy" - Select: - AliasedExpression: - - Function.array_get: + - FunctionExpression.array_get: - Field: tags - Constant: 0 - "firstTag" @@ -428,12 +428,12 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: title - Constant: "The Hitchhiker's Guide to the Galaxy" - Select: - AliasedExpression: - - Function.array_get: + - FunctionExpression.array_get: - Field: tags - Constant: -1 - "lastTag" diff --git a/tests/system/pipeline_e2e/date_and_time.yaml b/tests/system/pipeline_e2e/date_and_time.yaml index cb5323dc1..2319b333b 100644 --- a/tests/system/pipeline_e2e/date_and_time.yaml +++ b/tests/system/pipeline_e2e/date_and_time.yaml @@ -6,13 +6,13 @@ tests: - Select: - AliasedExpression: - And: - - Function.greater_than_or_equal: + - FunctionExpression.greater_than_or_equal: - CurrentTimestamp: [] - - Function.unix_seconds_to_timestamp: + - FunctionExpression.unix_seconds_to_timestamp: - Constant: 1735689600 # 2025-01-01 - - Function.less_than: + - FunctionExpression.less_than: - CurrentTimestamp: [] - - Function.unix_seconds_to_timestamp: + - FunctionExpression.unix_seconds_to_timestamp: - Constant: 4892438400 # 2125-01-01 - "is_between_2025_and_2125" assert_results: @@ -52,42 +52,42 @@ tests: args: - integerValue: '4892438400' name: select - - description: testTimestampFunctions + - description: testTimestampFunctionExpressions pipeline: - Collection: timestamps - Select: - AliasedExpression: - - Function.timestamp_to_unix_micros: + - FunctionExpression.timestamp_to_unix_micros: - Field: time - "micros" - AliasedExpression: - - Function.timestamp_to_unix_millis: + - FunctionExpression.timestamp_to_unix_millis: - Field: time - "millis" - AliasedExpression: - - Function.timestamp_to_unix_seconds: + - FunctionExpression.timestamp_to_unix_seconds: - Field: time - "seconds" - AliasedExpression: - - Function.unix_micros_to_timestamp: + - FunctionExpression.unix_micros_to_timestamp: - Field: micros - "from_micros" - AliasedExpression: - - Function.unix_millis_to_timestamp: + - FunctionExpression.unix_millis_to_timestamp: - Field: millis - "from_millis" - AliasedExpression: - - Function.unix_seconds_to_timestamp: + - FunctionExpression.unix_seconds_to_timestamp: - Field: seconds - "from_seconds" - AliasedExpression: - - Function.timestamp_add: + - FunctionExpression.timestamp_add: - Field: time - Constant: "day" - Constant: 1 - "plus_day" - AliasedExpression: - - Function.timestamp_subtract: + - FunctionExpression.timestamp_subtract: - Field: time - Constant: "hour" - Constant: 1 diff --git a/tests/system/pipeline_e2e/general.yaml b/tests/system/pipeline_e2e/general.yaml index 8ff3f60d2..46a10cd4d 100644 --- a/tests/system/pipeline_e2e/general.yaml +++ b/tests/system/pipeline_e2e/general.yaml @@ -57,13 +57,13 @@ tests: - Collection: books - AddFields: - AliasedExpression: - - Function.string_concat: + - FunctionExpression.string_concat: - Field: author - Constant: _ - Field: title - "author_title" - AliasedExpression: - - Function.string_concat: + - FunctionExpression.string_concat: - Field: title - Constant: _ - Field: author @@ -227,14 +227,14 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: genre - Constant: Romance - Union: - Pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: genre - Constant: Dystopian - Select: @@ -299,12 +299,12 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: title - Constant: "The Hitchhiker's Guide to the Galaxy" - Select: - AliasedExpression: - - Function.document_id: + - FunctionExpression.document_id: - Field: __name__ - "doc_id" assert_results: @@ -337,7 +337,7 @@ tests: - Limit: 1 - Select: - AliasedExpression: - - Function.collection_id: + - FunctionExpression.collection_id: - Field: __name__ - "collectionName" assert_results: @@ -481,7 +481,7 @@ tests: - Select: - AliasedExpression: - Conditional: - - Function.greater_than_or_equal: + - FunctionExpression.greater_than_or_equal: - Field: count - Constant: 10 - Constant: True @@ -536,7 +536,7 @@ tests: reference_value: "/books" - RawStage: - "where" - - Function.equal: + - FunctionExpression.equal: - Field: title - Constant: The Hitchhiker's Guide to the Galaxy - RawStage: @@ -571,7 +571,7 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: title - Constant: The Hitchhiker's Guide to the Galaxy - Unnest: @@ -609,7 +609,7 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: title - Constant: The Hitchhiker's Guide to the Galaxy - Unnest: @@ -660,7 +660,7 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: title - Constant: "The Hitchhiker's Guide to the Galaxy" - ReplaceWith: diff --git a/tests/system/pipeline_e2e/logical.yaml b/tests/system/pipeline_e2e/logical.yaml index bbb71921b..296cfda14 100644 --- a/tests/system/pipeline_e2e/logical.yaml +++ b/tests/system/pipeline_e2e/logical.yaml @@ -4,10 +4,10 @@ tests: - Collection: books - Where: - And: - - Function.greater_than: + - FunctionExpression.greater_than: - Field: rating - Constant: 4.5 - - Function.equal: + - FunctionExpression.equal: - Field: genre - Constant: Science Fiction assert_results: @@ -49,10 +49,10 @@ tests: - Collection: books - Where: - Or: - - Function.equal: + - FunctionExpression.equal: - Field: genre - Constant: Romance - - Function.equal: + - FunctionExpression.equal: - Field: genre - Constant: Dystopian - Select: @@ -105,13 +105,13 @@ tests: - Collection: books - Where: - And: - - Function.greater_than: + - FunctionExpression.greater_than: - Field: rating - Constant: 4.2 - - Function.less_than_or_equal: + - FunctionExpression.less_than_or_equal: - Field: rating - Constant: 4.5 - - Function.not_equal: + - FunctionExpression.not_equal: - Field: genre - Constant: Science Fiction - Select: @@ -176,13 +176,13 @@ tests: - Where: - Or: - And: - - Function.greater_than: + - FunctionExpression.greater_than: - Field: rating - Constant: 4.5 - - Function.equal: + - FunctionExpression.equal: - Field: genre - Constant: Science Fiction - - Function.less_than: + - FunctionExpression.less_than: - Field: published - Constant: 1900 - Select: @@ -242,7 +242,7 @@ tests: pipeline: - Collection: errors - Where: - - Function.equal: + - FunctionExpression.equal: - Field: value - null assert_results: @@ -264,7 +264,7 @@ tests: pipeline: - Collection: errors - Where: - - Function.equal: + - FunctionExpression.equal: - Field: value - NaN assert_count: 1 @@ -285,7 +285,7 @@ tests: pipeline: - Collection: books - Where: - - Function.is_absent: + - FunctionExpression.is_absent: - Field: awards.pulitzer assert_count: 9 assert_proto: @@ -305,13 +305,13 @@ tests: - Collection: books - Select: - AliasedExpression: - - Function.if_absent: + - FunctionExpression.if_absent: - Field: awards.pulitzer - Constant: false - "pulitzer_award" - title - Where: - - Function.equal: + - FunctionExpression.equal: - Field: pulitzer_award - Constant: true assert_results: @@ -347,8 +347,8 @@ tests: - Collection: books - Select: - AliasedExpression: - - Function.is_error: - - Function.divide: + - FunctionExpression.is_error: + - FunctionExpression.divide: - Field: rating - Constant: "string" - "is_error_result" @@ -382,8 +382,8 @@ tests: - Collection: books - Select: - AliasedExpression: - - Function.if_error: - - Function.divide: + - FunctionExpression.if_error: + - FunctionExpression.divide: - Field: rating - Field: genre - Constant: "An error occurred" @@ -418,17 +418,17 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: author - Constant: Douglas Adams - Select: - AliasedExpression: - - Function.logical_maximum: + - FunctionExpression.logical_maximum: - Field: rating - Constant: 4.5 - "max_rating" - AliasedExpression: - - Function.logical_minimum: + - FunctionExpression.logical_minimum: - Field: published - Constant: 1900 - "min_published" @@ -468,19 +468,19 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: author - Constant: Douglas Adams - Select: - AliasedExpression: - - Function.logical_maximum: + - FunctionExpression.logical_maximum: - Field: rating - Constant: 4.5 - Constant: 3.0 - Constant: 5.0 - "max_rating" - AliasedExpression: - - Function.logical_minimum: + - FunctionExpression.logical_minimum: - Field: published - Constant: 1900 - Constant: 2000 @@ -526,7 +526,7 @@ tests: pipeline: - Collection: books - Where: - - Function.greater_than_or_equal: + - FunctionExpression.greater_than_or_equal: - Field: rating - Constant: 4.6 - Select: @@ -546,11 +546,11 @@ tests: - Collection: books - Where: - And: - - Function.equal_any: + - FunctionExpression.equal_any: - Field: genre - - Constant: Romance - Constant: Dystopian - - Function.not_equal_any: + - FunctionExpression.not_equal_any: - Field: author - - Constant: "George Orwell" assert_results: @@ -582,9 +582,9 @@ tests: - Collection: books - Where: - And: - - Function.exists: + - FunctionExpression.exists: - Field: awards.pulitzer - - Function.equal: + - FunctionExpression.equal: - Field: awards.pulitzer - Constant: true - Select: @@ -596,10 +596,10 @@ tests: - Collection: books - Where: - Xor: - - - Function.equal: + - - FunctionExpression.equal: - Field: genre - Constant: Romance - - Function.greater_than: + - FunctionExpression.greater_than: - Field: published - Constant: 1980 - Select: @@ -624,7 +624,7 @@ tests: - title - AliasedExpression: - Conditional: - - Function.greater_than: + - FunctionExpression.greater_than: - Field: published - Constant: 1950 - Constant: "Modern" @@ -648,7 +648,7 @@ tests: pipeline: - Collection: books - Where: - - Function.greater_than: + - FunctionExpression.greater_than: - Field: published - Field: rating - Select: @@ -658,14 +658,14 @@ tests: pipeline: - Collection: books - Where: - - Function.exists: + - FunctionExpression.exists: - Field: non_existent_field assert_count: 0 - description: testConditionalWithFields pipeline: - Collection: books - Where: - - Function.equal_any: + - FunctionExpression.equal_any: - Field: title - - Constant: "Dune" - Constant: "1984" @@ -673,7 +673,7 @@ tests: - title - AliasedExpression: - Conditional: - - Function.greater_than: + - FunctionExpression.greater_than: - Field: published - Constant: 1950 - Field: author diff --git a/tests/system/pipeline_e2e/map.yaml b/tests/system/pipeline_e2e/map.yaml index 546af1351..3e5e5de12 100644 --- a/tests/system/pipeline_e2e/map.yaml +++ b/tests/system/pipeline_e2e/map.yaml @@ -8,13 +8,13 @@ tests: - DESCENDING - Select: - AliasedExpression: - - Function.map_get: + - FunctionExpression.map_get: - Field: awards - hugo - "hugoAward" - Field: title - Where: - - Function.equal: + - FunctionExpression.equal: - Field: hugoAward - Constant: true assert_results: @@ -59,7 +59,7 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: title - Constant: "Dune" - AddFields: @@ -68,7 +68,7 @@ tests: - "award_name" - Select: - AliasedExpression: - - Function.map_get: + - FunctionExpression.map_get: - Field: awards - Field: award_name - "hugoAward" @@ -111,12 +111,12 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: title - Constant: "Dune" - Select: - AliasedExpression: - - Function.map_remove: + - FunctionExpression.map_remove: - Field: awards - "nebula" - "awards_removed" @@ -150,12 +150,12 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: title - Constant: "Dune" - Select: - AliasedExpression: - - Function.map_merge: + - FunctionExpression.map_merge: - Field: awards - Map: elements: {"new_award": true, "hugo": false} @@ -206,7 +206,7 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: awards.hugo - Constant: true - Sort: @@ -256,7 +256,7 @@ tests: - Limit: 1 - Select: - AliasedExpression: - - Function.map_merge: + - FunctionExpression.map_merge: - Map: elements: {"a": "orig", "b": "orig"} - Map: diff --git a/tests/system/pipeline_e2e/math.yaml b/tests/system/pipeline_e2e/math.yaml index b62c0510b..4d35f746d 100644 --- a/tests/system/pipeline_e2e/math.yaml +++ b/tests/system/pipeline_e2e/math.yaml @@ -3,61 +3,61 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: title - Constant: "Dune" - Select: - AliasedExpression: - - Function.add: + - FunctionExpression.add: - Field: published - Field: rating - "pub_plus_rating" assert_results: - pub_plus_rating: 1969.6 - - description: testMathFunctionessions + - description: testMathFunctionExpressionessions pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: title - Constant: To Kill a Mockingbird - Select: - AliasedExpression: - - Function.abs: + - FunctionExpression.abs: - Field: rating - "abs_rating" - AliasedExpression: - - Function.ceil: + - FunctionExpression.ceil: - Field: rating - "ceil_rating" - AliasedExpression: - - Function.exp: + - FunctionExpression.exp: - Field: rating - "exp_rating" - AliasedExpression: - - Function.floor: + - FunctionExpression.floor: - Field: rating - "floor_rating" - AliasedExpression: - - Function.ln: + - FunctionExpression.ln: - Field: rating - "ln_rating" - AliasedExpression: - - Function.log10: + - FunctionExpression.log10: - Field: rating - "log_rating_base10" - AliasedExpression: - - Function.log: + - FunctionExpression.log: - Field: rating - Constant: 2 - "log_rating_base2" - AliasedExpression: - - Function.pow: + - FunctionExpression.pow: - Field: rating - Constant: 2 - "pow_rating" - AliasedExpression: - - Function.sqrt: + - FunctionExpression.sqrt: - Field: rating - "sqrt_rating" assert_results_approximate: @@ -134,11 +134,11 @@ tests: - fieldReferenceValue: rating name: sqrt name: select - - description: testRoundFunctionessions + - description: testRoundFunctionExpressionessions pipeline: - Collection: books - Where: - - Function.equal_any: + - FunctionExpression.equal_any: - Field: title - - Constant: "To Kill a Mockingbird" # rating 4.2 - Constant: "Pride and Prejudice" # rating 4.5 @@ -146,7 +146,7 @@ tests: - Select: - title - AliasedExpression: - - Function.round: + - FunctionExpression.round: - Field: rating - "round_rating" - Sort: @@ -201,42 +201,42 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: title - Constant: To Kill a Mockingbird - Select: - AliasedExpression: - - Function.add: + - FunctionExpression.add: - Field: rating - Constant: 1 - "ratingPlusOne" - AliasedExpression: - - Function.subtract: + - FunctionExpression.subtract: - Field: published - Constant: 1900 - "yearsSince1900" - AliasedExpression: - - Function.multiply: + - FunctionExpression.multiply: - Field: rating - Constant: 10 - "ratingTimesTen" - AliasedExpression: - - Function.divide: + - FunctionExpression.divide: - Field: rating - Constant: 2 - "ratingDividedByTwo" - AliasedExpression: - - Function.multiply: + - FunctionExpression.multiply: - Field: rating - Constant: 20 - "ratingTimes20" - AliasedExpression: - - Function.add: + - FunctionExpression.add: - Field: rating - Constant: 3 - "ratingPlus3" - AliasedExpression: - - Function.mod: + - FunctionExpression.mod: - Field: rating - Constant: 2 - "ratingMod2" diff --git a/tests/system/pipeline_e2e/string.yaml b/tests/system/pipeline_e2e/string.yaml index d612483e1..20a97ba60 100644 --- a/tests/system/pipeline_e2e/string.yaml +++ b/tests/system/pipeline_e2e/string.yaml @@ -8,7 +8,7 @@ tests: - ASCENDING - Select: - AliasedExpression: - - Function.string_concat: + - FunctionExpression.string_concat: - Field: author - Constant: " - " - Field: title @@ -48,7 +48,7 @@ tests: pipeline: - Collection: books - Where: - - Function.starts_with: + - FunctionExpression.starts_with: - Field: title - Constant: The - Select: @@ -93,7 +93,7 @@ tests: pipeline: - Collection: books - Where: - - Function.ends_with: + - FunctionExpression.ends_with: - Field: title - Constant: y - Select: @@ -136,18 +136,18 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: title - Constant: "The Hitchhiker's Guide to the Galaxy" - Select: - AliasedExpression: - - Function.concat: + - FunctionExpression.concat: - Field: author - Constant: ": " - Field: title - "author_title" - AliasedExpression: - - Function.concat: + - FunctionExpression.concat: - Field: tags - - Constant: "new_tag" - "concatenatedTags" @@ -162,20 +162,20 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: title - Constant: "The Hitchhiker's Guide to the Galaxy" - Select: - AliasedExpression: - - Function.length: + - FunctionExpression.length: - Field: title - "titleLength" - AliasedExpression: - - Function.length: + - FunctionExpression.length: - Field: tags - "tagsLength" - AliasedExpression: - - Function.length: + - FunctionExpression.length: - Field: awards - "awardsLength" assert_results: @@ -187,12 +187,12 @@ tests: - Collection: books - Select: - AliasedExpression: - - Function.char_length: + - FunctionExpression.char_length: - Field: title - "titleLength" - title - Where: - - Function.greater_than: + - FunctionExpression.greater_than: - Field: titleLength - Constant: 20 - Sort: @@ -244,12 +244,12 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: author - Constant: "Douglas Adams" - Select: - AliasedExpression: - - Function.char_length: + - FunctionExpression.char_length: - Field: title - "title_length" assert_results: @@ -280,13 +280,13 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: author - Constant: Douglas Adams - Select: - AliasedExpression: - - Function.byte_length: - - Function.string_concat: + - FunctionExpression.byte_length: + - FunctionExpression.string_concat: - Field: title - Constant: _银河系漫游指南 - "title_byte_length" @@ -322,7 +322,7 @@ tests: pipeline: - Collection: books - Where: - - Function.like: + - FunctionExpression.like: - Field: title - Constant: "%Guide%" - Select: @@ -334,7 +334,7 @@ tests: pipeline: - Collection: books - Where: - - Function.regex_contains: + - FunctionExpression.regex_contains: - Field: title - Constant: "(?i)(the|of)" assert_count: 5 @@ -356,7 +356,7 @@ tests: pipeline: - Collection: books - Where: - - Function.regex_match: + - FunctionExpression.regex_match: - Field: title - Constant: ".*(?i)(the|of).*" assert_count: 5 @@ -377,7 +377,7 @@ tests: pipeline: - Collection: books - Where: - - Function.string_contains: + - FunctionExpression.string_contains: - Field: title - Constant: "Hitchhiker's" - Select: @@ -388,12 +388,12 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: author - Constant: "Douglas Adams" - Select: - AliasedExpression: - - Function.to_lower: + - FunctionExpression.to_lower: - Field: title - "lower_title" assert_results: @@ -424,12 +424,12 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: author - Constant: "Douglas Adams" - Select: - AliasedExpression: - - Function.to_upper: + - FunctionExpression.to_upper: - Field: title - "upper_title" assert_results: @@ -460,13 +460,13 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: author - Constant: "Douglas Adams" - Select: - AliasedExpression: - - Function.trim: - - Function.string_concat: + - FunctionExpression.trim: + - FunctionExpression.string_concat: - Constant: " " - Field: title - Constant: " " @@ -504,12 +504,12 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: author - Constant: "Jane Austen" - Select: - AliasedExpression: - - Function.string_reverse: + - FunctionExpression.string_reverse: - Field: title - "reversed_title" assert_results: @@ -540,12 +540,12 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: author - Constant: "Douglas Adams" - Select: - AliasedExpression: - - Function.substring: + - FunctionExpression.substring: - Field: title - Constant: 4 - Constant: 11 @@ -580,12 +580,12 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: author - Constant: "Fyodor Dostoevsky" - Select: - AliasedExpression: - - Function.substring: + - FunctionExpression.substring: - Field: title - Constant: 10 - "substring_title" @@ -618,12 +618,12 @@ tests: pipeline: - Collection: books - Where: - - Function.equal: + - FunctionExpression.equal: - Field: author - Constant: "Douglas Adams" - Select: - AliasedExpression: - - Function.join: + - FunctionExpression.join: - Field: tags - Constant: ", " - "joined_tags" diff --git a/tests/system/pipeline_e2e/vector.yaml b/tests/system/pipeline_e2e/vector.yaml index 85d265c2d..31df276b2 100644 --- a/tests/system/pipeline_e2e/vector.yaml +++ b/tests/system/pipeline_e2e/vector.yaml @@ -4,7 +4,7 @@ tests: - Collection: vectors - Select: - AliasedExpression: - - Function.vector_length: + - FunctionExpression.vector_length: - Field: embedding - "embedding_length" - Sort: @@ -117,12 +117,12 @@ tests: pipeline: - Collection: vectors - Where: - - Function.equal: + - FunctionExpression.equal: - Field: embedding - Vector: [1.0, 2.0, 3.0] - Select: - AliasedExpression: - - Function.dot_product: + - FunctionExpression.dot_product: - Field: embedding - Vector: [1.0, 1.0, 1.0] - "dot_product_result" @@ -132,12 +132,12 @@ tests: pipeline: - Collection: vectors - Where: - - Function.equal: + - FunctionExpression.equal: - Field: embedding - Vector: [1.0, 2.0, 3.0] - Select: - AliasedExpression: - - Function.euclidean_distance: + - FunctionExpression.euclidean_distance: - Field: embedding - Vector: [1.0, 2.0, 3.0] - "euclidean_distance_result" @@ -147,12 +147,12 @@ tests: pipeline: - Collection: vectors - Where: - - Function.equal: + - FunctionExpression.equal: - Field: embedding - Vector: [1.0, 2.0, 3.0] - Select: - AliasedExpression: - - Function.cosine_distance: + - FunctionExpression.cosine_distance: - Field: embedding - Vector: [1.0, 2.0, 3.0] - "cosine_distance_result" diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 84eb6cfe9..e2c6dcd0f 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -566,12 +566,12 @@ def test__from_query_filter_pb_unknown_filter_type(self, mock_client): BooleanExpression._from_query_filter_pb(document_pb.Value(), mock_client) -class TestFunction: +class TestFunctionExpression: def test_equals(self): - assert expr.Function.sqrt("1") == expr.Function.sqrt("1") - assert expr.Function.sqrt("1") != expr.Function.sqrt("2") - assert expr.Function.sqrt("1") != expr.Function.sum("1") - assert expr.Function.sqrt("1") != object() + assert expr.FunctionExpression.sqrt("1") == expr.FunctionExpression.sqrt("1") + assert expr.FunctionExpression.sqrt("1") != expr.FunctionExpression.sqrt("2") + assert expr.FunctionExpression.sqrt("1") != expr.FunctionExpression.sum("1") + assert expr.FunctionExpression.sqrt("1") != object() class TestArray: From 88e6dfe1e7894faa8a54a8096c63e0e7ff126a56 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 14 Jan 2026 09:20:35 -0800 Subject: [PATCH 21/27] chore: revert generated files --- .../services/firestore/async_client.py | 3 - .../firestore_v1/services/firestore/client.py | 3 - .../firestore/transports/rest_base.py | 68 +++++++++++-------- 3 files changed, 41 insertions(+), 33 deletions(-) diff --git a/google/cloud/firestore_v1/services/firestore/async_client.py b/google/cloud/firestore_v1/services/firestore/async_client.py index 96421f879..3557eb94c 100644 --- a/google/cloud/firestore_v1/services/firestore/async_client.py +++ b/google/cloud/firestore_v1/services/firestore/async_client.py @@ -238,9 +238,6 @@ def __init__( If a Callable is given, it will be called with the same set of initialization arguments as used in the FirestoreTransport constructor. If set to None, a transport is chosen automatically. - NOTE: "rest" transport functionality is currently in a - beta state (preview). We welcome your feedback via an - issue in this library's source repository. client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the client. diff --git a/google/cloud/firestore_v1/services/firestore/client.py b/google/cloud/firestore_v1/services/firestore/client.py index e362896af..ac86aaa9e 100644 --- a/google/cloud/firestore_v1/services/firestore/client.py +++ b/google/cloud/firestore_v1/services/firestore/client.py @@ -571,9 +571,6 @@ def __init__( If a Callable is given, it will be called with the same set of initialization arguments as used in the FirestoreTransport constructor. If set to None, a transport is chosen automatically. - NOTE: "rest" transport functionality is currently in a - beta state (preview). We welcome your feedback via an - issue in this library's source repository. client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the client. diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py index 66cffc43c..239eb7dee 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py @@ -130,7 +130,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=False + transcoded_request["body"], use_integers_for_enums=True ) return body @@ -139,7 +139,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=False, + use_integers_for_enums=True, ) ) query_params.update( @@ -148,6 +148,7 @@ def _get_query_params_json(transcoded_request): ) ) + query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseBatchWrite: @@ -186,7 +187,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=False + transcoded_request["body"], use_integers_for_enums=True ) return body @@ -195,7 +196,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=False, + use_integers_for_enums=True, ) ) query_params.update( @@ -204,6 +205,7 @@ def _get_query_params_json(transcoded_request): ) ) + query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseBeginTransaction: @@ -242,7 +244,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=False + transcoded_request["body"], use_integers_for_enums=True ) return body @@ -251,7 +253,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=False, + use_integers_for_enums=True, ) ) query_params.update( @@ -260,6 +262,7 @@ def _get_query_params_json(transcoded_request): ) ) + query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseCommit: @@ -298,7 +301,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=False + transcoded_request["body"], use_integers_for_enums=True ) return body @@ -307,7 +310,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=False, + use_integers_for_enums=True, ) ) query_params.update( @@ -316,6 +319,7 @@ def _get_query_params_json(transcoded_request): ) ) + query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseCreateDocument: @@ -354,7 +358,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=False + transcoded_request["body"], use_integers_for_enums=True ) return body @@ -363,7 +367,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=False, + use_integers_for_enums=True, ) ) query_params.update( @@ -372,6 +376,7 @@ def _get_query_params_json(transcoded_request): ) ) + query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseDeleteDocument: @@ -409,7 +414,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=False, + use_integers_for_enums=True, ) ) query_params.update( @@ -418,6 +423,7 @@ def _get_query_params_json(transcoded_request): ) ) + query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseExecutePipeline: @@ -456,7 +462,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=False + transcoded_request["body"], use_integers_for_enums=True ) return body @@ -465,7 +471,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=False, + use_integers_for_enums=True, ) ) query_params.update( @@ -569,7 +575,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=False, + use_integers_for_enums=True, ) ) query_params.update( @@ -578,6 +584,7 @@ def _get_query_params_json(transcoded_request): ) ) + query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseListCollectionIds: @@ -621,7 +628,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=False + transcoded_request["body"], use_integers_for_enums=True ) return body @@ -630,7 +637,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=False, + use_integers_for_enums=True, ) ) query_params.update( @@ -639,6 +646,7 @@ def _get_query_params_json(transcoded_request): ) ) + query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseListDocuments: @@ -680,7 +688,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=False, + use_integers_for_enums=True, ) ) query_params.update( @@ -689,6 +697,7 @@ def _get_query_params_json(transcoded_request): ) ) + query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseListen: @@ -736,7 +745,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=False + transcoded_request["body"], use_integers_for_enums=True ) return body @@ -745,7 +754,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=False, + use_integers_for_enums=True, ) ) query_params.update( @@ -754,6 +763,7 @@ def _get_query_params_json(transcoded_request): ) ) + query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseRollback: @@ -792,7 +802,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=False + transcoded_request["body"], use_integers_for_enums=True ) return body @@ -801,7 +811,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=False, + use_integers_for_enums=True, ) ) query_params.update( @@ -810,6 +820,7 @@ def _get_query_params_json(transcoded_request): ) ) + query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseRunAggregationQuery: @@ -853,7 +864,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=False + transcoded_request["body"], use_integers_for_enums=True ) return body @@ -862,7 +873,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=False, + use_integers_for_enums=True, ) ) query_params.update( @@ -871,6 +882,7 @@ def _get_query_params_json(transcoded_request): ) ) + query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseRunQuery: @@ -914,7 +926,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=False + transcoded_request["body"], use_integers_for_enums=True ) return body @@ -923,7 +935,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=False, + use_integers_for_enums=True, ) ) query_params.update( @@ -932,6 +944,7 @@ def _get_query_params_json(transcoded_request): ) ) + query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseUpdateDocument: @@ -970,7 +983,7 @@ def _get_request_body_json(transcoded_request): # Jsonify the request body body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=False + transcoded_request["body"], use_integers_for_enums=True ) return body @@ -979,7 +992,7 @@ def _get_query_params_json(transcoded_request): query_params = json.loads( json_format.MessageToJson( transcoded_request["query_params"], - use_integers_for_enums=False, + use_integers_for_enums=True, ) ) query_params.update( @@ -988,6 +1001,7 @@ def _get_query_params_json(transcoded_request): ) ) + query_params["$alt"] = "json;enum-encoding=int" return query_params class _BaseWrite: From 987c923800a567c39f6d46504251adf408f92023 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 14 Jan 2026 10:02:34 -0800 Subject: [PATCH 22/27] chore: fix typing for 3.8 --- google/cloud/firestore_v1/pipeline_result.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_result.py b/google/cloud/firestore_v1/pipeline_result.py index 704811b94..0496d0bfc 100644 --- a/google/cloud/firestore_v1/pipeline_result.py +++ b/google/cloud/firestore_v1/pipeline_result.py @@ -24,6 +24,7 @@ AsyncIterator, Iterable, Iterator, + List, Generic, MutableMapping, Type, @@ -258,12 +259,12 @@ def _process_response(self, response: ExecutePipelineResponse) -> Iterable[T]: ) -class PipelineSnapshot(_PipelineResultContainer[T], list[T]): +class PipelineSnapshot(_PipelineResultContainer[T], List[T]): """ A list type that holds the result of a pipeline.execute() operation, along with related metadata """ - def __init__(self, results_list: list[T], source: _PipelineResultContainer[T]): + def __init__(self, results_list: List[T], source: _PipelineResultContainer[T]): self.__dict__.update(source.__dict__.copy()) list.__init__(self, results_list) # snapshots are always complete From 27af228d7c3032dcc5530824068cec42d38e5119 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 14 Jan 2026 10:07:58 -0800 Subject: [PATCH 23/27] chore(tests): updated generated tests --- .../unit/gapic/firestore_v1/test_firestore.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/unit/gapic/firestore_v1/test_firestore.py b/tests/unit/gapic/firestore_v1/test_firestore.py index af45e4326..e3821e772 100644 --- a/tests/unit/gapic/firestore_v1/test_firestore.py +++ b/tests/unit/gapic/firestore_v1/test_firestore.py @@ -6385,7 +6385,7 @@ def test_get_document_rest_required_fields(request_type=firestore.GetDocumentReq response = client.get_document(request) - expected_params = [] + expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -6526,7 +6526,7 @@ def test_list_documents_rest_required_fields( response = client.list_documents(request) - expected_params = [] + expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -6727,7 +6727,7 @@ def test_update_document_rest_required_fields( response = client.update_document(request) - expected_params = [] + expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -6919,7 +6919,7 @@ def test_delete_document_rest_required_fields( response = client.delete_document(request) - expected_params = [] + expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -7105,7 +7105,7 @@ def test_batch_get_documents_rest_required_fields( iter_content.return_value = iter(json_return_value) response = client.batch_get_documents(request) - expected_params = [] + expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -7228,7 +7228,7 @@ def test_begin_transaction_rest_required_fields( response = client.begin_transaction(request) - expected_params = [] + expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -7405,7 +7405,7 @@ def test_commit_rest_required_fields(request_type=firestore.CommitRequest): response = client.commit(request) - expected_params = [] + expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -7585,7 +7585,7 @@ def test_rollback_rest_required_fields(request_type=firestore.RollbackRequest): response = client.rollback(request) - expected_params = [] + expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -7773,7 +7773,7 @@ def test_run_query_rest_required_fields(request_type=firestore.RunQueryRequest): iter_content.return_value = iter(json_return_value) response = client.run_query(request) - expected_params = [] + expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -8028,7 +8028,7 @@ def test_run_aggregation_query_rest_required_fields( iter_content.return_value = iter(json_return_value) response = client.run_aggregation_query(request) - expected_params = [] + expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -8149,7 +8149,7 @@ def test_partition_query_rest_required_fields( response = client.partition_query(request) - expected_params = [] + expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -8357,7 +8357,7 @@ def test_list_collection_ids_rest_required_fields( response = client.list_collection_ids(request) - expected_params = [] + expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -8597,7 +8597,7 @@ def test_batch_write_rest_required_fields(request_type=firestore.BatchWriteReque response = client.batch_write(request) - expected_params = [] + expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params @@ -8729,7 +8729,7 @@ def test_create_document_rest_required_fields( response = client.create_document(request) - expected_params = [] + expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params From f112a85b45f8319c26355bb0b12f04e86d0297b4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 14 Jan 2026 11:29:06 -0800 Subject: [PATCH 24/27] chore(tests): improve client mocking in unit tests --- tests/unit/v1/test_async_pipeline.py | 15 +++++++-------- tests/unit/v1/test_pipeline.py | 11 +++++------ tests/unit/v1/test_pipeline_source.py | 6 ++++-- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py index 5a7fb360c..18805b7b2 100644 --- a/tests/unit/v1/test_async_pipeline.py +++ b/tests/unit/v1/test_async_pipeline.py @@ -18,6 +18,8 @@ from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.pipeline_expressions import Field +from tests.unit.v1._test_helpers import make_async_client + def _make_async_pipeline(*args, client=mock.Mock()): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline @@ -189,11 +191,10 @@ async def test_async_pipeline_stream_populated(): from google.cloud.firestore_v1.types import ExecutePipelineResponse from google.cloud.firestore_v1.types import ExecutePipelineRequest from google.cloud.firestore_v1.types import Value - from google.cloud.firestore_v1.client import Client - from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.pipeline_result import PipelineResult - real_client = Client() + real_client = make_async_client() client = mock.Mock() client.project = "A" client._database = "B" @@ -228,7 +229,7 @@ async def test_async_pipeline_stream_populated(): response = results[0] assert isinstance(response, PipelineResult) - assert isinstance(response.ref, DocumentReference) + assert isinstance(response.ref, AsyncDocumentReference) assert response.ref.path == "test/my_doc" assert response.id == "my_doc" assert response.create_time.seconds == 1 @@ -246,10 +247,9 @@ async def test_async_pipeline_stream_multiple(): from google.cloud.firestore_v1.types import ExecutePipelineResponse from google.cloud.firestore_v1.types import ExecutePipelineRequest from google.cloud.firestore_v1.types import Value - from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.pipeline_result import PipelineResult - real_client = Client() + real_client = make_async_client() client = mock.Mock() client.project = "A" client._database = "B" @@ -358,9 +358,8 @@ async def test_async_pipeline_stream_stream_equivalence(): from google.cloud.firestore_v1.types import Document from google.cloud.firestore_v1.types import ExecutePipelineResponse from google.cloud.firestore_v1.types import Value - from google.cloud.firestore_v1.client import Client - real_client = Client() + real_client = make_async_client() client = mock.Mock() client.project = "A" client._database = "B" diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index fc8e90a04..10509cafb 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -18,6 +18,8 @@ from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.pipeline_expressions import Field +from tests.unit.v1._test_helpers import make_client + def _make_pipeline(*args, client=mock.Mock()): from google.cloud.firestore_v1.pipeline import Pipeline @@ -190,11 +192,10 @@ def test_pipeline_stream_populated(): from google.cloud.firestore_v1.types import ExecutePipelineResponse from google.cloud.firestore_v1.types import ExecutePipelineRequest from google.cloud.firestore_v1.types import Value - from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.pipeline_result import PipelineResult - real_client = Client() + real_client = make_client() client = mock.Mock() client.project = "A" client._database = "B" @@ -244,10 +245,9 @@ def test_pipeline_stream_multiple(): from google.cloud.firestore_v1.types import ExecutePipelineResponse from google.cloud.firestore_v1.types import ExecutePipelineRequest from google.cloud.firestore_v1.types import Value - from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.pipeline_result import PipelineResult - real_client = Client() + real_client = make_client() client = mock.Mock() client.project = "A" client._database = "B" @@ -348,9 +348,8 @@ def test_pipeline_execute_stream_equivalence(): from google.cloud.firestore_v1.types import Document from google.cloud.firestore_v1.types import ExecutePipelineResponse from google.cloud.firestore_v1.types import Value - from google.cloud.firestore_v1.client import Client - real_client = Client() + real_client = make_client() client = mock.Mock() client.project = "A" client._database = "B" diff --git a/tests/unit/v1/test_pipeline_source.py b/tests/unit/v1/test_pipeline_source.py index 69754a941..871522035 100644 --- a/tests/unit/v1/test_pipeline_source.py +++ b/tests/unit/v1/test_pipeline_source.py @@ -23,12 +23,14 @@ from google.cloud.firestore_v1.query import Query from google.cloud.firestore_v1.async_query import AsyncQuery +from tests.unit.v1._test_helpers import make_async_client +from tests.unit.v1._test_helpers import make_client class TestPipelineSource: _expected_pipeline_type = Pipeline def _make_client(self): - return Client() + return make_client() def _make_query(self): return Query(mock.Mock()) @@ -120,7 +122,7 @@ class TestPipelineSourceWithAsyncClient(TestPipelineSource): _expected_pipeline_type = AsyncPipeline def _make_client(self): - return AsyncClient() + return make_async_client() def _make_query(self): return AsyncQuery(mock.Mock()) From 4c33ae82935d67a221d0ff70a72f0ed4ccf8a06a Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 14 Jan 2026 11:38:03 -0800 Subject: [PATCH 25/27] chore: fix mypy check --- .../firestore/transports/rest_base.py | 57 ------------------- tests/unit/v1/test_pipeline_source.py | 3 +- 2 files changed, 1 insertion(+), 59 deletions(-) diff --git a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py index 239eb7dee..80ce35e49 100644 --- a/google/cloud/firestore_v1/services/firestore/transports/rest_base.py +++ b/google/cloud/firestore_v1/services/firestore/transports/rest_base.py @@ -483,63 +483,6 @@ def _get_query_params_json(transcoded_request): query_params["$alt"] = "json;enum-encoding=int" return query_params - class _BaseExecutePipeline: - def __hash__(self): # pragma: NO COVER - return NotImplementedError("__hash__ must be implemented.") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return { - k: v - for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() - if k not in message_dict - } - - @staticmethod - def _get_http_options(): - http_options: List[Dict[str, str]] = [ - { - "method": "post", - "uri": "/v1/{database=projects/*/databases/*}/documents:executePipeline", - "body": "*", - }, - ] - return http_options - - @staticmethod - def _get_transcoded_request(http_options, request): - pb_request = firestore.ExecutePipelineRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) - return transcoded_request - - @staticmethod - def _get_request_body_json(transcoded_request): - # Jsonify the request body - - body = json_format.MessageToJson( - transcoded_request["body"], use_integers_for_enums=True - ) - return body - - @staticmethod - def _get_query_params_json(transcoded_request): - query_params = json.loads( - json_format.MessageToJson( - transcoded_request["query_params"], - use_integers_for_enums=True, - ) - ) - query_params.update( - _BaseFirestoreRestTransport._BaseExecutePipeline._get_unset_required_fields( - query_params - ) - ) - - query_params["$alt"] = "json;enum-encoding=int" - return query_params - class _BaseGetDocument: def __hash__(self): # pragma: NO COVER return NotImplementedError("__hash__ must be implemented.") diff --git a/tests/unit/v1/test_pipeline_source.py b/tests/unit/v1/test_pipeline_source.py index 871522035..d6665d4bc 100644 --- a/tests/unit/v1/test_pipeline_source.py +++ b/tests/unit/v1/test_pipeline_source.py @@ -16,8 +16,6 @@ from google.cloud.firestore_v1.pipeline_source import PipelineSource from google.cloud.firestore_v1.pipeline import Pipeline from google.cloud.firestore_v1.async_pipeline import AsyncPipeline -from google.cloud.firestore_v1.client import Client -from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.base_document import BaseDocumentReference from google.cloud.firestore_v1.query import Query @@ -26,6 +24,7 @@ from tests.unit.v1._test_helpers import make_async_client from tests.unit.v1._test_helpers import make_client + class TestPipelineSource: _expected_pipeline_type = Pipeline From 61ae07a8e02fc1837a8ebe8c4efb1beb36003a72 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 14 Jan 2026 11:51:15 -0800 Subject: [PATCH 26/27] chore(tests): disable pipeline tests in emulator --- tests/system/test_pipeline_acceptance.py | 6 ++++++ tests/system/test_system.py | 3 +++ tests/system/test_system_async.py | 3 +++ tests/unit/v1/test_pipeline_result.py | 2 +- 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 4634037ab..2a83f4eaf 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -34,9 +34,15 @@ from google.cloud.firestore import Client, AsyncClient from test__helpers import FIRESTORE_ENTERPRISE_DB +from test__helpers import FIRESTORE_EMULATOR FIRESTORE_PROJECT = os.environ.get("GCLOUD_PROJECT") +pytestmark = pytest.mark.skipif( + condition=FIRESTORE_EMULATOR, + reason="Pipeline tests are currently not supported by emulator", +) + test_dir_name = os.path.dirname(__file__) id_format = ( diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 328b29098..2f6e877d0 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -94,6 +94,9 @@ def verify_pipeline(query): """ from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + if FIRESTORE_EMULATOR: + pytest.skip("skip pipeline verification on emulator") + def _clean_results(results): if isinstance(results, dict): return {k: _clean_results(v) for k, v in results.items()} diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 1aaa79591..76d1b5538 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -174,6 +174,9 @@ async def verify_pipeline(query): """ from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + if FIRESTORE_EMULATOR: + pytest.skip("skip pipeline verification on emulator") + def _clean_results(results): if isinstance(results, dict): return {k: _clean_results(v) for k, v in results.items()} diff --git a/tests/unit/v1/test_pipeline_result.py b/tests/unit/v1/test_pipeline_result.py index eca622801..3650074bc 100644 --- a/tests/unit/v1/test_pipeline_result.py +++ b/tests/unit/v1/test_pipeline_result.py @@ -484,7 +484,7 @@ async def test_double_iterate(self): async def async_gen(items): for item in items: - yield item + yield item # pragma: NO COVER # mock the api call to avoid real network requests instance._client._firestore_api.execute_pipeline = mock.AsyncMock( From 71baf580c23259ab2be46dc0cd8897df607706fb Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 14 Jan 2026 11:58:05 -0800 Subject: [PATCH 27/27] chore(tests): disable pipeline read time test in emulator --- tests/system/test_system.py | 3 +++ tests/system/test_system_async.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 2f6e877d0..0c86c69a3 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -1800,6 +1800,9 @@ def test_query_stream_w_read_time(query_docs, cleanup, database): assert new_values[new_ref.id] == new_data +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Pipeline query not supported in emulator." +) @pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) def test_pipeline_w_read_time(query_docs, cleanup, database): collection, stored, allowed_vals = query_docs diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 76d1b5538..1442e7932 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -1685,6 +1685,9 @@ async def test_pipeline_explain_options_using_additional_options( assert "Execution:" in text_stats +@pytest.mark.skipif( + FIRESTORE_EMULATOR, reason="Pipeline query not supported in emulator." +) @pytest.mark.parametrize("database", [FIRESTORE_ENTERPRISE_DB], indirect=True) async def test_pipeline_w_read_time(query_docs, cleanup, database): collection, stored, allowed_vals = query_docs